diff --git a/accelerate/commands/__init__.py b/accelerate/commands/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c9cbe26c257b515f657c05e1996d517e69613972 --- /dev/null +++ b/accelerate/commands/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/accelerate/commands/accelerate_cli.py b/accelerate/commands/accelerate_cli.py new file mode 100644 index 0000000000000000000000000000000000000000..b878c8debd874e1418b946775b11568c7487ad72 --- /dev/null +++ b/accelerate/commands/accelerate_cli.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python + +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from accelerate.commands.config import get_config_parser +from accelerate.commands.env import env_command_parser +from accelerate.commands.estimate import estimate_command_parser +from accelerate.commands.launch import launch_command_parser +from accelerate.commands.merge import merge_command_parser +from accelerate.commands.test import test_command_parser +from accelerate.commands.to_fsdp2 import to_fsdp2_command_parser +from accelerate.commands.tpu import tpu_command_parser +from accelerate.commands.utils import CustomArgumentParser + + +def main(): + parser = CustomArgumentParser("Accelerate CLI tool", usage="accelerate []", allow_abbrev=False) + subparsers = parser.add_subparsers(help="accelerate command helpers") + + # Register commands + get_config_parser(subparsers=subparsers) + estimate_command_parser(subparsers=subparsers) + env_command_parser(subparsers=subparsers) + launch_command_parser(subparsers=subparsers) + merge_command_parser(subparsers=subparsers) + tpu_command_parser(subparsers=subparsers) + test_command_parser(subparsers=subparsers) + to_fsdp2_command_parser(subparsers=subparsers) + + # Let's go + args = parser.parse_args() + + if not hasattr(args, "func"): + parser.print_help() + exit(1) + + # Run + args.func(args) + + +if __name__ == "__main__": + main() diff --git a/accelerate/commands/config/__init__.py b/accelerate/commands/config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..649a15888cccd070b3d4ca9a600457c6ad59d4d3 --- /dev/null +++ b/accelerate/commands/config/__init__.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python + +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +from .config import config_command_parser +from .config_args import default_config_file, load_config_from_file # noqa: F401 +from .default import default_command_parser +from .update import update_command_parser + + +def get_config_parser(subparsers=None): + parent_parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False) + # The main config parser + config_parser = config_command_parser(subparsers) + # The subparser to add commands to + subcommands = config_parser.add_subparsers(title="subcommands", dest="subcommand") + + # Then add other parsers with the parent parser + default_command_parser(subcommands, parents=[parent_parser]) + update_command_parser(subcommands, parents=[parent_parser]) + + return config_parser + + +def main(): + config_parser = get_config_parser() + args = config_parser.parse_args() + + if not hasattr(args, "func"): + config_parser.print_help() + exit(1) + + # Run + args.func(args) + + +if __name__ == "__main__": + main() diff --git a/accelerate/commands/config/cluster.py b/accelerate/commands/config/cluster.py new file mode 100644 index 0000000000000000000000000000000000000000..fce80efdcbbb6e36572716d38b3d74491ebc918f --- /dev/null +++ b/accelerate/commands/config/cluster.py @@ -0,0 +1,939 @@ +#!/usr/bin/env python + +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from ...utils import ( + ComputeEnvironment, + DistributedType, + is_deepspeed_available, + is_fp8_available, + is_hpu_available, + is_mlu_available, + is_mps_available, + is_msamp_available, + is_musa_available, + is_neuron_available, + is_npu_available, + is_sdaa_available, + is_torchao_available, + is_transformer_engine_available, + is_transformers_available, + is_xpu_available, +) +from ...utils.constants import ( + DEEPSPEED_MULTINODE_LAUNCHERS, + FSDP2_STATE_DICT_TYPE, + FSDP_AUTO_WRAP_POLICY, + FSDP_BACKWARD_PREFETCH, + FSDP_SHARDING_STRATEGY, + FSDP_STATE_DICT_TYPE, + TORCH_DYNAMO_MODES, +) +from .config_args import ClusterConfig +from .config_utils import ( + DYNAMO_BACKENDS, + _ask_field, + _ask_options, + _convert_distributed_mode, + _convert_dynamo_backend, + _convert_fp8_backend, + _convert_mixed_precision, + _convert_yes_no_to_bool, +) + + +def get_cluster_input(): + distributed_type = _ask_options( + "Which type of machine are you using?", + [ + "No distributed training", + "multi-CPU", + "multi-XPU", + "multi-HPU", + "multi-GPU", + "multi-NPU", + "multi-MLU", + "multi-SDAA", + "multi-MUSA", + "multi-NEURON", + "TPU", + ], + _convert_distributed_mode, + ) + + machine_rank = 0 + num_machines = 1 + num_processes = 1 + gpu_ids = None + main_process_ip = None + main_process_port = None + rdzv_backend = "static" + same_network = True + debug = False + + if distributed_type in [ + DistributedType.MULTI_GPU, + DistributedType.MULTI_MLU, + DistributedType.MULTI_SDAA, + DistributedType.MULTI_MUSA, + DistributedType.MULTI_NPU, + DistributedType.MULTI_XPU, + DistributedType.MULTI_CPU, + DistributedType.MULTI_HPU, + DistributedType.MULTI_NEURON, + ]: + num_machines = _ask_field( + "How many different machines will you use (use more than 1 for multi-node training)? [1]: ", + int, + default=1, + ) + if num_machines > 1: + machine_rank = _ask_options( + "What is the rank of this machine?", + list(range(num_machines)), + int, + ) + main_process_ip = _ask_field( + "What is the IP address of the machine that will host the main process? ", + ) + main_process_port = _ask_field( + "What is the port you will use to communicate with the main process? ", + int, + ) + same_network = _ask_field( + "Are all the machines on the same local network? Answer `no` if nodes are on the cloud and/or on different network hosts [YES/no]: ", + _convert_yes_no_to_bool, + default=True, + error_message="Please enter yes or no.", + ) + if not same_network: + rdzv_backend = _ask_field( + "What rendezvous backend will you use? ('static', 'c10d', ...): ", default="static" + ) + debug = _ask_field( + "Should distributed operations be checked while running for errors? This can avoid timeout issues but will be slower. [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + + if distributed_type == DistributedType.NO: + use_cpu = _ask_field( + "Do you want to run your training on CPU only (even if a GPU / Apple Silicon / Ascend NPU device is available)? [yes/NO]:", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + elif distributed_type == DistributedType.MULTI_CPU: + use_cpu = True + else: + use_cpu = False + + mpirun_config = {} + + if use_cpu: + if distributed_type == DistributedType.MULTI_CPU: + use_mpirun = _ask_field( + "Do you want accelerate to launch mpirun? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + if use_mpirun: + mpirun_hostfile = _ask_field( + "Please enter the path to the hostfile to use with mpirun [~/hostfile]: ", + str, + default="~/hostfile", + ) + mpirun_config["mpirun_hostfile"] = os.path.expanduser(mpirun_hostfile.strip()) + + dynamo_config = {} + use_dynamo = _ask_field( + "Do you wish to optimize your script with torch dynamo?[yes/NO]:", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + if use_dynamo: + prefix = "dynamo_" + dynamo_config[prefix + "backend"] = _ask_options( + "Which dynamo backend would you like to use?", + [x.lower() for x in DYNAMO_BACKENDS], + _convert_dynamo_backend, + default=2, + ) + use_custom_options = _ask_field( + "Do you want to customize the defaults sent to torch.compile? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + + if use_custom_options: + dynamo_config[prefix + "mode"] = _ask_options( + "Which mode do you want to use?", + TORCH_DYNAMO_MODES, + lambda x: TORCH_DYNAMO_MODES[int(x)], + default=0, + ) + dynamo_config[prefix + "use_fullgraph"] = _ask_field( + "Do you want the fullgraph mode or it is ok to break model into several subgraphs? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + dynamo_config[prefix + "use_dynamic"] = _ask_field( + "Do you want to enable dynamic shape tracing? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + dynamo_config[prefix + "use_regional_compilation"] = _ask_field( + "Do you want to enable regional compilation? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + + use_mps = not use_cpu and is_mps_available() + deepspeed_config = {} + if ( + distributed_type + in [ + DistributedType.MULTI_GPU, + DistributedType.MULTI_XPU, + DistributedType.MULTI_HPU, + DistributedType.MULTI_NPU, + DistributedType.MULTI_MLU, + DistributedType.MULTI_SDAA, + DistributedType.MULTI_MUSA, + DistributedType.MULTI_NEURON, + DistributedType.NO, + ] + and not use_mps + ): + use_deepspeed = _ask_field( + "Do you want to use DeepSpeed? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + if use_deepspeed: + if distributed_type is DistributedType.MULTI_NEURON: + raise RuntimeError("DeepSpeed is not supported on Neuron devices.") + + distributed_type = DistributedType.DEEPSPEED + assert is_deepspeed_available(), ( + "DeepSpeed is not installed => run `pip3 install deepspeed` or build it from source" + ) + + if distributed_type == DistributedType.DEEPSPEED: + use_deepspeed_config = _ask_field( + "Do you want to specify a json file to a DeepSpeed config? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + if use_deepspeed_config: + deepspeed_config["deepspeed_config_file"] = _ask_field( + "Please enter the path to the json DeepSpeed config file: ", + str, + default="none", + ) + else: + deepspeed_config["zero_stage"] = _ask_options( + "What should be your DeepSpeed's ZeRO optimization stage?", + [0, 1, 2, 3], + int, + default=2, + ) + + deepspeed_devices = ["none", "cpu", "nvme"] + if deepspeed_config["zero_stage"] >= 2: + deepspeed_config["offload_optimizer_device"] = _ask_options( + "Where to offload optimizer states?", deepspeed_devices, lambda x: deepspeed_devices[int(x)] + ) + deepspeed_config["offload_param_device"] = _ask_options( + "Where to offload parameters?", deepspeed_devices, lambda x: deepspeed_devices[int(x)] + ) + if deepspeed_config["offload_param_device"] == "nvme": + deepspeed_config["offload_param_nvme_path"] = _ask_field( + "Nvme Path to offload parameters?", + str, + default="/nvme", + ) + if deepspeed_config["offload_optimizer_device"] == "nvme": + deepspeed_config["offload_optimizer_nvme_path"] = _ask_field( + "Nvme Path to offload optimizer states?", + str, + default="/nvme", + ) + deepspeed_config["gradient_accumulation_steps"] = _ask_field( + "How many gradient accumulation steps you're passing in your script? [1]: ", + int, + default=1, + ) + use_gradient_clipping = _ask_field( + "Do you want to use gradient clipping? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + if use_gradient_clipping: + deepspeed_config["gradient_clipping"] = _ask_field( + "What is the gradient clipping value? [1.0]: ", + float, + default=1.0, + ) + if deepspeed_config["zero_stage"] == 3: + deepspeed_config["zero3_save_16bit_model"] = _ask_field( + "Do you want to save 16-bit model weights when using ZeRO Stage-3? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + deepspeed_config["zero3_init_flag"] = _ask_field( + "Do you want to enable `deepspeed.zero.Init` when using ZeRO Stage-3 for constructing massive models? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + if deepspeed_config["zero3_init_flag"]: + if not is_transformers_available(): + raise Exception( + "When `zero3_init_flag` is set, it requires Transformers to be installed. " + "Please run `pip3 install transformers`." + ) + use_moe = _ask_field( + "Do you want to enable Mixture-of-Experts training (MoE)? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + if use_moe: + deepspeed_config["deepspeed_moe_layer_cls_names"] = _ask_field( + "Specify the comma-separated list of transformers MoE layer class names (case-sensitive), e.g : " + " `MixtralSparseMoeBlock`, `Qwen2MoeSparseMoeBlock`, `JetMoEAttention,JetMoEBlock` ... : ", + str, + ) + + if num_machines > 1: + launcher_query = "Which Type of launcher do you want to use?" + deepspeed_config["deepspeed_multinode_launcher"] = _ask_options( + launcher_query, + DEEPSPEED_MULTINODE_LAUNCHERS, + lambda x: DEEPSPEED_MULTINODE_LAUNCHERS[int(x)], + ) + + if deepspeed_config["deepspeed_multinode_launcher"] != DEEPSPEED_MULTINODE_LAUNCHERS[1]: + deepspeed_config["deepspeed_hostfile"] = _ask_field( + "DeepSpeed configures multi-node compute resources with hostfile. " + "Each row is of the format `hostname slots=[num_gpus]`, e.g., `localhost slots=2`; " + "for more information please refer official [documentation]" + "(https://www.deepspeed.ai/getting-started/#resource-configuration-multi-node). " + "Please specify the location of hostfile: ", + str, + ) + + is_exclusion_filter = _ask_field( + "Do you want to specify exclusion filter string? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + if is_exclusion_filter: + deepspeed_config["deepspeed_exclusion_filter"] = _ask_field( + "DeepSpeed exclusion filter string: ", + str, + ) + + is_inclusion_filter = _ask_field( + "Do you want to specify inclusion filter string? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + if is_inclusion_filter: + deepspeed_config["deepspeed_inclusion_filter"] = _ask_field( + "DeepSpeed inclusion filter string: ", + str, + ) + + fsdp_config = {} + + if distributed_type in [ + DistributedType.MULTI_GPU, + DistributedType.MULTI_NPU, + DistributedType.MULTI_MLU, + DistributedType.MULTI_SDAA, + DistributedType.MULTI_MUSA, + DistributedType.MULTI_XPU, + DistributedType.MULTI_HPU, + DistributedType.MULTI_NEURON, + ]: + use_fsdp = _ask_field( + "Do you want to use FullyShardedDataParallel? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + if use_fsdp: + if distributed_type is DistributedType.MULTI_NEURON: + raise NotImplementedError("FSDP is not currently supported on Neuron devices.") + distributed_type = DistributedType.FSDP + + if distributed_type == DistributedType.FSDP: + fsdp_config["fsdp_version"] = _ask_options( + "What should be your FSDP version? [2]: ", + [1, 2], + lambda x: int(x) + 1, + default=1, + ) + fsdp_version = fsdp_config["fsdp_version"] # extract to a variable to simplify usage later + + if fsdp_version == 1: + sharding_strategy_query = "What should be your sharding strategy?" + fsdp_config["fsdp_reshard_after_forward"] = _ask_options( + sharding_strategy_query, + FSDP_SHARDING_STRATEGY, + lambda x: FSDP_SHARDING_STRATEGY[int(x)], + ) + else: + fsdp_config["fsdp_reshard_after_forward"] = _ask_field( + "Do you want to enable resharding after forward? [YES/no]: ", + _convert_yes_no_to_bool, + default=True, + error_message="Please enter yes or no.", + ) + + fsdp_config["fsdp_offload_params"] = _ask_field( + "Do you want to offload parameters and gradients to CPU? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + + fsdp_wrap_query = "What should be your auto wrap policy?" + fsdp_config["fsdp_auto_wrap_policy"] = _ask_options( + fsdp_wrap_query, + FSDP_AUTO_WRAP_POLICY, + lambda x: FSDP_AUTO_WRAP_POLICY[int(x)], + ) + if fsdp_config["fsdp_auto_wrap_policy"] == FSDP_AUTO_WRAP_POLICY[0]: + use_no_split_modules = _ask_field( + "Do you want to use the model's `_no_split_modules` to wrap. Only applicable for 🤗 Transformers [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + if not use_no_split_modules: + fsdp_config["fsdp_transformer_layer_cls_to_wrap"] = _ask_field( + "Specify the comma-separated list of transformer layer class names (case-sensitive) to wrap ,e.g, :" + "`BertLayer`, `GPTJBlock`, `T5Block`, `BertLayer,BertEmbeddings,BertSelfOutput` ...? : ", + str, + ) + elif fsdp_config["fsdp_auto_wrap_policy"] == FSDP_AUTO_WRAP_POLICY[1]: + fsdp_config["fsdp_min_num_params"] = _ask_field( + "What should be your FSDP's minimum number of parameters for Default Auto Wrapping Policy? [1e8]: ", + int, + default=100000000, + ) + # Removed in FSDP2, ask for user input for FSDP1 + if fsdp_version == 1: + fsdp_backward_prefetch_query = "What should be your FSDP's backward prefetch policy?" + fsdp_config["fsdp_backward_prefetch"] = _ask_options( + fsdp_backward_prefetch_query, + FSDP_BACKWARD_PREFETCH, + lambda x: FSDP_BACKWARD_PREFETCH[int(x)], + ) + + fsdp_state_dict_type_query = "What should be your FSDP's state dict type?" + fsdp_config["fsdp_state_dict_type"] = _ask_options( + fsdp_state_dict_type_query, + FSDP_STATE_DICT_TYPE if fsdp_version == 1 else FSDP2_STATE_DICT_TYPE, + lambda x: FSDP_STATE_DICT_TYPE[int(x)] if fsdp_version == 1 else FSDP2_STATE_DICT_TYPE[int(x)], + default=0, + ) + # Not implemented in FSDP2, ask for user input for FSDP1 + if fsdp_version == 1: + fsdp_config["fsdp_forward_prefetch"] = _ask_field( + "Do you want to enable FSDP's forward prefetch policy? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + # Obsolete in FSDP2, ask for user input for FSDP1 + if fsdp_version == 1: + fsdp_config["fsdp_use_orig_params"] = _ask_field( + "Do you want to enable FSDP's `use_orig_params` feature? [YES/no]: ", + _convert_yes_no_to_bool, + default=True, + error_message="Please enter yes or no.", + ) + fsdp_config["fsdp_cpu_ram_efficient_loading"] = _ask_field( + "Do you want to enable CPU RAM efficient model loading? Only applicable for 🤗 Transformers models. [YES/no]: ", + _convert_yes_no_to_bool, + default=True, + error_message="Please enter yes or no.", + ) + # Obsolete in FSDP2, ask for user input for FSDP1 + if fsdp_version == 1: + if fsdp_config["fsdp_cpu_ram_efficient_loading"]: + fsdp_config["fsdp_sync_module_states"] = True + else: + fsdp_config["fsdp_sync_module_states"] = _ask_field( + "Do you want each individually wrapped FSDP unit to broadcast module parameters from rank 0 at the start? [YES/no]: ", + _convert_yes_no_to_bool, + default=True, + error_message="Please enter yes or no.", + ) + fsdp_config["fsdp_activation_checkpointing"] = _ask_field( + "Do you want to enable FSDP activation checkpointing? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + + parallelism_config = {} + + if fsdp_config.get("fsdp_version", 1) == 2: + use_parallelism_config = _ask_field( + "Do you want to use the parallelism config? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + + if use_parallelism_config: + prefix = "parallelism_config_" + parallelism_config[prefix + "dp_replicate_size"] = _ask_field( + "What is the data parallelism replicate size? [1]: ", + int, + default=1, + error_message="Please enter an integer.", + ) + + parallelism_config[prefix + "dp_shard_size"] = _ask_field( + "What is the FSDP shard size? [1]: ", + int, + default=1, + error_message="Please enter an integer.", + ) + + parallelism_config[prefix + "tp_size"] = _ask_field( + "What is the tensor parallelism size? [1]: ", + int, + default=1, + error_message="Please enter an integer.", + ) + + parallelism_config[prefix + "cp_size"] = _ask_field( + "What is the context parallelism size? [1]: ", + int, + default=1, + error_message="Please enter an integer.", + ) + if parallelism_config[prefix + "cp_size"] > 1: + parallelism_config[prefix + "cp_comm_strategy"] = _ask_options( + "What is the compute parallelism communication strategy?", + ["allgather", "alltoall"], + lambda x: ["allgather", "alltoall"][int(x)], + default=0, + ) + + megatron_lm_config = {} + if distributed_type in [DistributedType.MULTI_GPU]: + use_megatron_lm = _ask_field( + "Do you want to use Megatron-LM ? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + if use_megatron_lm: + distributed_type = DistributedType.MEGATRON_LM + if distributed_type == DistributedType.MEGATRON_LM: + prefix = "megatron_lm_" + megatron_lm_config[prefix + "tp_degree"] = _ask_field( + "What is the Tensor Parallelism degree/size? [1]:", + int, + default=1, + error_message="Please enter an integer.", + ) + if megatron_lm_config[prefix + "tp_degree"] > 1: + megatron_lm_config[prefix + "sequence_parallelism"] = _ask_field( + "Do you want to enable Sequence Parallelism? [YES/no]: ", + _convert_yes_no_to_bool, + default=True, + error_message="Please enter yes or no.", + ) + + megatron_lm_config[prefix + "pp_degree"] = _ask_field( + "What is the Pipeline Parallelism degree/size? [1]:", + int, + default=1, + error_message="Please enter an integer.", + ) + if megatron_lm_config[prefix + "pp_degree"] > 1: + megatron_lm_config[prefix + "num_micro_batches"] = _ask_field( + "What is the number of micro-batches? [1]:", + int, + default=1, + error_message="Please enter an integer.", + ) + + megatron_lm_config[prefix + "recompute_activations"] = _ask_field( + "Do you want to enable selective activation recomputation? [YES/no]: ", + _convert_yes_no_to_bool, + default=True, + error_message="Please enter yes or no.", + ) + + megatron_lm_config[prefix + "use_distributed_optimizer"] = _ask_field( + "Do you want to use distributed optimizer " + "which shards optimizer state and gradients across data parallel ranks? [YES/no]: ", + _convert_yes_no_to_bool, + default=True, + error_message="Please enter yes or no.", + ) + + megatron_lm_config[prefix + "gradient_clipping"] = _ask_field( + "What is the gradient clipping value based on global L2 Norm (0 to disable)? [1.0]: ", + float, + default=1.0, + ) + # TPU specific defaults + tpu_commands = None + tpu_command_file = None + tpu_downcast_bf16 = "no" + tpu_env = [] + tpu_name = None + tpu_vm = None + tpu_zone = None + tpu_use_sudo = False + tpu_use_cluster = False + + if distributed_type in [ + DistributedType.MULTI_CPU, + DistributedType.MULTI_XPU, + DistributedType.MULTI_HPU, + DistributedType.MULTI_GPU, + DistributedType.MULTI_MLU, + DistributedType.MULTI_SDAA, + DistributedType.MULTI_MUSA, + DistributedType.MULTI_NPU, + DistributedType.MULTI_NEURON, + DistributedType.XLA, + ]: + machine_type = str(distributed_type).split(".")[1].replace("MULTI_", "") + if machine_type in ["TPU", "NEURON"]: + machine_type += " cores" + elif machine_type == "CPU": + machine_type = "processes" + else: + machine_type += "(s)" + num_processes = _ask_field( + f"How many {machine_type} should be used for distributed training? [1]:", + int, + default=1, + error_message="Please enter an integer.", + ) + elif distributed_type in [DistributedType.FSDP, DistributedType.DEEPSPEED, DistributedType.MEGATRON_LM]: + num_processes = _ask_field( + "How many GPU(s) should be used for distributed training? [1]:", + int, + default=1, + error_message="Please enter an integer.", + ) + else: + num_processes = 1 + + if (distributed_type == DistributedType.MULTI_GPU) and (num_machines == 1) and (num_processes == 1): + raise ValueError( + f"Specified distributed type {distributed_type} but only using 1 GPU on a single machine. Please select `No distributed training` for the type of machine you are using." + ) + + if ( + distributed_type + in [ + DistributedType.MULTI_GPU, + DistributedType.MULTI_MLU, + DistributedType.MULTI_SDAA, + DistributedType.MULTI_MUSA, + DistributedType.MULTI_NPU, + DistributedType.MULTI_XPU, + DistributedType.MULTI_HPU, + DistributedType.MULTI_NEURON, + DistributedType.NO, + ] + and not use_cpu + and not use_mps + ): + if is_npu_available(): + machine_type = "NPU(s)" + elif is_mlu_available(): + machine_type = "MLU(s)" + elif is_sdaa_available(): + machine_type = "SDAA(s)" + elif is_musa_available(): + machine_type = "MUSA(s)" + elif is_xpu_available(): + machine_type = "XPU(s)" + elif is_hpu_available(): + machine_type = "HPU(s)" + elif is_neuron_available(): + machine_type = "Neuron cores" + else: + machine_type = "GPU(s)" + gpu_ids = _ask_field( + f"What {machine_type} (by id) should be used for training on this machine as a comma-separated list? [all]:", + default="all", + ) + + # CPU affinity is only supported on NVIDIA hardware for now + enable_cpu_affinity = False + if distributed_type in (DistributedType.NO, DistributedType.MULTI_GPU) and not use_cpu and not use_mps: + enable_cpu_affinity = _ask_field( + "Would you like to enable numa efficiency? (Currently only supported on NVIDIA hardware). [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + + fp8_config = None + if distributed_type == DistributedType.XLA: + mixed_precision = "no" + main_training_function = _ask_field( + "What is the name of the function in your script that should be launched in all parallel scripts? [main]: ", + default="main", + ) + tpu_use_cluster = _ask_field( + "Are you using a TPU cluster? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + if tpu_use_cluster: + tpu_name = _ask_field( + "What is the name of your TPU cluster? ", + default=None, + error_message="Please enter the name of your TPU cluster.", + ) + tpu_zone = _ask_field( + "What is the zone of your TPU cluster? ", + default=None, + error_message="Please enter the zone of your TPU cluster.", + ) + tpu_use_sudo = _ask_field( + "To run a python script in a TPU pod, should `sudo` be used? [yes/NO]: ", + default=False, + error_message="Please enter yes or no.", + ) + run_commands = _ask_field( + "Do you have code you wish to run on startup in each pod? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + if run_commands: + use_command_file = _ask_field( + "Is this code located in a bash script? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + if use_command_file: + tpu_command_file = _ask_field( + "What is the path to your bash script? ", + default=None, + error_message="Please enter the path to your bash script.", + ) + tpu_command_file = os.path.abspath(tpu_command_file) + else: + print("Please enter each command separately you wish to run on startup in each pod.") + tpu_commands = [] + another_command = True + while another_command: + tpu_commands.append( + _ask_field( + "Please enter a single command to be ran ", + default=None, + error_message="Please enter the commands you wish to run on startup in each pod as a single string.", + ) + ) + another_command = _ask_field( + "Do you wish to add another command? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + tpu_vm = _ask_field( + "If not using an instance group, what are the names of the Compute VM instances to be used, separated by a comma: ", + default="", + ).split(",") + tpu_env = _ask_field( + "What environment variables do you wish to set in each pod, separated by a comma: ", + default="", + ).split(",") + + else: + main_training_function = "main" + if distributed_type == DistributedType.DEEPSPEED and use_deepspeed_config: + mixed_precision = None + else: + mixed_precision = _ask_options( + "Do you wish to use mixed precision?", + ["no", "fp16", "bf16", "fp8"], + _convert_mixed_precision, + ) + if mixed_precision == "fp8": + if not is_fp8_available(): + raise ValueError( + "FP8 (either torchao, Transformer Engine or MSAMP) is not installed on this machine." + ) + fp8_config = {} + fp8_config["backend"] = _ask_options( + "Which FP8 backend do you want to use?", + ["ao", "te", "msamp"], + _convert_fp8_backend, + ) + if fp8_config["backend"] == "TE": + if not is_transformer_engine_available(): + raise ValueError("TransformersEngine was selected, but it is not installed on this machine.") + fp8_config["use_autocast_during_eval"] = _ask_field( + "Do you want to use FP8 autocast during eval mode? Generally better metrics are found when this is disabled [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + ) + fp8_config["margin"] = _ask_field( + "What margin should be used for gradient scaling? [0]: ", + int, + default=0, + ) + fp8_config["interval"] = _ask_field( + "What interval should be used for for how often the scaling factor is recomputed? [1]: ", + int, + default=1, + ) + fp8_config["fp8_format"] = _ask_options( + "Which weight format should be used?", + ["HYBRID", "E4M3", "E5M2"], + lambda i: ["HYBRID", "E4M3", "E5M2"][i], + default=0, + ) + fp8_config["amax_history_length"] = _ask_field( + "What length of history should be used for the amax scaling factor computation? [1024]: ", + int, + default=1024, + ) + fp8_config["amax_compute_algorithm"] = _ask_options( + "Which algorithm should be used for the amax scaling factor computation?", + ["max", "most_recent"], + lambda x: "max" if x == 0 else "most_recent", + default=0, + ) + fp8_config["override_linear_precision"] = _ask_field( + "Do you want to to execute `fprop`, `dgrad`, and `wgrad` GEMMS in higher precision? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + ) + if fp8_config["override_linear_precision"]: + fprop = _ask_field( + "Should `fprop` be executed in higher precision? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + ) + dgrad = _ask_field( + "Should `dgrad` be executed in higher precision? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + ) + wgrad = _ask_field( + "Should `wgrad` be executed in higher precision? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + ) + fp8_config["override_linear_precision"] = (fprop, dgrad, wgrad) + else: + fp8_config["override_linear_precision"] = (False, False, False) + + elif fp8_config["backend"] == "MSAMP": + if not is_msamp_available(): + raise ValueError("MSAMP was selected, but it is not installed on this machine.") + fp8_config["optimization_level"] = _ask_options( + "Which optimization level should be used?", + ["O1", "O2"], + lambda x: "O1" if x == 0 else "O2", + default=1, + ) + + elif fp8_config["backend"] == "AO": + if not is_torchao_available(): + raise ValueError("torchao was selected, but it is not installed on this machine.") + fp8_config["enable_fsdp_float8_all_gather"] = _ask_field( + "Do you want to enable FSDP2 float8 all gather? This is recommended for better performance if using FSDP2. [YES/no]: ", + _convert_yes_no_to_bool, + default=True, + ) + fp8_config["pad_inner_dim"] = _ask_field( + "Do you want to pad the inner dimension of weight matrices before float8 matmuls? This is required for _scaled_mm which has strict alignment requirements. Note: padding may cause memory spikes. [YES/no]: ", + _convert_yes_no_to_bool, + default=True, + ) + + if use_dynamo and mixed_precision == "no" and not use_cpu: + print( + "Torch dynamo used without mixed precision requires TF32 to be efficient. Accelerate will enable it by default when launching your scripts." + ) + + if distributed_type == DistributedType.XLA and mixed_precision == "bf16": + tpu_downcast_bf16 = _ask_field( + "Should `torch.float` be cast as `bfloat16` and `torch.double` remain `float32` on TPUs?", default="no" + ) + + return ClusterConfig( + compute_environment=ComputeEnvironment.LOCAL_MACHINE, + distributed_type=distributed_type, + num_processes=num_processes, + gpu_ids=gpu_ids, + mixed_precision=mixed_precision, + downcast_bf16=tpu_downcast_bf16, + machine_rank=machine_rank, + num_machines=num_machines, + main_process_ip=main_process_ip, + main_process_port=main_process_port, + main_training_function=main_training_function, + fp8_config=fp8_config, + deepspeed_config=deepspeed_config, + fsdp_config=fsdp_config, + parallelism_config=parallelism_config, + megatron_lm_config=megatron_lm_config, + mpirun_config=mpirun_config, + use_cpu=use_cpu, + rdzv_backend=rdzv_backend, + same_network=same_network, + commands=tpu_commands, + command_file=tpu_command_file, + tpu_env=tpu_env, + tpu_name=tpu_name, + tpu_vm=tpu_vm, + tpu_zone=tpu_zone, + tpu_use_sudo=tpu_use_sudo, + tpu_use_cluster=tpu_use_cluster, + dynamo_config=dynamo_config, + debug=debug, + enable_cpu_affinity=enable_cpu_affinity, + ) diff --git a/accelerate/commands/config/config.py b/accelerate/commands/config/config.py new file mode 100644 index 0000000000000000000000000000000000000000..72414f2abe62d76bd5133f4b0ed99bf34133f6f6 --- /dev/null +++ b/accelerate/commands/config/config.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python + +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +from accelerate.utils import ComputeEnvironment + +from .cluster import get_cluster_input +from .config_args import cache_dir, default_config_file, default_yaml_config_file, load_config_from_file # noqa: F401 +from .config_utils import _ask_field, _ask_options, _convert_compute_environment # noqa: F401 +from .sagemaker import get_sagemaker_input + + +description = "Launches a series of prompts to create and save a `default_config.yaml` configuration file for your training system. Should always be ran first on your machine" + + +def get_user_input(): + compute_environment = _ask_options( + "In which compute environment are you running?", + ["This machine", "AWS (Amazon SageMaker)"], + _convert_compute_environment, + ) + if compute_environment == ComputeEnvironment.AMAZON_SAGEMAKER: + config = get_sagemaker_input() + else: + config = get_cluster_input() + return config + + +def config_command_parser(subparsers=None): + if subparsers is not None: + parser = subparsers.add_parser("config", description=description) + else: + parser = argparse.ArgumentParser("Accelerate config command", description=description) + + parser.add_argument( + "--config_file", + default=None, + help=( + "The path to use to store the config file. Will default to a file named default_config.yaml in the cache " + "location, which is the content of the environment `HF_HOME` suffixed with 'accelerate', or if you don't have " + "such an environment variable, your cache directory ('~/.cache' or the content of `XDG_CACHE_HOME`) suffixed " + "with 'huggingface'." + ), + ) + + if subparsers is not None: + parser.set_defaults(func=config_command) + return parser + + +def config_command(args): + config = get_user_input() + if args.config_file is not None: + config_file = args.config_file + else: + if not os.path.isdir(cache_dir): + os.makedirs(cache_dir) + config_file = default_yaml_config_file + + if config_file.endswith(".json"): + config.to_json_file(config_file) + else: + config.to_yaml_file(config_file) + print(f"accelerate configuration saved at {config_file}") + + +def main(): + parser = config_command_parser() + args = parser.parse_args() + config_command(args) + + +if __name__ == "__main__": + main() diff --git a/accelerate/commands/config/config_args.py b/accelerate/commands/config/config_args.py new file mode 100644 index 0000000000000000000000000000000000000000..9a50fade655da8e1ef766817e013185728c6b7bc --- /dev/null +++ b/accelerate/commands/config/config_args.py @@ -0,0 +1,252 @@ +#!/usr/bin/env python + +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +from dataclasses import dataclass +from enum import Enum +from typing import Optional, Union + +import yaml + +from ...utils import ComputeEnvironment, DistributedType, SageMakerDistributedType +from ...utils.constants import SAGEMAKER_PYTHON_VERSION, SAGEMAKER_PYTORCH_VERSION, SAGEMAKER_TRANSFORMERS_VERSION + + +hf_cache_home = os.path.expanduser( + os.environ.get("HF_HOME", os.path.join(os.environ.get("XDG_CACHE_HOME", "~/.cache"), "huggingface")) +) +cache_dir = os.path.join(hf_cache_home, "accelerate") +default_json_config_file = os.path.join(cache_dir, "default_config.yaml") +default_yaml_config_file = os.path.join(cache_dir, "default_config.yaml") + +# For backward compatibility: the default config is the json one if it's the only existing file. +if os.path.isfile(default_yaml_config_file) or not os.path.isfile(default_json_config_file): + default_config_file = default_yaml_config_file +else: + default_config_file = default_json_config_file + + +def load_config_from_file(config_file): + if config_file is not None: + if not os.path.isfile(config_file): + raise FileNotFoundError( + f"The passed configuration file `{config_file}` does not exist. " + "Please pass an existing file to `accelerate launch`, or use the default one " + "created through `accelerate config` and run `accelerate launch` " + "without the `--config_file` argument." + ) + else: + config_file = default_config_file + with open(config_file, encoding="utf-8") as f: + if config_file.endswith(".json"): + if ( + json.load(f).get("compute_environment", ComputeEnvironment.LOCAL_MACHINE) + == ComputeEnvironment.LOCAL_MACHINE + ): + config_class = ClusterConfig + else: + config_class = SageMakerConfig + return config_class.from_json_file(json_file=config_file) + else: + if ( + yaml.safe_load(f).get("compute_environment", ComputeEnvironment.LOCAL_MACHINE) + == ComputeEnvironment.LOCAL_MACHINE + ): + config_class = ClusterConfig + else: + config_class = SageMakerConfig + return config_class.from_yaml_file(yaml_file=config_file) + + +@dataclass +class BaseConfig: + compute_environment: ComputeEnvironment + distributed_type: Union[DistributedType, SageMakerDistributedType] + mixed_precision: str + use_cpu: bool + debug: bool + + def to_dict(self): + result = self.__dict__ + # For serialization, it's best to convert Enums to strings (or their underlying value type). + + def _convert_enums(value): + if isinstance(value, Enum): + return value.value + if isinstance(value, dict): + if not bool(value): + return None + for key1, value1 in value.items(): + value[key1] = _convert_enums(value1) + return value + + for key, value in result.items(): + result[key] = _convert_enums(value) + result = {k: v for k, v in result.items() if v is not None} + return result + + @staticmethod + def process_config(config_dict): + """ + Processes `config_dict` and sets default values for any missing keys + """ + if "compute_environment" not in config_dict: + config_dict["compute_environment"] = ComputeEnvironment.LOCAL_MACHINE + if "distributed_type" not in config_dict: + raise ValueError("A `distributed_type` must be specified in the config file.") + if "num_processes" not in config_dict and config_dict["distributed_type"] == DistributedType.NO: + config_dict["num_processes"] = 1 + if "mixed_precision" not in config_dict: + config_dict["mixed_precision"] = "fp16" if ("fp16" in config_dict and config_dict["fp16"]) else None + if "fp16" in config_dict: # Convert the config to the new format. + del config_dict["fp16"] + if "dynamo_backend" in config_dict: # Convert the config to the new format. + dynamo_backend = config_dict.pop("dynamo_backend") + config_dict["dynamo_config"] = {} if dynamo_backend == "NO" else {"dynamo_backend": dynamo_backend} + if "use_cpu" not in config_dict: + config_dict["use_cpu"] = False + if "debug" not in config_dict: + config_dict["debug"] = False + if "enable_cpu_affinity" not in config_dict: + config_dict["enable_cpu_affinity"] = False + return config_dict + + @classmethod + def from_json_file(cls, json_file=None): + json_file = default_json_config_file if json_file is None else json_file + with open(json_file, encoding="utf-8") as f: + config_dict = json.load(f) + config_dict = cls.process_config(config_dict) + extra_keys = sorted(set(config_dict.keys()) - set(cls.__dataclass_fields__.keys())) + if len(extra_keys) > 0: + raise ValueError( + f"The config file at {json_file} had unknown keys ({extra_keys}), please try upgrading your `accelerate`" + " version or fix (and potentially remove) these keys from your config file." + ) + + return cls(**config_dict) + + def to_json_file(self, json_file): + with open(json_file, "w", encoding="utf-8") as f: + content = json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" + f.write(content) + + @classmethod + def from_yaml_file(cls, yaml_file=None): + yaml_file = default_yaml_config_file if yaml_file is None else yaml_file + with open(yaml_file, encoding="utf-8") as f: + config_dict = yaml.safe_load(f) + config_dict = cls.process_config(config_dict) + extra_keys = sorted(set(config_dict.keys()) - set(cls.__dataclass_fields__.keys())) + if len(extra_keys) > 0: + raise ValueError( + f"The config file at {yaml_file} had unknown keys ({extra_keys}), please try upgrading your `accelerate`" + " version or fix (and potentially remove) these keys from your config file." + ) + return cls(**config_dict) + + def to_yaml_file(self, yaml_file): + with open(yaml_file, "w", encoding="utf-8") as f: + yaml.safe_dump(self.to_dict(), f) + + def __post_init__(self): + if isinstance(self.compute_environment, str): + self.compute_environment = ComputeEnvironment(self.compute_environment) + if isinstance(self.distributed_type, str): + if self.compute_environment == ComputeEnvironment.AMAZON_SAGEMAKER: + self.distributed_type = SageMakerDistributedType(self.distributed_type) + else: + self.distributed_type = DistributedType(self.distributed_type) + if getattr(self, "dynamo_config", None) is None: + self.dynamo_config = {} + + +@dataclass +class ClusterConfig(BaseConfig): + num_processes: int = -1 # For instance if we use SLURM and the user manually passes it in + machine_rank: int = 0 + num_machines: int = 1 + gpu_ids: Optional[str] = None + main_process_ip: Optional[str] = None + main_process_port: Optional[int] = None + rdzv_backend: Optional[str] = "static" + same_network: Optional[bool] = False + main_training_function: str = "main" + enable_cpu_affinity: bool = False + + # args for FP8 training + fp8_config: Optional[dict] = None + # args for deepspeed_plugin + deepspeed_config: Optional[dict] = None + # args for fsdp + fsdp_config: Optional[dict] = None + # args for parallelism config + parallelism_config: Optional[dict] = None + # args for megatron_lm + megatron_lm_config: Optional[dict] = None + # args for mpirun + mpirun_config: Optional[dict] = None + # args for TPU + downcast_bf16: bool = False + + # args for TPU pods + tpu_name: Optional[str] = None + tpu_zone: Optional[str] = None + tpu_use_cluster: bool = False + tpu_use_sudo: bool = False + command_file: Optional[str] = None + commands: list[str] = None + tpu_vm: list[str] = None + tpu_env: list[str] = None + + # args for dynamo + dynamo_config: Optional[dict] = None + + def __post_init__(self): + if self.deepspeed_config is None: + self.deepspeed_config = {} + if self.fsdp_config is None: + self.fsdp_config = {} + if self.megatron_lm_config is None: + self.megatron_lm_config = {} + if self.mpirun_config is None: + self.mpirun_config = {} + if self.fp8_config is None: + self.fp8_config = {} + if self.parallelism_config is None: + self.parallelism_config = {} + return super().__post_init__() + + +@dataclass +class SageMakerConfig(BaseConfig): + ec2_instance_type: str + iam_role_name: str + image_uri: Optional[str] = None + profile: Optional[str] = None + region: str = "us-east-1" + num_machines: int = 1 + gpu_ids: str = "all" + base_job_name: str = f"accelerate-sagemaker-{num_machines}" + pytorch_version: str = SAGEMAKER_PYTORCH_VERSION + transformers_version: str = SAGEMAKER_TRANSFORMERS_VERSION + py_version: str = SAGEMAKER_PYTHON_VERSION + sagemaker_inputs_file: Optional[str] = None + sagemaker_metrics_file: Optional[str] = None + additional_args: Optional[dict] = None + dynamo_config: Optional[dict] = None + enable_cpu_affinity: bool = False diff --git a/accelerate/commands/config/config_utils.py b/accelerate/commands/config/config_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d29bf4726a8205e1d01fb5f279ddbc7d7160b944 --- /dev/null +++ b/accelerate/commands/config/config_utils.py @@ -0,0 +1,122 @@ +#!/usr/bin/env python + +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +from ...utils.dataclasses import ( + ComputeEnvironment, + DistributedType, + DynamoBackend, + FP8BackendType, + PrecisionType, + SageMakerDistributedType, +) +from ..menu import BulletMenu + + +DYNAMO_BACKENDS = [ + "EAGER", + "AOT_EAGER", + "INDUCTOR", + "AOT_TS_NVFUSER", + "NVPRIMS_NVFUSER", + "CUDAGRAPHS", + "OFI", + "FX2TRT", + "ONNXRT", + "TENSORRT", + "AOT_TORCHXLA_TRACE_ONCE", + "TORHCHXLA_TRACE_ONCE", + "TVM", +] + + +def _ask_field(input_text, convert_value=None, default=None, error_message=None): + ask_again = True + while ask_again: + result = input(input_text) + try: + if default is not None and len(result) == 0: + return default + return convert_value(result) if convert_value is not None else result + except Exception: + if error_message is not None: + print(error_message) + + +def _ask_options(input_text, options=[], convert_value=None, default=0): + menu = BulletMenu(input_text, options) + result = menu.run(default_choice=default) + return convert_value(result) if convert_value is not None else result + + +def _convert_compute_environment(value): + value = int(value) + return ComputeEnvironment(["LOCAL_MACHINE", "AMAZON_SAGEMAKER"][value]) + + +def _convert_distributed_mode(value): + value = int(value) + return DistributedType( + [ + "NO", + "MULTI_CPU", + "MULTI_XPU", + "MULTI_HPU", + "MULTI_GPU", + "MULTI_NPU", + "MULTI_MLU", + "MULTI_SDAA", + "MULTI_MUSA", + "MULTI_NEURON", + "XLA", + ][value] + ) + + +def _convert_dynamo_backend(value): + value = int(value) + return DynamoBackend(DYNAMO_BACKENDS[value]).value + + +def _convert_mixed_precision(value): + value = int(value) + return PrecisionType(["no", "fp16", "bf16", "fp8"][value]) + + +def _convert_sagemaker_distributed_mode(value): + value = int(value) + return SageMakerDistributedType(["NO", "DATA_PARALLEL", "MODEL_PARALLEL"][value]) + + +def _convert_fp8_backend(value): + value = int(value) + return FP8BackendType(["AO", "TE", "MSAMP"][value]) + + +def _convert_yes_no_to_bool(value): + return {"yes": True, "no": False}[value.lower()] + + +class SubcommandHelpFormatter(argparse.RawDescriptionHelpFormatter): + """ + A custom formatter that will remove the usage line from the help message for subcommands. + """ + + def _format_usage(self, usage, actions, groups, prefix): + usage = super()._format_usage(usage, actions, groups, prefix) + usage = usage.replace(" [] ", "") + return usage diff --git a/accelerate/commands/config/default.py b/accelerate/commands/config/default.py new file mode 100644 index 0000000000000000000000000000000000000000..c505d8be937a1b59151e58f9298ef920e9f06f25 --- /dev/null +++ b/accelerate/commands/config/default.py @@ -0,0 +1,172 @@ +#!/usr/bin/env python + +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path + +import torch + +from ...utils import ( + is_hpu_available, + is_mlu_available, + is_musa_available, + is_neuron_available, + is_npu_available, + is_sdaa_available, + is_xpu_available, +) +from .config_args import ClusterConfig, default_json_config_file +from .config_utils import SubcommandHelpFormatter + + +description = "Create a default config file for Accelerate with only a few flags set." + + +def write_basic_config(mixed_precision="no", save_location: str = default_json_config_file): + """ + Creates and saves a basic cluster config to be used on a local machine with potentially multiple GPUs. Will also + set CPU if it is a CPU-only machine. + + Args: + mixed_precision (`str`, *optional*, defaults to "no"): + Mixed Precision to use. Should be one of "no", "fp16", or "bf16" + save_location (`str`, *optional*, defaults to `default_json_config_file`): + Optional custom save location. Should be passed to `--config_file` when using `accelerate launch`. Default + location is inside the huggingface cache folder (`~/.cache/huggingface`) but can be overridden by setting + the `HF_HOME` environmental variable, followed by `accelerate/default_config.yaml`. + """ + path = Path(save_location) + path.parent.mkdir(parents=True, exist_ok=True) + if path.exists(): + print( + f"Configuration already exists at {save_location}, will not override. Run `accelerate config` manually or pass a different `save_location`." + ) + return False + mixed_precision = mixed_precision.lower() + if mixed_precision not in ["no", "fp16", "bf16", "fp8"]: + raise ValueError( + f"`mixed_precision` should be one of 'no', 'fp16', 'bf16', or 'fp8'. Received {mixed_precision}" + ) + config = { + "compute_environment": "LOCAL_MACHINE", + "mixed_precision": mixed_precision, + } + if is_mlu_available(): + num_mlus = torch.mlu.device_count() + config["num_processes"] = num_mlus + config["use_cpu"] = False + if num_mlus > 1: + config["distributed_type"] = "MULTI_MLU" + else: + config["distributed_type"] = "NO" + if is_sdaa_available(): + num_sdaas = torch.sdaa.device_count() + config["num_processes"] = num_sdaas + config["use_cpu"] = False + if num_sdaas > 1: + config["distributed_type"] = "MULTI_SDAA" + else: + config["distributed_type"] = "NO" + elif is_musa_available(): + num_musas = torch.musa.device_count() + config["num_processes"] = num_musas + config["use_cpu"] = False + if num_musas > 1: + config["distributed_type"] = "MULTI_MUSA" + else: + config["distributed_type"] = "NO" + elif is_hpu_available(): + num_hpus = torch.hpu.device_count() + config["num_processes"] = num_hpus + config["use_cpu"] = False + if num_hpus > 1: + config["distributed_type"] = "MULTI_HPU" + else: + config["distributed_type"] = "NO" + elif torch.cuda.is_available(): + num_gpus = torch.cuda.device_count() + config["num_processes"] = num_gpus + config["use_cpu"] = False + if num_gpus > 1: + config["distributed_type"] = "MULTI_GPU" + else: + config["distributed_type"] = "NO" + elif is_xpu_available(): + num_xpus = torch.xpu.device_count() + config["num_processes"] = num_xpus + config["use_cpu"] = False + if num_xpus > 1: + config["distributed_type"] = "MULTI_XPU" + else: + config["distributed_type"] = "NO" + elif is_npu_available(): + num_npus = torch.npu.device_count() + config["num_processes"] = num_npus + config["use_cpu"] = False + if num_npus > 1: + config["distributed_type"] = "MULTI_NPU" + else: + config["distributed_type"] = "NO" + elif is_neuron_available(): + num_neuron_cores = torch.neuron.device_count() + config["num_processes"] = num_neuron_cores + config["use_cpu"] = False + if num_neuron_cores > 1: + config["distributed_type"] = "MULTI_NEURON" + else: + config["distributed_type"] = "NO" + else: + num_xpus = 0 + config["use_cpu"] = True + config["num_processes"] = 1 + config["distributed_type"] = "NO" + config["debug"] = False + config["enable_cpu_affinity"] = False + config = ClusterConfig(**config) + config.to_json_file(path) + return path + + +def default_command_parser(parser, parents): + parser = parser.add_parser("default", parents=parents, help=description, formatter_class=SubcommandHelpFormatter) + parser.add_argument( + "--config_file", + default=default_json_config_file, + help=( + "The path to use to store the config file. Will default to a file named default_config.yaml in the cache " + "location, which is the content of the environment `HF_HOME` suffixed with 'accelerate', or if you don't have " + "such an environment variable, your cache directory ('~/.cache' or the content of `XDG_CACHE_HOME`) suffixed " + "with 'huggingface'." + ), + dest="save_location", + ) + + parser.add_argument( + "--mixed_precision", + choices=["no", "fp16", "bf16"], + type=str, + help="Whether or not to use mixed precision training. " + "Choose between FP16 and BF16 (bfloat16) training. " + "BF16 training is only supported on Nvidia Ampere GPUs and PyTorch 1.10 or later.", + default="no", + ) + parser.set_defaults(func=default_config_command) + return parser + + +def default_config_command(args): + config_file = write_basic_config(args.mixed_precision, args.save_location) + if config_file: + print(f"accelerate configuration saved at {config_file}") diff --git a/accelerate/commands/config/sagemaker.py b/accelerate/commands/config/sagemaker.py new file mode 100644 index 0000000000000000000000000000000000000000..5092ef31fc4715f901be6c1e7bfe80c0b140d767 --- /dev/null +++ b/accelerate/commands/config/sagemaker.py @@ -0,0 +1,274 @@ +#!/usr/bin/env python + +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import os + +from ...utils.constants import SAGEMAKER_PARALLEL_EC2_INSTANCES, TORCH_DYNAMO_MODES +from ...utils.dataclasses import ComputeEnvironment, SageMakerDistributedType +from ...utils.imports import is_boto3_available +from .config_args import SageMakerConfig +from .config_utils import ( + DYNAMO_BACKENDS, + _ask_field, + _ask_options, + _convert_dynamo_backend, + _convert_mixed_precision, + _convert_sagemaker_distributed_mode, + _convert_yes_no_to_bool, +) + + +if is_boto3_available(): + import boto3 # noqa: F401 + + +def _create_iam_role_for_sagemaker(role_name): + iam_client = boto3.client("iam") + + sagemaker_trust_policy = { + "Version": "2012-10-17", + "Statement": [ + {"Effect": "Allow", "Principal": {"Service": "sagemaker.amazonaws.com"}, "Action": "sts:AssumeRole"} + ], + } + try: + # create the role, associated with the chosen trust policy + iam_client.create_role( + RoleName=role_name, AssumeRolePolicyDocument=json.dumps(sagemaker_trust_policy, indent=2) + ) + policy_document = { + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Action": [ + "sagemaker:*", + "ecr:GetDownloadUrlForLayer", + "ecr:BatchGetImage", + "ecr:BatchCheckLayerAvailability", + "ecr:GetAuthorizationToken", + "cloudwatch:PutMetricData", + "cloudwatch:GetMetricData", + "cloudwatch:GetMetricStatistics", + "cloudwatch:ListMetrics", + "logs:CreateLogGroup", + "logs:CreateLogStream", + "logs:DescribeLogStreams", + "logs:PutLogEvents", + "logs:GetLogEvents", + "s3:CreateBucket", + "s3:ListBucket", + "s3:GetBucketLocation", + "s3:GetObject", + "s3:PutObject", + ], + "Resource": "*", + } + ], + } + # attach policy to role + iam_client.put_role_policy( + RoleName=role_name, + PolicyName=f"{role_name}_policy_permission", + PolicyDocument=json.dumps(policy_document, indent=2), + ) + except iam_client.exceptions.EntityAlreadyExistsException: + print(f"role {role_name} already exists. Using existing one") + + +def _get_iam_role_arn(role_name): + iam_client = boto3.client("iam") + return iam_client.get_role(RoleName=role_name)["Role"]["Arn"] + + +def get_sagemaker_input(): + credentials_configuration = _ask_options( + "How do you want to authorize?", + ["AWS Profile", "Credentials (AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY) "], + int, + ) + aws_profile = None + if credentials_configuration == 0: + aws_profile = _ask_field("Enter your AWS Profile name: [default] ", default="default") + os.environ["AWS_PROFILE"] = aws_profile + else: + print( + "Note you will need to provide AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY when you launch you training script with," + "`accelerate launch --aws_access_key_id XXX --aws_secret_access_key YYY`" + ) + aws_access_key_id = _ask_field("AWS Access Key ID: ") + os.environ["AWS_ACCESS_KEY_ID"] = aws_access_key_id + + aws_secret_access_key = _ask_field("AWS Secret Access Key: ") + os.environ["AWS_SECRET_ACCESS_KEY"] = aws_secret_access_key + + aws_region = _ask_field("Enter your AWS Region: [us-east-1]", default="us-east-1") + os.environ["AWS_DEFAULT_REGION"] = aws_region + + role_management = _ask_options( + "Do you already have an IAM Role for executing Amazon SageMaker Training Jobs?", + ["Provide IAM Role name", "Create new IAM role using credentials"], + int, + ) + if role_management == 0: + iam_role_name = _ask_field("Enter your IAM role name: ") + else: + iam_role_name = "accelerate_sagemaker_execution_role" + print(f'Accelerate will create an iam role "{iam_role_name}" using the provided credentials') + _create_iam_role_for_sagemaker(iam_role_name) + + is_custom_docker_image = _ask_field( + "Do you want to use custom Docker image? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + docker_image = None + if is_custom_docker_image: + docker_image = _ask_field("Enter your Docker image: ", lambda x: str(x).lower()) + + is_sagemaker_inputs_enabled = _ask_field( + "Do you want to provide SageMaker input channels with data locations? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + sagemaker_inputs_file = None + if is_sagemaker_inputs_enabled: + sagemaker_inputs_file = _ask_field( + "Enter the path to the SageMaker inputs TSV file with columns (channel_name, data_location): ", + lambda x: str(x).lower(), + ) + + is_sagemaker_metrics_enabled = _ask_field( + "Do you want to enable SageMaker metrics? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + sagemaker_metrics_file = None + if is_sagemaker_metrics_enabled: + sagemaker_metrics_file = _ask_field( + "Enter the path to the SageMaker metrics TSV file with columns (metric_name, metric_regex): ", + lambda x: str(x).lower(), + ) + + distributed_type = _ask_options( + "What is the distributed mode?", + ["No distributed training", "Data parallelism"], + _convert_sagemaker_distributed_mode, + ) + dynamo_config = {} + use_dynamo = _ask_field( + "Do you wish to optimize your script with torch dynamo?[yes/NO]:", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + if use_dynamo: + prefix = "dynamo_" + dynamo_config[prefix + "backend"] = _ask_options( + "Which dynamo backend would you like to use?", + [x.lower() for x in DYNAMO_BACKENDS], + _convert_dynamo_backend, + default=2, + ) + use_custom_options = _ask_field( + "Do you want to customize the defaults sent to torch.compile? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + + if use_custom_options: + dynamo_config[prefix + "mode"] = _ask_options( + "Which mode do you want to use?", + TORCH_DYNAMO_MODES, + lambda x: TORCH_DYNAMO_MODES[int(x)], + default="default", + ) + dynamo_config[prefix + "use_fullgraph"] = _ask_field( + "Do you want the fullgraph mode or it is ok to break model into several subgraphs? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + dynamo_config[prefix + "use_dynamic"] = _ask_field( + "Do you want to enable dynamic shape tracing? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + dynamo_config[prefix + "use_regional_compilation"] = _ask_field( + "Do you want to enable regional compilation? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + + ec2_instance_query = "Which EC2 instance type you want to use for your training?" + if distributed_type != SageMakerDistributedType.NO: + ec2_instance_type = _ask_options( + ec2_instance_query, SAGEMAKER_PARALLEL_EC2_INSTANCES, lambda x: SAGEMAKER_PARALLEL_EC2_INSTANCES[int(x)] + ) + else: + ec2_instance_query += "? [ml.p3.2xlarge]:" + ec2_instance_type = _ask_field(ec2_instance_query, lambda x: str(x).lower(), default="ml.p3.2xlarge") + + debug = False + if distributed_type != SageMakerDistributedType.NO: + debug = _ask_field( + "Should distributed operations be checked while running for errors? This can avoid timeout issues but will be slower. [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + + num_machines = 1 + if distributed_type in (SageMakerDistributedType.DATA_PARALLEL, SageMakerDistributedType.MODEL_PARALLEL): + num_machines = _ask_field( + "How many machines do you want use? [1]: ", + int, + default=1, + ) + + mixed_precision = _ask_options( + "Do you wish to use FP16 or BF16 (mixed precision)?", + ["no", "fp16", "bf16", "fp8"], + _convert_mixed_precision, + ) + + if use_dynamo and mixed_precision == "no": + print( + "Torch dynamo used without mixed precision requires TF32 to be efficient. Accelerate will enable it by default when launching your scripts." + ) + + return SageMakerConfig( + image_uri=docker_image, + compute_environment=ComputeEnvironment.AMAZON_SAGEMAKER, + distributed_type=distributed_type, + use_cpu=False, + dynamo_config=dynamo_config, + ec2_instance_type=ec2_instance_type, + profile=aws_profile, + region=aws_region, + iam_role_name=iam_role_name, + mixed_precision=mixed_precision, + num_machines=num_machines, + sagemaker_inputs_file=sagemaker_inputs_file, + sagemaker_metrics_file=sagemaker_metrics_file, + debug=debug, + ) diff --git a/accelerate/commands/config/update.py b/accelerate/commands/config/update.py new file mode 100644 index 0000000000000000000000000000000000000000..369bb638e60516cb1ef46dbe53497c915834cbe4 --- /dev/null +++ b/accelerate/commands/config/update.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python + +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path + +from .config_args import default_config_file, load_config_from_file +from .config_utils import SubcommandHelpFormatter + + +description = "Update an existing config file with the latest defaults while maintaining the old configuration." + + +def update_config(args): + """ + Update an existing config file with the latest defaults while maintaining the old configuration. + """ + config_file = args.config_file + if config_file is None and Path(default_config_file).exists(): + config_file = default_config_file + elif not Path(config_file).exists(): + raise ValueError(f"The passed config file located at {config_file} doesn't exist.") + config = load_config_from_file(config_file) + + if config_file.endswith(".json"): + config.to_json_file(config_file) + else: + config.to_yaml_file(config_file) + return config_file + + +def update_command_parser(parser, parents): + parser = parser.add_parser("update", parents=parents, help=description, formatter_class=SubcommandHelpFormatter) + parser.add_argument( + "--config_file", + default=None, + help=( + "The path to the config file to update. Will default to a file named default_config.yaml in the cache " + "location, which is the content of the environment `HF_HOME` suffixed with 'accelerate', or if you don't have " + "such an environment variable, your cache directory ('~/.cache' or the content of `XDG_CACHE_HOME`) suffixed " + "with 'huggingface'." + ), + ) + + parser.set_defaults(func=update_config_command) + return parser + + +def update_config_command(args): + config_file = update_config(args) + print(f"Successfully updated the configuration file at {config_file}.") diff --git a/accelerate/commands/env.py b/accelerate/commands/env.py new file mode 100644 index 0000000000000000000000000000000000000000..22790c40bfb7c489dda7e5e5aa172998f80fd45b --- /dev/null +++ b/accelerate/commands/env.py @@ -0,0 +1,143 @@ +#!/usr/bin/env python + +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import platform +import subprocess + +import numpy as np +import psutil +import torch + +from accelerate import __version__ as version +from accelerate.commands.config import default_config_file, load_config_from_file + +from ..utils import ( + is_mlu_available, + is_musa_available, + is_neuron_available, + is_npu_available, + is_sdaa_available, + is_xpu_available, +) + + +def env_command_parser(subparsers=None): + if subparsers is not None: + parser = subparsers.add_parser("env") + else: + parser = argparse.ArgumentParser("Accelerate env command") + + parser.add_argument( + "--config_file", default=None, help="The config file to use for the default values in the launching script." + ) + + if subparsers is not None: + parser.set_defaults(func=env_command) + return parser + + +def env_command(args): + pt_version = torch.__version__ + pt_cuda_available = torch.cuda.is_available() + pt_xpu_available = is_xpu_available() + pt_mlu_available = is_mlu_available() + pt_sdaa_available = is_sdaa_available() + pt_musa_available = is_musa_available() + pt_npu_available = is_npu_available() + pt_neuron_available = is_neuron_available() + + accelerator = "N/A" + if pt_cuda_available: + accelerator = "CUDA" + elif pt_xpu_available: + accelerator = "XPU" + elif pt_mlu_available: + accelerator = "MLU" + elif pt_sdaa_available: + accelerator = "SDAA" + elif pt_musa_available: + accelerator = "MUSA" + elif pt_npu_available: + accelerator = "NPU" + elif pt_neuron_available: + accelerator = "NEURON" + + accelerate_config = "Not found" + # Get the default from the config file. + if args.config_file is not None or os.path.isfile(default_config_file): + accelerate_config = load_config_from_file(args.config_file).to_dict() + + # if we can run which, get it + command = None + bash_location = "Not found" + if os.name == "nt": + command = ["where", "accelerate"] + elif os.name == "posix": + command = ["which", "accelerate"] + if command is not None: + bash_location = subprocess.check_output(command, text=True, stderr=subprocess.STDOUT).strip() + info = { + "`Accelerate` version": version, + "Platform": platform.platform(), + "`accelerate` bash location": bash_location, + "Python version": platform.python_version(), + "Numpy version": np.__version__, + "PyTorch version": f"{pt_version}", + "PyTorch accelerator": accelerator, + "System RAM": f"{psutil.virtual_memory().total / 1024**3:.2f} GB", + } + if pt_cuda_available: + info["GPU type"] = torch.cuda.get_device_name() + elif pt_xpu_available: + info["XPU type"] = torch.xpu.get_device_name() + elif pt_mlu_available: + info["MLU type"] = torch.mlu.get_device_name() + elif pt_sdaa_available: + info["SDAA type"] = torch.sdaa.get_device_name() + elif pt_musa_available: + info["MUSA type"] = torch.musa.get_device_name() + elif pt_neuron_available: + info["NEURON type"] = torch.neuron.get_device_name() + elif pt_npu_available: + info["CANN version"] = torch.version.cann + + print("\nCopy-and-paste the text below in your GitHub issue\n") + print("\n".join([f"- {prop}: {val}" for prop, val in info.items()])) + + print("- `Accelerate` default config:" if args.config_file is None else "- `Accelerate` config passed:") + accelerate_config_str = ( + "\n".join([f"\t- {prop}: {val}" for prop, val in accelerate_config.items()]) + if isinstance(accelerate_config, dict) + else f"\t{accelerate_config}" + ) + print(accelerate_config_str) + + info["`Accelerate` configs"] = accelerate_config + + return info + + +def main() -> int: + parser = env_command_parser() + args = parser.parse_args() + env_command(args) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/accelerate/commands/estimate.py b/accelerate/commands/estimate.py new file mode 100644 index 0000000000000000000000000000000000000000..c1a8b41cead9d0d309bf482bde5a8f558b18b6fd --- /dev/null +++ b/accelerate/commands/estimate.py @@ -0,0 +1,318 @@ +#!/usr/bin/env python + +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +import torch +from huggingface_hub import model_info +from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError + +from accelerate import init_empty_weights +from accelerate.commands.utils import CustomArgumentParser +from accelerate.utils import ( + calculate_maximum_sizes, + convert_bytes, + is_timm_available, + is_transformers_available, +) + + +if is_transformers_available(): + import transformers + from transformers import AutoConfig, AutoModel + +if is_timm_available(): + import timm + + +def verify_on_hub(repo: str, token: Optional[str] = None): + "Verifies that the model is on the hub and returns the model info." + try: + return model_info(repo, token=token) + except (OSError, GatedRepoError): + return "gated" + except RepositoryNotFoundError: + return "repo" + + +def check_has_model(error): + """ + Checks what library spawned `error` when a model is not found + """ + if is_timm_available() and isinstance(error, RuntimeError) and "Unknown model" in error.args[0]: + return "timm" + elif ( + is_transformers_available() + and isinstance(error, OSError) + and "does not appear to have a file named" in error.args[0] + ): + return "transformers" + else: + return "unknown" + + +def create_empty_model( + model_name: str, library_name: str, trust_remote_code: bool = False, access_token: Optional[str] = None +): + """ + Creates an empty model in full precision from its parent library on the `Hub` to calculate the overall memory + consumption. + + Args: + model_name (`str`): + The model name on the Hub + library_name (`str`): + The library the model has an integration with, such as `transformers`. Will be used if `model_name` has no + metadata on the Hub to determine the library. + trust_remote_code (`bool`, `optional`, defaults to `False`): + Whether or not to allow for custom models defined on the Hub in their own modeling files. This option + should only be set to `True` for repositories you trust and in which you have read the code, as it will + execute code present on the Hub on your local machine. + access_token (`str`, `optional`, defaults to `None`): + The access token to use to access private or gated models on the Hub. (for use on the Gradio app) + + Returns: + `torch.nn.Module`: The torch model that has been initialized on the `meta` device. + + """ + model_info = verify_on_hub(model_name, access_token) + # Simplified errors + if model_info == "gated": + raise OSError( + f"Repo for model `{model_name}` is gated. You must be authenticated to access it. Please run `huggingface-cli login`." + ) + elif model_info == "repo": + raise OSError( + f"Repo for model `{model_name}` does not exist on the Hub. If you are trying to access a private repo," + " make sure you are authenticated via `huggingface-cli login` and have access." + ) + if library_name is None: + library_name = getattr(model_info, "library_name", False) + if not library_name: + raise ValueError( + f"Model `{model_name}` does not have any library metadata on the Hub, please manually pass in a `--library_name` to use (such as `transformers`)" + ) + if library_name == "transformers": + if not is_transformers_available(): + raise ImportError( + f"To check `{model_name}`, `transformers` must be installed. Please install it via `pip install transformers`" + ) + print(f"Loading pretrained config for `{model_name}` from `transformers`...") + if model_info.config is None: + raise RuntimeError(f"Tried to load `{model_name}` with `transformers` but it does not have any metadata.") + + auto_map = model_info.config.get("auto_map", False) + config = AutoConfig.from_pretrained(model_name, trust_remote_code=trust_remote_code, token=access_token) + with init_empty_weights(): + # remote code could specify a specific `AutoModel` class in the `auto_map` + constructor = AutoModel + if isinstance(auto_map, dict): + value = None + for key in auto_map.keys(): + if key.startswith("AutoModelFor"): + value = key + break + if value is not None: + constructor = getattr(transformers, value) + # we need to pass the dtype, otherwise it is going to use the torch_dtype that is saved in the config + model = constructor.from_config(config, torch_dtype=torch.float32, trust_remote_code=trust_remote_code) + elif library_name == "timm": + if not is_timm_available(): + raise ImportError( + f"To check `{model_name}`, `timm` must be installed. Please install it via `pip install timm`" + ) + print(f"Loading pretrained config for `{model_name}` from `timm`...") + with init_empty_weights(): + model = timm.create_model(model_name, pretrained=False) + else: + raise ValueError( + f"Library `{library_name}` is not supported yet, please open an issue on GitHub for us to add support." + ) + return model + + +def create_ascii_table(headers: list, rows: list, title: str): + "Creates a pretty table from a list of rows, minimal version of `tabulate`." + sep_char, in_between = "│", "─" + column_widths = [] + for i in range(len(headers)): + column_values = [row[i] for row in rows] + [headers[i]] + max_column_width = max(len(value) for value in column_values) + column_widths.append(max_column_width) + + formats = [f"%{column_widths[i]}s" for i in range(len(rows[0]))] + + pattern = f"{sep_char}{sep_char.join(formats)}{sep_char}" + diff = 0 + + def make_row(left_char, middle_char, right_char): + return f"{left_char}{middle_char.join([in_between * n for n in column_widths])}{in_between * diff}{right_char}" + + separator = make_row("├", "┼", "┤") + if len(title) > sum(column_widths): + diff = abs(len(title) - len(separator)) + column_widths[-1] += diff + + # Update with diff + separator = make_row("├", "┼", "┤") + initial_rows = [ + make_row("┌", in_between, "┐"), + f"{sep_char}{title.center(len(separator) - 2)}{sep_char}", + make_row("├", "┬", "┤"), + ] + table = "\n".join(initial_rows) + "\n" + column_widths[-1] += diff + centered_line = [text.center(column_widths[i]) for i, text in enumerate(headers)] + table += f"{pattern % tuple(centered_line)}\n{separator}\n" + for i, line in enumerate(rows): + centered_line = [t.center(column_widths[i]) for i, t in enumerate(line)] + table += f"{pattern % tuple(centered_line)}\n" + table += f"└{'┴'.join([in_between * n for n in column_widths])}┘" + + return table + + +def estimate_command_parser(subparsers=None): + if subparsers is not None: + parser = subparsers.add_parser("estimate-memory") + else: + parser = CustomArgumentParser( + description="Model size estimator for fitting a model onto device(e.g. cuda, xpu) memory." + ) + + parser.add_argument("model_name", type=str, help="The model name on the Hugging Face Hub.") + parser.add_argument( + "--library_name", + type=str, + help="The library the model has an integration with, such as `transformers`, needed only if this information is not stored on the Hub.", + choices=["timm", "transformers"], + ) + parser.add_argument( + "--dtypes", + type=str, + nargs="+", + default=["float32", "float16", "int8", "int4"], + help="The dtypes to use for the model, must be one (or many) of `float32`, `float16`, `int8`, and `int4`", + choices=["float32", "float16", "int8", "int4"], + ) + parser.add_argument( + "--trust_remote_code", + action="store_true", + help="""Whether or not to allow for custom models defined on the Hub in their own modeling files. This flag + should only be used for repositories you trust and in which you have read the code, as it will execute + code present on the Hub on your local machine.""", + default=False, + ) + + if subparsers is not None: + parser.set_defaults(func=estimate_command) + return parser + + +def estimate_training_usage(bytes: int, mixed_precision: str, msamp_config: Optional[str] = None) -> dict: + """ + Given an amount of `bytes` and `mixed_precision`, calculates how much training memory is needed for a batch size of + 1. + + Args: + bytes (`int`): + The size of the model being trained. + mixed_precision (`str`): + The mixed precision that would be ran. + msamp_config (`str`): + The msamp config to estimate the training memory for if `mixed_precision` is set to `"fp8"`. + """ + memory_sizes = {"model": -1, "optimizer": -1, "gradients": -1, "step": -1} + fp32_size = bytes + fp16_size = bytes // 2 + + if mixed_precision == "float32": + memory_sizes["model"] = fp32_size + memory_sizes["gradients"] = fp32_size + memory_sizes["optimizer"] = fp32_size * 2 + memory_sizes["step"] = fp32_size * 4 + elif mixed_precision in ("float16", "bfloat16") or (mixed_precision == "fp8" and msamp_config is None): + # With native `TransformersEngine`, there is no memory savings with FP8 + # With mixed precision training, the model has weights stored + # in FP16 and FP32 + memory_sizes["model"] = fp32_size + # 1.5 from weight gradient + computation (GEMM) + memory_sizes["gradients"] = fp32_size + fp16_size + # 2x from optimizer states + memory_sizes["optimizer"] = fp32_size * 2 # Optimizer states + memory_sizes["step"] = memory_sizes["optimizer"] + return memory_sizes + + +def gather_data(args): + "Creates an empty model and gathers the data for the sizes" + try: + model = create_empty_model( + args.model_name, library_name=args.library_name, trust_remote_code=args.trust_remote_code + ) + except (RuntimeError, OSError) as e: + library = check_has_model(e) + if library != "unknown": + raise RuntimeError( + f"Tried to load `{args.model_name}` with `{library}` but a possible model to load was not found inside the repo." + ) + raise e + + total_size, largest_layer = calculate_maximum_sizes(model) + + data = [] + + for dtype in args.dtypes: + dtype_total_size = total_size + dtype_largest_layer = largest_layer[0] + dtype_training_size = estimate_training_usage(dtype_total_size, dtype) + if dtype == "float16": + dtype_total_size /= 2 + dtype_largest_layer /= 2 + elif dtype == "int8": + dtype_total_size /= 4 + dtype_largest_layer /= 4 + elif dtype == "int4": + dtype_total_size /= 8 + dtype_largest_layer /= 8 + data.append([dtype, dtype_largest_layer, dtype_total_size, dtype_training_size]) + return data + + +def estimate_command(args): + data = gather_data(args) + for row in data: + for i, item in enumerate(row): + if isinstance(item, (int, float)): + row[i] = convert_bytes(item) + elif isinstance(item, dict): + training_usage = max(item.values()) + row[i] = convert_bytes(training_usage) if training_usage != -1 else "N/A" + + headers = ["dtype", "Largest Layer", "Total Size", "Training using Adam"] + + title = f"Memory Usage for loading `{args.model_name}`" + table = create_ascii_table(headers, data, title) + print(table) + + +def main(): + parser = estimate_command_parser() + args = parser.parse_args() + estimate_command(args) + + +if __name__ == "__main__": + main() diff --git a/accelerate/commands/launch.py b/accelerate/commands/launch.py new file mode 100644 index 0000000000000000000000000000000000000000..3b7becf1b0c7f864b04e626b5313bae7647774cf --- /dev/null +++ b/accelerate/commands/launch.py @@ -0,0 +1,1415 @@ +#!/usr/bin/env python + +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import importlib +import logging +import os +import subprocess +import sys +from pathlib import Path + +import torch + +from accelerate.commands.config import default_config_file, load_config_from_file +from accelerate.commands.config.config_args import SageMakerConfig +from accelerate.commands.config.config_utils import DYNAMO_BACKENDS +from accelerate.commands.utils import CustomArgumentParser +from accelerate.state import get_int_from_env +from accelerate.utils import ( + ComputeEnvironment, + DistributedType, + PrepareForLaunch, + _filter_args, + check_cuda_p2p_ib_support, + convert_dict_to_env_variables, + is_bf16_available, + is_deepspeed_available, + is_hpu_available, + is_mlu_available, + is_musa_available, + is_neuron_available, + is_npu_available, + is_rich_available, + is_sagemaker_available, + is_sdaa_available, + is_torch_xla_available, + is_xpu_available, + patch_environment, + prepare_deepspeed_cmd_env, + prepare_multi_gpu_env, + prepare_sagemager_args_inputs, + prepare_simple_launcher_cmd_env, + prepare_tpu, + str_to_bool, +) +from accelerate.utils.constants import DEEPSPEED_MULTINODE_LAUNCHERS, TORCH_DYNAMO_MODES + + +if is_rich_available(): + from rich import get_console + from rich.logging import RichHandler + + FORMAT = "%(message)s" + logging.basicConfig(format=FORMAT, datefmt="[%X]", handlers=[RichHandler()]) + + +logger = logging.getLogger(__name__) + + +options_to_group = { + "multi_gpu": "Distributed GPUs", + "tpu": "TPU", + "use_deepspeed": "DeepSpeed Arguments", + "use_fsdp": "FSDP Arguments", + "use_megatron_lm": "Megatron-LM Arguments", + "fp8_backend": "FP8 Arguments", +} + + +def clean_option(option): + "Finds all cases of - after the first two characters and changes them to _" + if "fp8_backend" in option: + option = "--fp8_backend" + if option.startswith("--"): + return option[2:].replace("-", "_") + + +class CustomHelpFormatter(argparse.HelpFormatter): + """ + This is a custom help formatter that will hide all arguments that are not used in the command line when the help is + called. This is useful for the case where the user is using a specific platform and only wants to see the arguments + for that platform. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.titles = [ + "Hardware Selection Arguments", + "Resource Selection Arguments", + "Training Paradigm Arguments", + "positional arguments", + "optional arguments", + ] + + def add_argument(self, action: argparse.Action): + if "accelerate" in sys.argv[0] and "launch" in sys.argv[1:]: + args = sys.argv[2:] + else: + args = sys.argv[1:] + + if len(args) > 1: + args = list(map(clean_option, args)) + used_platforms = [arg for arg in args if arg in options_to_group.keys()] + used_titles = [options_to_group[o] for o in used_platforms] + if action.container.title not in self.titles + used_titles: + action.help = argparse.SUPPRESS + elif action.container.title == "Hardware Selection Arguments": + if set(action.option_strings).isdisjoint(set(args)): + action.help = argparse.SUPPRESS + else: + action.help = action.help + " (currently selected)" + elif action.container.title == "Training Paradigm Arguments": + if set(action.option_strings).isdisjoint(set(args)): + action.help = argparse.SUPPRESS + else: + action.help = action.help + " (currently selected)" + + action.option_strings = [s for s in action.option_strings if "-" not in s[2:]] + super().add_argument(action) + + def end_section(self): + if len(self._current_section.items) < 2: + self._current_section.items = [] + self._current_section.heading = "" + super().end_section() + + +def launch_command_parser(subparsers=None): + description = "Launch a python script in a distributed scenario. Arguments can be passed in with either hyphens (`--num-processes=2`) or underscores (`--num_processes=2`)" + if subparsers is not None: + parser = subparsers.add_parser( + "launch", description=description, add_help=False, allow_abbrev=False, formatter_class=CustomHelpFormatter + ) + else: + parser = CustomArgumentParser( + "Accelerate launch command", + description=description, + add_help=False, + allow_abbrev=False, + formatter_class=CustomHelpFormatter, + ) + + parser.add_argument("-h", "--help", action="help", help="Show this help message and exit.") + + parser.add_argument( + "--config_file", + default=None, + help="The config file to use for the default values in the launching script.", + ) + parser.add_argument( + "--quiet", + "-q", + action="store_true", + help="Silence subprocess errors from the launch stack trace and only show the relevant tracebacks. (Only applicable to DeepSpeed and single-process configurations)", + ) + # Hardware selection arguments + hardware_args = parser.add_argument_group( + "Hardware Selection Arguments", "Arguments for selecting the hardware to be used." + ) + hardware_args.add_argument( + "--cpu", default=False, action="store_true", help="Whether or not to force the training on the CPU." + ) + hardware_args.add_argument( + "--multi_gpu", + default=False, + action="store_true", + help="Whether or not this should launch a distributed GPU training.", + ) + hardware_args.add_argument( + "--tpu", default=False, action="store_true", help="Whether or not this should launch a TPU training." + ) + # Resource selection arguments + resource_args = parser.add_argument_group( + "Resource Selection Arguments", "Arguments for fine-tuning how available hardware should be used." + ) + resource_args.add_argument( + "--mixed_precision", + type=str, + choices=["no", "fp16", "bf16", "fp8"], + help="Whether or not to use mixed precision training. " + "Choose between FP16 and BF16 (bfloat16) training. " + "BF16 training is only supported on Nvidia Ampere GPUs and PyTorch 1.10 or later.", + ) + resource_args.add_argument( + "--num_processes", type=int, default=None, help="The total number of processes to be launched in parallel." + ) + resource_args.add_argument( + "--num_machines", type=int, default=None, help="The total number of machines used in this training." + ) + resource_args.add_argument( + "--num_cpu_threads_per_process", + type=int, + default=None, + help="The number of CPU threads per process. Can be tuned for optimal performance.", + ) + resource_args.add_argument( + "--enable_cpu_affinity", + default=False, + action="store_true", + help="Whether or not CPU affinity and balancing should be enabled. Currently only supported on NVIDIA hardware.", + ) + # Dynamo arguments + resource_args.add_argument( + "--dynamo_backend", + type=str, + choices=["no"] + [b.lower() for b in DYNAMO_BACKENDS], + help="Choose a backend to optimize your training with dynamo, see more at " + "https://github.com/pytorch/torchdynamo.", + ) + resource_args.add_argument( + "--dynamo_mode", + type=str, + default="default", + choices=TORCH_DYNAMO_MODES, + help="Choose a mode to optimize your training with dynamo.", + ) + resource_args.add_argument( + "--dynamo_use_fullgraph", + default=False, + action="store_true", + help="Whether to use full graph mode for dynamo or it is ok to break model into several subgraphs", + ) + resource_args.add_argument( + "--dynamo_use_dynamic", + default=False, + action="store_true", + help="Whether to enable dynamic shape tracing.", + ) + resource_args.add_argument( + "--dynamo_use_regional_compilation", + default=False, + action="store_true", + help="Whether to enable regional compilation.", + ) + + # Training Paradigm arguments + paradigm_args = parser.add_argument_group( + "Training Paradigm Arguments", "Arguments for selecting which training paradigm to be used." + ) + paradigm_args.add_argument( + "--use_deepspeed", + default=False, + action="store_true", + help="Whether to use deepspeed.", + ) + paradigm_args.add_argument( + "--use_fsdp", + default=False, + action="store_true", + help="Whether to use fsdp.", + ) + paradigm_args.add_argument( + "--use_parallelism_config", + default=False, + action="store_true", + help="Whether to use the parallelism config to configure the N-d distributed training.", + ) + paradigm_args.add_argument( + "--use_megatron_lm", + default=False, + action="store_true", + help="Whether to use Megatron-LM.", + ) + + # distributed GPU training arguments + distributed_args = parser.add_argument_group("Distributed GPUs", "Arguments related to distributed GPU training.") + distributed_args.add_argument( + "--gpu_ids", + default=None, + help="What GPUs (by id) should be used for training on this machine as a comma-separated list", + ) + distributed_args.add_argument( + "--same_network", + default=False, + action="store_true", + help="Whether all machines used for multinode training exist on the same local network.", + ) + distributed_args.add_argument( + "--machine_rank", type=int, default=None, help="The rank of the machine on which this script is launched." + ) + distributed_args.add_argument( + "--main_process_ip", type=str, default=None, help="The IP address of the machine of rank 0." + ) + distributed_args.add_argument( + "--main_process_port", + type=int, + default=None, + help="The port to use to communicate with the machine of rank 0.", + ) + distributed_args.add_argument( + "-t", + "--tee", + default="0", + type=str, + help="Tee std streams into a log file and also to console.", + ) + distributed_args.add_argument( + "--log_dir", + type=str, + default=None, + help=( + "Base directory to use for log files when using torchrun/torch.distributed.run as launcher. " + "Use with --tee to redirect std streams info log files." + ), + ) + distributed_args.add_argument( + "--role", + type=str, + default="default", + help="User-defined role for the workers.", + ) + # Rendezvous related arguments + distributed_args.add_argument( + "--rdzv_backend", + type=str, + default="static", + help="The rendezvous method to use, such as 'static' (the default) or 'c10d'", + ) + distributed_args.add_argument( + "--rdzv_conf", + type=str, + default="", + help="Additional rendezvous configuration (=,=,...).", + ) + distributed_args.add_argument( + "--max_restarts", + type=int, + default=0, + help="Maximum number of worker group restarts before failing.", + ) + distributed_args.add_argument( + "--monitor_interval", + type=float, + default=0.1, + help="Interval, in seconds, to monitor the state of workers.", + ) + parser.add_argument( + "-m", + "--module", + action="store_true", + help="Change each process to interpret the launch script as a Python module, executing with the same behavior as 'python -m'.", + ) + parser.add_argument( + "--no_python", + action="store_true", + help="Skip prepending the training script with 'python' - just execute it directly. Useful when the script is not a Python script.", + ) + + # TPU arguments + tpu_args = parser.add_argument_group("TPU", "Arguments related to TPU.") + tpu_args.add_argument( + "--tpu_cluster", + action="store_true", + dest="tpu_use_cluster", + help="Whether to use a GCP TPU pod for training.", + ) + tpu_args.add_argument( + "--no_tpu_cluster", + action="store_false", + dest="tpu_use_cluster", + help="Should not be passed explicitly, this is for internal use only.", + ) + tpu_args.add_argument( + "--tpu_use_sudo", + action="store_true", + help="Whether to use `sudo` when running the TPU training script in each pod.", + ) + tpu_args.add_argument( + "--vm", + type=str, + action="append", + help=( + "List of single Compute VM instance names. " + "If not provided we assume usage of instance groups. For TPU pods." + ), + ) + tpu_args.add_argument( + "--env", + type=str, + action="append", + help="List of environment variables to set on the Compute VM instances. For TPU pods.", + ) + tpu_args.add_argument( + "--main_training_function", + type=str, + default=None, + help="The name of the main function to be executed in your script (only for TPU training).", + ) + tpu_args.add_argument( + "--downcast_bf16", + action="store_true", + help="Whether when using bf16 precision on TPUs if both float and double tensors are cast to bfloat16 or if double tensors remain as float32.", + ) + + # DeepSpeed arguments + deepspeed_args = parser.add_argument_group("DeepSpeed Arguments", "Arguments related to DeepSpeed.") + deepspeed_args.add_argument( + "--deepspeed_config_file", + default=None, + type=str, + help="DeepSpeed config file.", + ) + deepspeed_args.add_argument( + "--zero_stage", + default=None, + type=int, + help="DeepSpeed's ZeRO optimization stage (useful only when `use_deepspeed` flag is passed). " + "If unspecified, will default to `2`.", + ) + deepspeed_args.add_argument( + "--offload_optimizer_device", + default=None, + type=str, + help="Decides where (none|cpu|nvme) to offload optimizer states (useful only when `use_deepspeed` flag is passed). " + "If unspecified, will default to 'none'.", + ) + deepspeed_args.add_argument( + "--offload_param_device", + default=None, + type=str, + help="Decides where (none|cpu|nvme) to offload parameters (useful only when `use_deepspeed` flag is passed). " + "If unspecified, will default to 'none'.", + ) + deepspeed_args.add_argument( + "--offload_optimizer_nvme_path", + default=None, + type=str, + help="Decides Nvme Path to offload optimizer states (useful only when `use_deepspeed` flag is passed). " + "If unspecified, will default to 'none'.", + ) + deepspeed_args.add_argument( + "--offload_param_nvme_path", + default=None, + type=str, + help="Decides Nvme Path to offload parameters (useful only when `use_deepspeed` flag is passed). " + "If unspecified, will default to 'none'.", + ) + deepspeed_args.add_argument( + "--gradient_accumulation_steps", + default=None, + type=int, + help="No of gradient_accumulation_steps used in your training script (useful only when `use_deepspeed` flag is passed). " + "If unspecified, will default to `1`.", + ) + deepspeed_args.add_argument( + "--gradient_clipping", + default=None, + type=float, + help="gradient clipping value used in your training script (useful only when `use_deepspeed` flag is passed). " + "If unspecified, will default to `1.0`.", + ) + deepspeed_args.add_argument( + "--zero3_init_flag", + default=None, + type=str, + help="Decides Whether (true|false) to enable `deepspeed.zero.Init` for constructing massive models. " + "Only applicable with DeepSpeed ZeRO Stage-3. If unspecified, will default to `true`.", + ) + deepspeed_args.add_argument( + "--zero3_save_16bit_model", + default=None, + type=str, + help="Decides Whether (true|false) to save 16-bit model weights when using ZeRO Stage-3. " + "Only applicable with DeepSpeed ZeRO Stage-3. If unspecified, will default to `false`.", + ) + deepspeed_args.add_argument( + "--deepspeed_hostfile", + default=None, + type=str, + help="DeepSpeed hostfile for configuring multi-node compute resources.", + ) + deepspeed_args.add_argument( + "--deepspeed_exclusion_filter", + default=None, + type=str, + help="DeepSpeed exclusion filter string when using multi-node setup.", + ) + deepspeed_args.add_argument( + "--deepspeed_inclusion_filter", + default=None, + type=str, + help="DeepSpeed inclusion filter string when using multi-node setup.", + ) + deepspeed_args.add_argument( + "--deepspeed_multinode_launcher", + default=None, + type=str, + help="DeepSpeed multi-node launcher to use, e.g. `pdsh`, `standard`, `openmpi`, `mvapich`, `mpich`, `slurm`, `nossh` (requires DeepSpeed >= 0.14.5). If unspecified, will default to `pdsh`.", + ) + deepspeed_args.add_argument( + "--deepspeed_moe_layer_cls_names", + default=None, + type=str, + help="comma-separated list of transformer MoE layer class names (case-sensitive) to wrap ,e.g, `MixtralSparseMoeBlock`, `Qwen2MoeSparseMoeBlock`, `JetMoEAttention,JetMoEBlock` ..." + " (useful only when `use_deepspeed` flag is passed).", + ) + + # fsdp arguments + fsdp_args = parser.add_argument_group("FSDP Arguments", "Arguments related to Fully Shared Data Parallelism.") + fsdp_args.add_argument( + "--fsdp_version", + type=str, + default="1", + choices=["1", "2"], + help="FSDP version to use. (useful only when `use_fsdp` flag is passed).", + ) + fsdp_args.add_argument( + "--fsdp_offload_params", + default="false", + type=str, + help="Decides Whether (true|false) to offload parameters and gradients to CPU. (useful only when `use_fsdp` flag is passed).", + ) + fsdp_args.add_argument( + "--fsdp_min_num_params", + type=int, + default=int(1e8), + help="FSDP's minimum number of parameters for Default Auto Wrapping. (useful only when `use_fsdp` flag is passed).", + ) + # We enable this for backwards compatibility, throw a warning if this is set in `FullyShardedDataParallelPlugin` + fsdp_args.add_argument( + "--fsdp_sharding_strategy", + type=str, + default="FULL_SHARD", + help="FSDP's sharding strategy. (useful only when `use_fsdp` flag is passed and `fsdp_version=1`).", + ) + fsdp_args.add_argument( + "--fsdp_reshard_after_forward", + type=str, + default="true", + help="FSDP's Reshard After Forward Strategy. (useful only when `use_fsdp` flag is passed). Supports either boolean (FSDP2) or `FULL_SHARD | SHARD_GRAD_OP | NO_RESHARD` (FSDP1).", + ) + fsdp_args.add_argument( + "--fsdp_auto_wrap_policy", + type=str, + default=None, + help="FSDP's auto wrap policy. (useful only when `use_fsdp` flag is passed).", + ) + fsdp_args.add_argument( + "--fsdp_transformer_layer_cls_to_wrap", + default=None, + type=str, + help="Transformer layer class name (case-sensitive) to wrap ,e.g, `BertLayer`, `GPTJBlock`, `T5Block` .... " + "(useful only when `use_fsdp` flag is passed).", + ) + fsdp_args.add_argument( + "--fsdp_backward_prefetch", + default=None, + type=str, + help="FSDP's backward prefetch policy. (useful only when `use_fsdp` flag is passed).", + ) + fsdp_args.add_argument( + "--fsdp_state_dict_type", + default=None, + type=str, + help="FSDP's state dict type. (useful only when `use_fsdp` flag is passed).", + ) + fsdp_args.add_argument( + "--fsdp_forward_prefetch", + default="false", + type=str, + help="If True, then FSDP explicitly prefetches the next upcoming " + "all-gather while executing in the forward pass (useful only when `use_fsdp` flag is passed).", + ) + fsdp_args.add_argument( + "--fsdp_use_orig_params", + default="true", + type=str, + help="If True, allows non-uniform `requires_grad` during init, which means support for interspersed frozen and trainable parameters." + " (useful only when `use_fsdp` flag is passed).", + ) + fsdp_args.add_argument( + "--fsdp_cpu_ram_efficient_loading", + default="true", + type=str, + help="If True, only the first process loads the pretrained model checkoint while all other processes have empty weights. " + "Only applicable for 🤗 Transformers. When using this, `--fsdp_sync_module_states` needs to True. " + "(useful only when `use_fsdp` flag is passed).", + ) + fsdp_args.add_argument( + "--fsdp_sync_module_states", + default="true", + type=str, + help="If True, each individually wrapped FSDP unit will broadcast module parameters from rank 0." + " (useful only when `use_fsdp` flag is passed).", + ) + fsdp_args.add_argument( + "--fsdp_activation_checkpointing", + default="false", + type=str, + help="Decides Whether (true|false) intermediate activations are freed during the forward pass, and a checkpoint is left as a placeholder. (useful only when `use_fsdp` flag is passed).", + ) + + # megatron_lm args + megatron_lm_args = parser.add_argument_group("Megatron-LM Arguments", "Arguments related to Megatron-LM.") + megatron_lm_args.add_argument( + "--megatron_lm_tp_degree", + type=int, + default=1, + help="Megatron-LM's Tensor Parallelism (TP) degree. (useful only when `use_megatron_lm` flag is passed).", + ) + megatron_lm_args.add_argument( + "--megatron_lm_use_custom_fsdp", + type=bool, + default=False, + help="Whether to use custom FSDP. (useful only when `use_megatron_lm` flag is passed).", + ) + megatron_lm_args.add_argument( + "--megatron_lm_no_load_optim", + type=bool, + default=False, + help="Whether to not load optimizer. (useful only when `use_megatron_lm` flag is passed).", + ) + megatron_lm_args.add_argument( + "--megatron_lm_eod_mask_loss", + type=bool, + default=False, + help="Whether to use eod mask loss. (useful only when `use_megatron_lm` flag is passed).", + ) + megatron_lm_args.add_argument( + "--megatron_lm_overlap_cpu_optimizer_d2h_h2d", + type=bool, + default=False, + help="Whether to overlap CPU optimizer step, gradients D2H and updated parameters H2D. (useful only when `use_megatron_lm` flag is passed).", + ) + megatron_lm_args.add_argument( + "--megatron_lm_no_save_optim", + type=bool, + default=False, + help="Whether to not save optimizer. (useful only when `use_megatron_lm` flag is passed).", + ) + megatron_lm_args.add_argument( + "--megatron_lm_optimizer_cpu_offload", + type=bool, + default=False, + help="Whether to use CPU offload for optimizer. (useful only when `use_megatron_lm` flag is passed).", + ) + megatron_lm_args.add_argument( + "--megatron_lm_use_precision_aware_optimizer", + type=bool, + default=False, + help="Whether to use precision aware optimizer. (useful only when `use_megatron_lm` flag is passed).", + ) + megatron_lm_args.add_argument( + "--megatron_lm_decoder_last_pipeline_num_layers", + type=int, + default=None, + help="Megatron-LM's decoder last pipeline number of layers, default None is even split of transformer layers across all pipeline stages.", + ) + megatron_lm_args.add_argument( + "--megatron_lm_pp_degree", + type=int, + default=1, + help="Megatron-LM's Pipeline Parallelism (PP) degree. (useful only when `use_megatron_lm` flag is passed).", + ) + megatron_lm_args.add_argument( + "--megatron_lm_num_micro_batches", + type=int, + default=None, + help="Megatron-LM's number of micro batches when PP degree > 1. (useful only when `use_megatron_lm` flag is passed).", + ) + megatron_lm_args.add_argument( + "--megatron_lm_sequence_parallelism", + default=None, + type=str, + help="Decides Whether (true|false) to enable Sequence Parallelism when TP degree > 1. " + "(useful only when `use_megatron_lm` flag is passed).", + ) + megatron_lm_args.add_argument( + "--megatron_lm_recompute_activations", + default=None, + type=str, + help="Decides Whether (true|false) to enable Selective Activation Recomputation. " + "(useful only when `use_megatron_lm` flag is passed).", + ) + megatron_lm_args.add_argument( + "--megatron_lm_use_distributed_optimizer", + default=None, + type=str, + help="Decides Whether (true|false) to use distributed optimizer " + "which shards optimizer state and gradients across Data Pralellel (DP) ranks. " + "(useful only when `use_megatron_lm` flag is passed).", + ) + megatron_lm_args.add_argument( + "--megatron_lm_gradient_clipping", + default=1.0, + type=float, + help="Megatron-LM's gradient clipping value based on global L2 Norm (0 to disable). " + "(useful only when `use_megatron_lm` flag is passed).", + ) + megatron_lm_args.add_argument( + "--megatron_lm_recompute_granularity", + default=None, + type=str, + help="Megatron-LM's recompute granularity (full, selective). " + "(useful only when `use_megatron_lm` flag is passed).", + ) + megatron_lm_args.add_argument( + "--megatron_lm_recompute_method", + default=None, + type=str, + help="Megatron-LM's recompute method (uniform, block). (useful only when `use_megatron_lm` flag is passed).", + ) + megatron_lm_args.add_argument( + "--megatron_lm_recompute_num_layers", + default=None, + type=int, + help="Megatron-LM's number of layers to recompute. (useful only when `use_megatron_lm` flag is passed).", + ) + megatron_lm_args.add_argument( + "--megatron_lm_attention_backend", + default=None, + type=str, + help="Decides Whether (true|false) to enable attention backend. " + "(useful only when `use_megatron_lm` flag is passed).", + ) + megatron_lm_args.add_argument( + "--megatron_lm_expert_model_parallel_size", + default=None, + type=int, + help="Megatron-LM's expert model parallel size. (useful only when `use_megatron_lm` flag is passed).", + ) + megatron_lm_args.add_argument( + "--megatron_lm_context_parallel_size", + default=None, + type=int, + help="Megatron-LM's context parallel size. (useful only when `use_megatron_lm` flag is passed).", + ) + megatron_lm_args.add_argument( + "--megatron_lm_attention_dropout", + default=None, + type=float, + help="Megatron-LM's attention dropout rate. (useful only when `use_megatron_lm` flag is passed).", + ) + megatron_lm_args.add_argument( + "--megatron_lm_hidden_dropout", + default=None, + type=float, + help="Megatron-LM's hidden dropout rate. (useful only when `use_megatron_lm` flag is passed).", + ) + megatron_lm_args.add_argument( + "--megatron_lm_attention_softmax_in_fp32", + default=None, + type=str, + help="Decides Whether (true|false) to use fp32 for attention softmax. " + "(useful only when `use_megatron_lm` flag is passed).", + ) + megatron_lm_args.add_argument( + "--megatron_lm_expert_tensor_parallel_size", + default=None, + type=int, + help="Megatron-LM's expert tensor parallel size. (useful only when `use_megatron_lm` flag is passed).", + ) + megatron_lm_args.add_argument( + "--megatron_lm_calculate_per_token_loss", + default=None, + type=str, + help="Decides Whether (true|false) to calculate per token loss. " + "(useful only when `use_megatron_lm` flag is passed).", + ) + megatron_lm_args.add_argument( + "--megatron_lm_use_rotary_position_embeddings", + default=None, + type=str, + help="Decides Whether (true|false) to use rotary position embeddings. " + "(useful only when `use_megatron_lm` flag is passed).", + ) + + # FP8 arguments + fp8_args = parser.add_argument_group( + "FP8 Arguments", "Arguments related to FP8 training (requires `--mixed_precision=fp8`)" + ) + fp8_args.add_argument( + "--fp8_backend", + type=str, + choices=["ao", "te", "msamp"], + help="Choose a backend to train with FP8 (ao: torchao, te: TransformerEngine, msamp: MS-AMP)", + ) + fp8_args.add_argument( + "--fp8_use_autocast_during_eval", + default=False, + action="store_true", + help="Whether to use FP8 autocast during eval mode (useful only when `--fp8_backend=te` is passed). Generally better metrics are found when this is not passed.", + ) + fp8_args.add_argument( + "--fp8_margin", + type=int, + default=0, + help="The margin to use for the gradient scaling (useful only when `--fp8_backend=te` is passed).", + ) + fp8_args.add_argument( + "--fp8_interval", + type=int, + default=1, + help="The interval to use for how often the scaling factor is recomputed (useful only when `--fp8_backend=te` is passed).", + ) + fp8_args.add_argument( + "--fp8_format", + type=str, + default="HYBRID", + choices=["HYBRID", "E4M3", "E5M2"], + help="The format to use for the FP8 recipe (useful only when `--fp8_backend=te` is passed).", + ) + fp8_args.add_argument( + "--fp8_amax_history_len", + type=int, + default=1024, + help="The length of the history to use for the scaling factor computation (useful only when `--fp8_backend=te` is passed).", + ) + fp8_args.add_argument( + "--fp8_amax_compute_algo", + type=str, + default="most_recent", + choices=["max", "most_recent"], + help="The algorithm to use for the scaling factor computation. (useful only when `--fp8_backend=te` is passed).", + ) + fp8_args.add_argument( + "--fp8_override_linear_precision", + type=lambda x: tuple(map(str_to_bool, x.split(","))), + default=(False, False, False), + help="Whether or not to execute `fprop`, `dgrad`, and `wgrad` GEMMS in higher precision. Should be passed in a comma-separated string of booleans (useful only when `--fp8_backend=te` is passed).", + ) + fp8_args.add_argument( + "--fp8_opt_level", + type=str, + default="O2", + choices=["O1", "O2"], + help="What level of 8-bit collective communication should be used with MS-AMP (useful only when `--fp8_backend=msamp` is passed).", + ) + fp8_args.add_argument( + "--fp8_enable_fsdp_float8_all_gather", + default="true", + type=str_to_bool, + help="Whether to enable FSDP2 float8 all gather (useful only when `--fp8_backend=ao` is passed).", + ) + fp8_args.add_argument( + "--fp8_pad_inner_dim", + default="true", + type=str_to_bool, + help="Whether to pad the inner dimension for FP8 GEMMs (useful only when `--fp8_backend=ao` is passed).", + ) + + # AWS arguments + aws_args = parser.add_argument_group("AWS Arguments", "Arguments related to AWS.") + aws_args.add_argument( + "--aws_access_key_id", + type=str, + default=None, + help="The AWS_ACCESS_KEY_ID used to launch the Amazon SageMaker training job", + ) + aws_args.add_argument( + "--aws_secret_access_key", + type=str, + default=None, + help="The AWS_SECRET_ACCESS_KEY used to launch the Amazon SageMaker training job.", + ) + parser.add_argument( + "--debug", + action="store_true", + help="Whether to print out the torch.distributed stack trace when something fails.", + ) + parser.add_argument( + "training_script", + type=str, + help=( + "The full path to the script to be launched in parallel, followed by all the arguments for the training " + "script." + ), + ) + + # MPI arguments + mpirun_args = parser.add_argument_group("MPI Arguments", "Arguments related to mpirun for Multi-CPU") + mpirun_args.add_argument( + "--mpirun_hostfile", + type=str, + default=None, + help="Location for a hostfile for using Accelerate to launch a multi-CPU training job with mpirun. This will " + "get passed to the MPI --hostfile or -f parameter, depending on which MPI program is installed.", + ) + + # ParallelismConfig arguments + parallelism_config_args = parser.add_argument_group( + "ParallelismConfig Arguments", + "Arguments related to the ParallelismConfig used for distributed training.", + ) + + parallelism_config_args.add_argument( + "--parallelism_config_dp_replicate_size", + type=int, + default=1, + help="The number of processes for data parallel training. Defaults to 1 (no data parallelism).", + ) + + parallelism_config_args.add_argument( + "--parallelism_config_dp_shard_size", + type=int, + default=1, + help="The number of processes for FSDP sharding. Defaults to 1 (No FSDP sharding).", + ) + + parallelism_config_args.add_argument( + "--parallelism_config_tp_size", + type=int, + default=1, + help="The number of processes for tensor parallel training. Defaults to 1 (no tensor parallelism).", + ) + + parallelism_config_args.add_argument( + "--parallelism_config_cp_size", + type=int, + default=1, + help="The number of processese for context parallel training. Defaults to 1 (no context parallelism).", + ) + + parallelism_config_args.add_argument( + "--parallelism_config_cp_backend", + type=str, + choices=["torch"], + default="torch", + help="Context Parallelism backend: torch (FSDP2) or deepspeed (ALST/Ulysses)", + ) + + parallelism_config_args.add_argument( + "--parallelism_config_cp_comm_strategy", + type=str, + default="allgather", + help="The communication strategy for context parallel training. Defaults to 'allgather'. Other option is alltoall", + ) + + parallelism_config_args.add_argument( + "--parallelism_config_sp_size", + type=int, + default=1, + help="The number of processese for context parallel training. Defaults to 1 (no context parallelism).", + ) + + parallelism_config_args.add_argument( + "--parallelism_config_sp_backend", + type=str, + choices=["deepspeed"], + default="deepspeed", + help="Sequence Parallelism backend: deepspeed (ALST/Ulysses)", + ) + + parallelism_config_args.add_argument( + "--parallelism_config_sp_seq_length", + type=str, + default=None, + help="Sequence length for when batches are all of the same length. For variable sequence lengths across batches set `parallelism_config_sp_seq_length_is_variable=True`", + ) + + parallelism_config_args.add_argument( + "--parallelism_config_sp_seq_length_is_variable", + type=bool, + default=True, + help="If `True` will work with a sequence length that may change between batches, in which case `parallelism_config_sp_seq_length` value can be set to anything divisible by sp size or remain unset. If `False` then `parallelism_config_sp_seq_length` needs to match the batch's sequence length dimension. The default is `True`.", + ) + + parallelism_config_args.add_argument( + "--parallelism_config_sp_attn_implementation", + type=str, + default="sdpa", + help="Attention implementation to use. Can be one of 'flash_attention_2', 'flash_attention_3' or 'sdpa'. Defaults to `sdpa`.", + ) + + # Other arguments of the training scripts + parser.add_argument("training_script_args", nargs=argparse.REMAINDER, help="Arguments of the training script.") + + if subparsers is not None: + parser.set_defaults(func=launch_command) + return parser + + +def simple_launcher(args): + cmd, current_env = prepare_simple_launcher_cmd_env(args) + + process = subprocess.Popen(cmd, env=current_env) + process.wait() + if process.returncode != 0: + if not args.quiet: + raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd) + else: + sys.exit(1) + + +def multi_gpu_launcher(args): + import torch.distributed.run as distrib_run + + current_env = prepare_multi_gpu_env(args) + if not check_cuda_p2p_ib_support(): + message = "Using RTX 4000 series which doesn't support faster communication speedups. Ensuring P2P and IB communications are disabled." + warn = False + if "NCCL_P2P_DISABLE" not in current_env: + current_env["NCCL_P2P_DISABLE"] = "1" + warn = True + if "NCCL_IB_DISABLE" not in current_env: + current_env["NCCL_IB_DISABLE"] = "1" + warn = True + if warn: + logger.warning(message) + + debug = getattr(args, "debug", False) + args = _filter_args( + args, + distrib_run.get_args_parser(), + ["--training_script", args.training_script, "--training_script_args", args.training_script_args], + ) + + with patch_environment(**current_env): + try: + distrib_run.run(args) + except Exception: + if is_rich_available() and debug: + console = get_console() + console.print("\n[bold red]Using --debug, `torch.distributed` Stack Trace:[/bold red]") + console.print_exception(suppress=[__file__], show_locals=False) + else: + raise + + +def deepspeed_launcher(args): + import torch.distributed.run as distrib_run + + if not is_deepspeed_available(): + raise ImportError("DeepSpeed is not installed => run `pip3 install deepspeed` or build it from source.") + else: + from deepspeed.launcher.runner import DEEPSPEED_ENVIRONMENT_NAME + + cmd, current_env = prepare_deepspeed_cmd_env(args) + if not check_cuda_p2p_ib_support(): + message = "Using RTX 4000 series which doesn't support faster communication speedups. Ensuring P2P and IB communications are disabled." + warn = False + if "NCCL_P2P_DISABLE" not in current_env: + current_env["NCCL_P2P_DISABLE"] = "1" + warn = True + if "NCCL_IB_DISABLE" not in current_env: + current_env["NCCL_IB_DISABLE"] = "1" + warn = True + if warn: + logger.warning(message) + + if args.num_machines > 1 and args.deepspeed_multinode_launcher != DEEPSPEED_MULTINODE_LAUNCHERS[1]: + with open(DEEPSPEED_ENVIRONMENT_NAME, "a") as f: + valid_env_items = convert_dict_to_env_variables(current_env) + if len(valid_env_items) > 1: + f.writelines(valid_env_items) + + process = subprocess.Popen(cmd, env=current_env) + process.wait() + if process.returncode != 0: + if not args.quiet: + raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd) + else: + sys.exit(1) + else: + debug = getattr(args, "debug", False) + args = _filter_args( + args, + distrib_run.get_args_parser(), + ["--training_script", args.training_script, "--training_script_args", args.training_script_args], + ) + with patch_environment(**current_env): + try: + distrib_run.run(args) + except Exception: + if is_rich_available() and debug: + console = get_console() + console.print("\n[bold red]Using --debug, `torch.distributed` Stack Trace:[/bold red]") + console.print_exception(suppress=[__file__], show_locals=False) + else: + raise + + +def tpu_launcher(args): + import torch_xla.distributed.xla_multiprocessing as xmp + + if args.no_python: + raise ValueError("--no_python cannot be used with TPU launcher") + + args, current_env = prepare_tpu(args, {}) + + if args.module: + mod_name = args.training_script + else: + # Import training_script as a module + script_path = Path(args.training_script) + sys.path.append(str(script_path.parent.resolve())) + mod_name = script_path.stem + + mod = importlib.import_module(mod_name) + if not hasattr(mod, args.main_training_function): + raise ValueError( + f"Your training script should have a function named {args.main_training_function}, or you should pass a " + "different value to `--main_training_function`." + ) + + # Patch sys.argv + sys.argv = [mod.__file__] + args.training_script_args + + main_function = getattr(mod, args.main_training_function) + with patch_environment(**current_env): + xmp.spawn(PrepareForLaunch(main_function), args=()) + + +def tpu_pod_launcher(args): + from torch_xla.distributed import xla_dist + + current_env = {} + args, current_env = prepare_tpu(args, current_env, True) + debug = getattr(args, "debug", False) + + training_script = args.training_script + training_script_args = args.training_script_args + new_args = _filter_args( + args, xla_dist.get_args_parser(), ["--tpu", args.tpu_name, "--positional", "", "--restart-tpuvm-pod-server"] + ) + + if args.tpu_use_sudo: + new_cmd = ["sudo"] + else: + new_cmd = [] + + new_cmd += [ + "accelerate-launch", + "--tpu", + "--no_tpu_cluster", + "--num_machines", + "1", + "--mixed_precision", + "no", + "--dynamo_backend", + "no", + "--num_processes", + str(args.num_processes), + "--main_training_function", + str(args.main_training_function), + training_script, + ] + training_script_args + + new_args.positional = new_cmd + bad_flags = "" + for arg in vars(new_args): + if arg.startswith("docker_"): + value = getattr(new_args, arg) + if value != "" and value is not None: + bad_flags += f'{arg}="{value}"\n' + if bad_flags != "": + raise ValueError( + f"Docker containers are not supported for TPU pod launcher currently, please remove the following flags:\n{bad_flags}" + ) + new_args.env = [f"{k}={v}" for k, v in current_env.items()] + new_args.env.append("ACCELERATE_IN_TPU_POD=1") + try: + xla_dist.resolve_and_execute(new_args) + except Exception: + if is_rich_available() and debug: + console = get_console() + console.print("\n[bold red]Using --debug, `torch_xla.xla_dist` Stack Trace:[/bold red]") + console.print_exception(suppress=[__file__], show_locals=False) + else: + raise + + +def sagemaker_launcher(sagemaker_config: SageMakerConfig, args): + if not is_sagemaker_available(): + raise ImportError( + "Please install sagemaker to be able to launch training on Amazon SageMaker with `pip install accelerate[sagemaker]`" + ) + if args.module or args.no_python: + raise ValueError( + "SageMaker requires a python training script file and cannot be used with --module or --no_python" + ) + + from sagemaker.huggingface import HuggingFace + + args, sagemaker_inputs = prepare_sagemager_args_inputs(sagemaker_config, args) + + huggingface_estimator = HuggingFace(**args) + + huggingface_estimator.fit(inputs=sagemaker_inputs) + print(f"You can find your model data at: {huggingface_estimator.model_data}") + + +def _validate_launch_command(args): + # Sanity checks + if sum([args.multi_gpu, args.cpu, args.tpu, args.use_deepspeed, args.use_fsdp]) > 1: + raise ValueError( + "You can only use one of `--cpu`, `--multi_gpu`, `--tpu`, `--use_deepspeed`, `--use_fsdp` at a time." + ) + if args.multi_gpu and (args.num_processes is not None) and (args.num_processes < 2): + raise ValueError("You need to use at least 2 processes to use `--multi_gpu`.") + + if (not args.use_fsdp or args.fsdp_version == 1) and args.use_parallelism_config: + raise ValueError("You cannot use `--use_parallelism_config` without `--use_fsdp` and `--fsdp_version=2`. ") + + defaults = None + warned = [] + mp_from_config_flag = False + # Get the default from the config file. + if args.config_file is not None or os.path.isfile(default_config_file) and not args.cpu: + defaults = load_config_from_file(args.config_file) + if ( + not args.multi_gpu + and not args.tpu + and not args.tpu_use_cluster + and not args.use_deepspeed + and not args.use_fsdp + and not args.use_megatron_lm + ): + args.use_deepspeed = defaults.distributed_type == DistributedType.DEEPSPEED + args.multi_gpu = ( + True + if defaults.distributed_type + in ( + DistributedType.MULTI_GPU, + DistributedType.MULTI_NPU, + DistributedType.MULTI_MLU, + DistributedType.MULTI_SDAA, + DistributedType.MULTI_MUSA, + DistributedType.MULTI_XPU, + DistributedType.MULTI_HPU, + DistributedType.MULTI_NEURON, + ) + else False + ) + args.tpu = defaults.distributed_type == DistributedType.XLA + args.use_fsdp = defaults.distributed_type == DistributedType.FSDP + args.use_megatron_lm = defaults.distributed_type == DistributedType.MEGATRON_LM + args.tpu_use_cluster = defaults.tpu_use_cluster if args.tpu else False + args.use_parallelism_config = defaults.parallelism_config != {} + if args.gpu_ids is None: + if defaults.gpu_ids is not None: + args.gpu_ids = defaults.gpu_ids + else: + args.gpu_ids = "all" + + if args.multi_gpu and args.num_machines is None: + args.num_machines = defaults.num_machines + + if len(args.gpu_ids.split(",")) < 2 and (args.gpu_ids != "all") and args.multi_gpu and args.num_machines <= 1: + raise ValueError( + "Less than two GPU ids were configured and tried to run on on multiple GPUs. " + "Please ensure at least two are specified for `--gpu_ids`, or use `--gpu_ids='all'`." + ) + if defaults.compute_environment == ComputeEnvironment.LOCAL_MACHINE: + # Update args with the defaults + for name, attr in defaults.__dict__.items(): + if isinstance(attr, dict): + # Copy defaults.somedict.somearg to args.somearg and + # defaults.fsdp_config.x to args.fsdp_x + for key, value in attr.items(): + if name == "fsdp_config" and not key.startswith("fsdp"): + key = "fsdp_" + key + elif name == "fp8_config" and not key.startswith("fp8"): + key = "fp8_" + key + if hasattr(args, "nondefault") and key not in args.nondefault: + setattr(args, key, value) + elif ( + name not in ["compute_environment", "mixed_precision", "distributed_type"] + and getattr(args, name, None) is None + ): + # Those args are handled separately + setattr(args, name, attr) + if not args.debug: + args.debug = defaults.debug + + if not args.mixed_precision: + if defaults.mixed_precision is None: + args.mixed_precision = "no" + else: + args.mixed_precision = defaults.mixed_precision + mp_from_config_flag = True + else: + native_amp = is_bf16_available(True) + if ( + args.mixed_precision == "bf16" + and not native_amp + and not (args.tpu and is_torch_xla_available(check_is_tpu=True)) + ): + raise ValueError("bf16 mixed precision requires PyTorch >= 1.10 and a supported device.") + + # Silently set the default here + if args.dynamo_backend is None: + args.dynamo_backend = "no" + if args.num_processes == -1: + raise ValueError("You need to manually pass in `--num_processes` using this config yaml.") + else: + if args.num_processes is None: + if is_xpu_available(): + args.num_processes = torch.xpu.device_count() + elif is_mlu_available(): + args.num_processes = torch.mlu.device_count() + elif is_sdaa_available(): + args.num_processes = torch.sdaa.device_count() + elif is_musa_available(): + args.num_processes = torch.musa.device_count() + elif is_npu_available(): + args.num_processes = torch.npu.device_count() + elif is_hpu_available(): + args.num_processes = torch.hpu.device_count() + elif is_neuron_available(): + args.num_processes = torch.neuron.device_count() + else: + args.num_processes = torch.cuda.device_count() + warned.append(f"\t`--num_processes` was set to a value of `{args.num_processes}`") + if args.debug is None: + args.debug = False + if ( + not args.multi_gpu + and args.num_processes > 1 + and ( + (is_xpu_available() and torch.xpu.device_count() > 1) + or (is_npu_available() and torch.npu.device_count() > 1) + or (is_hpu_available() and torch.hpu.device_count() > 1) + or (is_mlu_available() and torch.mlu.device_count() > 1) + or (is_sdaa_available() and torch.sdaa.device_count() > 1) + or (is_musa_available() and torch.musa.device_count() > 1) + or (is_neuron_available() and torch.neuron.device_count() > 1) + or (torch.cuda.is_available() and torch.cuda.device_count() > 1) + ) + ): + warned.append( + "\t\tMore than one GPU was found, enabling multi-GPU training.\n" + "\t\tIf this was unintended please pass in `--num_processes=1`." + ) + args.multi_gpu = True + if args.num_machines is None: + warned.append("\t`--num_machines` was set to a value of `1`") + args.num_machines = 1 + if args.mixed_precision is None: + warned.append("\t`--mixed_precision` was set to a value of `'no'`") + args.mixed_precision = "no" + if not hasattr(args, "use_cpu"): + args.use_cpu = args.cpu + if args.dynamo_backend is None: + warned.append("\t`--dynamo_backend` was set to a value of `'no'`") + args.dynamo_backend = "no" + if args.debug: + logger.debug("Running script in debug mode, expect distributed operations to be slightly slower.") + + is_aws_env_disabled = defaults is None or ( + defaults is not None and defaults.compute_environment != ComputeEnvironment.AMAZON_SAGEMAKER + ) + if is_aws_env_disabled and args.num_cpu_threads_per_process is None: + args.num_cpu_threads_per_process = get_int_from_env(["OMP_NUM_THREADS"], 1) + if args.use_cpu and args.num_processes >= 1 and get_int_from_env(["OMP_NUM_THREADS"], 0) == 0: + local_size = get_int_from_env( + ["MPI_LOCALNRANKS", "OMPI_COMM_WORLD_LOCAL_SIZE", "MV2_COMM_WORLD_LOCAL_SIZE"], + max(int(args.num_processes / args.num_machines), 1), + ) + import psutil + + threads_per_process = int(psutil.cpu_count(logical=False) / local_size) + if threads_per_process > 1: + args.num_cpu_threads_per_process = threads_per_process + warned.append( + f"\t`--num_cpu_threads_per_process` was set to `{args.num_cpu_threads_per_process}` to improve out-of-box performance when training on CPUs" + ) + + if any(warned): + message = "The following values were not passed to `accelerate launch` and had defaults used instead:\n" + message += "\n".join(warned) + message += ( + "\nTo avoid this warning pass in values for each of the problematic parameters or run `accelerate config`." + ) + logger.warning(message) + return args, defaults, mp_from_config_flag + + +def launch_command(args): + args, defaults, mp_from_config_flag = _validate_launch_command(args) + # Use the proper launcher + if args.use_deepspeed and not args.cpu: + args.deepspeed_fields_from_accelerate_config = list(defaults.deepspeed_config.keys()) if defaults else [] + if mp_from_config_flag: + args.deepspeed_fields_from_accelerate_config.append("mixed_precision") + args.deepspeed_fields_from_accelerate_config = ",".join(args.deepspeed_fields_from_accelerate_config) + deepspeed_launcher(args) + elif args.use_fsdp and not args.cpu: + multi_gpu_launcher(args) + elif args.use_megatron_lm and not args.cpu: + multi_gpu_launcher(args) + elif args.multi_gpu and not args.cpu: + multi_gpu_launcher(args) + elif args.tpu and not args.cpu: + if args.tpu_use_cluster: + tpu_pod_launcher(args) + else: + tpu_launcher(args) + elif defaults is not None and defaults.compute_environment == ComputeEnvironment.AMAZON_SAGEMAKER: + sagemaker_launcher(defaults, args) + else: + simple_launcher(args) + + +def main(): + parser = launch_command_parser() + args = parser.parse_args() + launch_command(args) + + +if __name__ == "__main__": + main() diff --git a/accelerate/commands/menu/__init__.py b/accelerate/commands/menu/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c2c851cc0b192ab8207d3fa68d7409868c84354c --- /dev/null +++ b/accelerate/commands/menu/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .selection_menu import BulletMenu diff --git a/accelerate/commands/menu/cursor.py b/accelerate/commands/menu/cursor.py new file mode 100644 index 0000000000000000000000000000000000000000..c1f0bb7b68025ae4fe0c2c76c095eb36b4e64f2c --- /dev/null +++ b/accelerate/commands/menu/cursor.py @@ -0,0 +1,65 @@ +# Copyright 2022 The HuggingFace Team and Brian Chao. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +A utility for showing and hiding the terminal cursor on Windows and Linux, based on https://github.com/bchao1/bullet +""" + +import os +import sys +from contextlib import contextmanager + + +# Windows only +if os.name == "nt": + import ctypes + import msvcrt # noqa + + class CursorInfo(ctypes.Structure): + # _fields is a specific attr expected by ctypes + _fields_ = [("size", ctypes.c_int), ("visible", ctypes.c_byte)] + + +def hide_cursor(): + if os.name == "nt": + ci = CursorInfo() + handle = ctypes.windll.kernel32.GetStdHandle(-11) + ctypes.windll.kernel32.GetConsoleCursorInfo(handle, ctypes.byref(ci)) + ci.visible = False + ctypes.windll.kernel32.SetConsoleCursorInfo(handle, ctypes.byref(ci)) + elif os.name == "posix": + sys.stdout.write("\033[?25l") + sys.stdout.flush() + + +def show_cursor(): + if os.name == "nt": + ci = CursorInfo() + handle = ctypes.windll.kernel32.GetStdHandle(-11) + ctypes.windll.kernel32.GetConsoleCursorInfo(handle, ctypes.byref(ci)) + ci.visible = True + ctypes.windll.kernel32.SetConsoleCursorInfo(handle, ctypes.byref(ci)) + elif os.name == "posix": + sys.stdout.write("\033[?25h") + sys.stdout.flush() + + +@contextmanager +def hide(): + "Context manager to hide the terminal cursor" + try: + hide_cursor() + yield + finally: + show_cursor() diff --git a/accelerate/commands/menu/helpers.py b/accelerate/commands/menu/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..de46f37ddcf4591167e3e01791391e4b1729034f --- /dev/null +++ b/accelerate/commands/menu/helpers.py @@ -0,0 +1,59 @@ +# Copyright 2022 The HuggingFace Team and Brian Chao. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +A variety of helper functions and constants when dealing with terminal menu choices, based on +https://github.com/bchao1/bullet +""" + +import enum +import shutil +import sys + + +TERMINAL_WIDTH, _ = shutil.get_terminal_size() + +CURSOR_TO_CHAR = {"UP": "A", "DOWN": "B", "RIGHT": "C", "LEFT": "D"} + + +class Direction(enum.Enum): + UP = 0 + DOWN = 1 + + +def forceWrite(content, end=""): + sys.stdout.write(str(content) + end) + sys.stdout.flush() + + +def writeColor(content, color, end=""): + forceWrite(f"\u001b[{color}m{content}\u001b[0m", end) + + +def reset_cursor(): + forceWrite("\r") + + +def move_cursor(num_lines: int, direction: str): + forceWrite(f"\033[{num_lines}{CURSOR_TO_CHAR[direction.upper()]}") + + +def clear_line(): + forceWrite(" " * TERMINAL_WIDTH) + reset_cursor() + + +def linebreak(): + reset_cursor() + forceWrite("-" * TERMINAL_WIDTH) diff --git a/accelerate/commands/menu/input.py b/accelerate/commands/menu/input.py new file mode 100644 index 0000000000000000000000000000000000000000..f1270eaece9d4243e7282dcb31166feeeb9bdfc1 --- /dev/null +++ b/accelerate/commands/menu/input.py @@ -0,0 +1,84 @@ +# Copyright 2022 The HuggingFace Team and Brian Chao. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file contains utilities for handling input from the user and registering specific keys to specific functions, +based on https://github.com/bchao1/bullet +""" + +from .keymap import KEYMAP, get_character + + +def mark(key: str): + """ + Mark the function with the key code so it can be handled in the register + """ + + def decorator(func): + handle = getattr(func, "handle_key", []) + handle += [key] + func.handle_key = handle + return func + + return decorator + + +def mark_multiple(*keys: list[str]): + """ + Mark the function with the key codes so it can be handled in the register + """ + + def decorator(func): + handle = getattr(func, "handle_key", []) + handle += keys + func.handle_key = handle + return func + + return decorator + + +class KeyHandler(type): + """ + Metaclass that adds the key handlers to the class + """ + + def __new__(cls, name, bases, attrs): + new_cls = super().__new__(cls, name, bases, attrs) + if not hasattr(new_cls, "key_handler"): + new_cls.key_handler = {} + new_cls.handle_input = KeyHandler.handle_input + + for value in attrs.values(): + handled_keys = getattr(value, "handle_key", []) + for key in handled_keys: + new_cls.key_handler[key] = value + return new_cls + + @staticmethod + def handle_input(cls): + "Finds and returns the selected character if it exists in the handler" + char = get_character() + if char != KEYMAP["undefined"]: + char = ord(char) + handler = cls.key_handler.get(char) + if handler: + cls.current_selection = char + return handler(cls) + else: + return None + + +def register(cls): + """Adds KeyHandler metaclass to the class""" + return KeyHandler(cls.__name__, cls.__bases__, cls.__dict__.copy()) diff --git a/accelerate/commands/menu/keymap.py b/accelerate/commands/menu/keymap.py new file mode 100644 index 0000000000000000000000000000000000000000..787db12860fe21c6786dda69c34fcccab114f2f8 --- /dev/null +++ b/accelerate/commands/menu/keymap.py @@ -0,0 +1,133 @@ +# Copyright 2022 The HuggingFace Team and Brian Chao. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Utilities relating to parsing raw characters from the keyboard, based on https://github.com/bchao1/bullet +""" + +import os +import string +import sys + + +ARROW_KEY_FLAG = 1 << 8 + +KEYMAP = { + "tab": ord("\t"), + "newline": ord("\r"), + "esc": 27, + "up": 65 + ARROW_KEY_FLAG, + "down": 66 + ARROW_KEY_FLAG, + "right": 67 + ARROW_KEY_FLAG, + "left": 68 + ARROW_KEY_FLAG, + "mod_int": 91, + "undefined": sys.maxsize, + "interrupt": 3, + "insert": 50, + "delete": 51, + "pg_up": 53, + "pg_down": 54, +} + +KEYMAP["arrow_begin"] = KEYMAP["up"] +KEYMAP["arrow_end"] = KEYMAP["left"] + +if sys.platform == "win32": + WIN_CH_BUFFER = [] + WIN_KEYMAP = { + b"\xe0H": KEYMAP["up"] - ARROW_KEY_FLAG, + b"\x00H": KEYMAP["up"] - ARROW_KEY_FLAG, + b"\xe0P": KEYMAP["down"] - ARROW_KEY_FLAG, + b"\x00P": KEYMAP["down"] - ARROW_KEY_FLAG, + b"\xe0M": KEYMAP["right"] - ARROW_KEY_FLAG, + b"\x00M": KEYMAP["right"] - ARROW_KEY_FLAG, + b"\xe0K": KEYMAP["left"] - ARROW_KEY_FLAG, + b"\x00K": KEYMAP["left"] - ARROW_KEY_FLAG, + } + +for i in range(10): + KEYMAP[str(i)] = ord(str(i)) + + +def get_raw_chars(): + "Gets raw characters from inputs" + if os.name == "nt": + import msvcrt + + encoding = "mbcs" + # Flush the keyboard buffer + while msvcrt.kbhit(): + msvcrt.getch() + if len(WIN_CH_BUFFER) == 0: + # Read the keystroke + ch = msvcrt.getch() + + # If it is a prefix char, get second part + if ch in (b"\x00", b"\xe0"): + ch2 = ch + msvcrt.getch() + # Translate actual Win chars to bullet char types + try: + chx = chr(WIN_KEYMAP[ch2]) + WIN_CH_BUFFER.append(chr(KEYMAP["mod_int"])) + WIN_CH_BUFFER.append(chx) + if ord(chx) in ( + KEYMAP["insert"] - 1 << 9, + KEYMAP["delete"] - 1 << 9, + KEYMAP["pg_up"] - 1 << 9, + KEYMAP["pg_down"] - 1 << 9, + ): + WIN_CH_BUFFER.append(chr(126)) + ch = chr(KEYMAP["esc"]) + except KeyError: + ch = ch2[1] + else: + ch = ch.decode(encoding) + else: + ch = WIN_CH_BUFFER.pop(0) + elif os.name == "posix": + import termios + import tty + + fd = sys.stdin.fileno() + old_settings = termios.tcgetattr(fd) + try: + tty.setraw(fd) + ch = sys.stdin.read(1) + finally: + termios.tcsetattr(fd, termios.TCSADRAIN, old_settings) + return ch + + +def get_character(): + "Gets a character from the keyboard and returns the key code" + char = get_raw_chars() + if ord(char) in [KEYMAP["interrupt"], KEYMAP["newline"]]: + return char + + elif ord(char) == KEYMAP["esc"]: + combo = get_raw_chars() + if ord(combo) == KEYMAP["mod_int"]: + key = get_raw_chars() + if ord(key) >= KEYMAP["arrow_begin"] - ARROW_KEY_FLAG and ord(key) <= KEYMAP["arrow_end"] - ARROW_KEY_FLAG: + return chr(ord(key) + ARROW_KEY_FLAG) + else: + return KEYMAP["undefined"] + else: + return get_raw_chars() + + else: + if char in string.printable: + return char + else: + return KEYMAP["undefined"] diff --git a/accelerate/commands/menu/selection_menu.py b/accelerate/commands/menu/selection_menu.py new file mode 100644 index 0000000000000000000000000000000000000000..ca66b2ed6782f819ffc69bd0bb3be5b7139a3863 --- /dev/null +++ b/accelerate/commands/menu/selection_menu.py @@ -0,0 +1,145 @@ +# Copyright 2022 The HuggingFace Team and Brian Chao. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Main driver for the selection menu, based on https://github.com/bchao1/bullet +""" + +import builtins +import sys +from typing import Optional + +from ...utils.imports import _is_package_available +from . import cursor, input +from .helpers import Direction, clear_line, forceWrite, linebreak, move_cursor, reset_cursor, writeColor +from .keymap import KEYMAP + + +in_colab = False +try: + in_colab = _is_package_available("google.colab") +except ModuleNotFoundError: + pass + + +@input.register +class BulletMenu: + """ + A CLI menu to select a choice from a list of choices using the keyboard. + """ + + def __init__(self, prompt: Optional[str] = None, choices: list = []): + self.position = 0 + self.choices = choices + self.prompt = prompt + if sys.platform == "win32": + self.arrow_char = "*" + else: + self.arrow_char = "➔ " + + def write_choice(self, index, end: str = ""): + if sys.platform != "win32": + writeColor(self.choices[index], 32, end) + else: + forceWrite(self.choices[index], end) + + def print_choice(self, index: int): + "Prints the choice at the given index" + if index == self.position: + forceWrite(f" {self.arrow_char} ") + self.write_choice(index) + else: + forceWrite(f" {self.choices[index]}") + reset_cursor() + + def move_direction(self, direction: Direction, num_spaces: int = 1): + "Should not be directly called, used to move a direction of either up or down" + old_position = self.position + if direction == Direction.DOWN: + if self.position + 1 >= len(self.choices): + return + self.position += num_spaces + else: + if self.position - 1 < 0: + return + self.position -= num_spaces + clear_line() + self.print_choice(old_position) + move_cursor(num_spaces, direction.name) + self.print_choice(self.position) + + @input.mark(KEYMAP["up"]) + def move_up(self): + self.move_direction(Direction.UP) + + @input.mark(KEYMAP["down"]) + def move_down(self): + self.move_direction(Direction.DOWN) + + @input.mark(KEYMAP["newline"]) + def select(self): + move_cursor(len(self.choices) - self.position, "DOWN") + return self.position + + @input.mark(KEYMAP["interrupt"]) + def interrupt(self): + move_cursor(len(self.choices) - self.position, "DOWN") + raise KeyboardInterrupt + + @input.mark_multiple(*[KEYMAP[str(number)] for number in range(10)]) + def select_row(self): + index = int(chr(self.current_selection)) + movement = index - self.position + if index == self.position: + return + if index < len(self.choices): + if self.position > index: + self.move_direction(Direction.UP, -movement) + elif self.position < index: + self.move_direction(Direction.DOWN, movement) + else: + return + else: + return + + def run(self, default_choice: int = 0): + "Start the menu and return the selected choice" + if self.prompt: + linebreak() + forceWrite(self.prompt, "\n") + if in_colab: + forceWrite("Please input a choice index (starting from 0), and press enter", "\n") + else: + forceWrite("Please select a choice using the arrow or number keys, and selecting with enter", "\n") + self.position = default_choice + for i in range(len(self.choices)): + self.print_choice(i) + forceWrite("\n") + move_cursor(len(self.choices) - self.position, "UP") + with cursor.hide(): + while True: + if in_colab: + try: + choice = int(builtins.input()) + except ValueError: + choice = default_choice + else: + choice = self.handle_input() + if choice is not None: + reset_cursor() + for _ in range(len(self.choices) + 1): + move_cursor(1, "UP") + clear_line() + self.write_choice(choice, "\n") + return choice diff --git a/accelerate/commands/merge.py b/accelerate/commands/merge.py new file mode 100644 index 0000000000000000000000000000000000000000..475b53b5bbb71b959057126f8667d7f61eb9d0e1 --- /dev/null +++ b/accelerate/commands/merge.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from accelerate.commands.utils import CustomArgumentParser +from accelerate.utils import merge_fsdp_weights + + +description = """Utility to merge the weights from multiple FSDP checkpoints into a single combined checkpoint. Should be used if +`SHARDED_STATE_DICT` was used for the model. Weights will be saved to `{output_path}`. + +This is a CPU-bound process and requires enough RAM to load the entire model state dict.""" + + +def merge_command(args): + merge_fsdp_weights( + args.checkpoint_directory, args.output_path, not args.unsafe_serialization, args.remove_checkpoint_dir + ) + + +def merge_command_parser(subparsers=None): + if subparsers is not None: + parser = subparsers.add_parser("merge-weights", description=description) + else: + parser = CustomArgumentParser(description=description) + + parser.add_argument("checkpoint_directory", type=str, help="A directory containing sharded weights saved by FSDP.") + parser.add_argument( + "output_path", + type=str, + help="The path to save the merged weights. Defaults to the current directory. ", + ) + parser.add_argument( + "--unsafe_serialization", + action="store_true", + default=False, + help="Whether to save the merged weights as `.bin` rather than `.safetensors` (not recommended).", + ) + parser.add_argument( + "--remove_checkpoint_dir", + action="store_true", + help="Whether to remove the checkpoint directory after merging.", + default=False, + ) + + if subparsers is not None: + parser.set_defaults(func=merge_command) + return parser + + +def main(): + parser = merge_command_parser() + args = parser.parse_args() + merge_command(args) + + +if __name__ == "__main__": + main() diff --git a/accelerate/commands/test.py b/accelerate/commands/test.py new file mode 100644 index 0000000000000000000000000000000000000000..a0d2f7bcf14727aa13e3438f4cd6e6f140f5bb2f --- /dev/null +++ b/accelerate/commands/test.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python + +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +from accelerate.test_utils import execute_subprocess_async, path_in_accelerate_package + + +def test_command_parser(subparsers=None): + if subparsers is not None: + parser = subparsers.add_parser("test") + else: + parser = argparse.ArgumentParser("Accelerate test command") + + parser.add_argument( + "--config_file", + default=None, + help=( + "The path to use to store the config file. Will default to a file named default_config.yaml in the cache " + "location, which is the content of the environment `HF_HOME` suffixed with 'accelerate', or if you don't have " + "such an environment variable, your cache directory ('~/.cache' or the content of `XDG_CACHE_HOME`) suffixed " + "with 'huggingface'." + ), + ) + + if subparsers is not None: + parser.set_defaults(func=test_command) + return parser + + +def test_command(args): + script_name = path_in_accelerate_package("test_utils", "scripts", "test_script.py") + + if args.config_file is None: + test_args = [script_name] + else: + test_args = f"--config_file={args.config_file} {script_name}".split() + + cmd = ["accelerate-launch"] + test_args + result = execute_subprocess_async(cmd) + if result.returncode == 0: + print("Test is a success! You are ready for your distributed training!") + + +def main(): + parser = test_command_parser() + args = parser.parse_args() + test_command(args) + + +if __name__ == "__main__": + main() diff --git a/accelerate/commands/to_fsdp2.py b/accelerate/commands/to_fsdp2.py new file mode 100644 index 0000000000000000000000000000000000000000..d2ce22f999caf60ad7bfd142289e18a56fe76280 --- /dev/null +++ b/accelerate/commands/to_fsdp2.py @@ -0,0 +1,172 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import enum +import logging +from pathlib import Path + +import yaml + +from accelerate.commands.utils import CustomArgumentParser + + +class ConversionStatus(enum.Enum): + NOT_YET_IMPLEMENTED = 0 + REMOVED = -1 + + +ARGUMENT_KEY_MAPPING = { + # New keys in FSDP2 + "fsdp_version": "fsdp_version", + "fsdp_reshard_after_forward": "fsdp_reshard_after_forward", + # https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md + # https://huggingface.co/docs/accelerate/en/usage_guides/fsdp + "fsdp_auto_wrap_policy": "fsdp_auto_wrap_policy", + "fsdp_backward_prefetch": ConversionStatus.REMOVED, + "fsdp_forward_prefetch": ConversionStatus.NOT_YET_IMPLEMENTED, + "fsdp_cpu_ram_efficient_loading": "fsdp_cpu_ram_efficient_loading", + "fsdp_offload_params": "fsdp_offload_params", + "fsdp_sharding_strategy": "fsdp_reshard_after_forward", + "fsdp_state_dict_type": "fsdp_state_dict_type", + "fsdp_sync_module_states": ConversionStatus.REMOVED, + "fsdp_transformer_layer_cls_to_wrap": "fsdp_transformer_layer_cls_to_wrap", + "fsdp_min_num_params": "fsdp_min_num_params", + "fsdp_use_orig_params": ConversionStatus.REMOVED, + "fsdp_activation_checkpointing": "fsdp_activation_checkpointing", +} + +ARGUMENT_VALUE_MAPPING = { + "fsdp_sharding_strategy": { + "FULL_SHARD": True, + "SHARD_GRAD_OP": False, + "HYBRID_SHARD": True, + "HYBRID_SHARD_ZERO2": False, + "NO_SHARD": False, + }, + "fsdp_reshard_after_forward": { # Needed to convert newly created configs using FSDP1 to FSDP2 + "FULL_SHARD": True, + "SHARD_GRAD_OP": False, + "HYBRID_SHARD": True, + "HYBRID_SHARD_ZERO2": False, + "NO_SHARD": False, + }, +} + +logger = logging.getLogger(__name__) + + +def _validate_to_fsdp2_args(args): + if not Path(args.config_file).exists(): + raise FileNotFoundError(f"Config file {args.config_file} not found") + + if not args.overwrite and args.output_file is None: + raise ValueError("If --overwrite is not set, --output_file must be provided") + + if not args.overwrite and Path(args.output_file).exists(): + raise FileExistsError(f"Output file {args.output_file} already exists and --overwrite is not set") + + +def convert_config_to_fsdp2(config: dict) -> dict: + fsdp_config = config.get("fsdp_config", {}) + + if not fsdp_config: + logger.info("No FSDP config found in the config file, skipping conversion...") + return config + + new_fsdp_config = {} + + if fsdp_config.get("fsdp_version", 1) == 2: + logger.warning("Config already specifies FSDP2, skipping conversion...") + logger.warning( + "If the config doesn't use new argument names, change `fsdp_version` to `1` and rerun the command." + ) + return config + + for key, value in fsdp_config.items(): + conversion_status = ARGUMENT_KEY_MAPPING.get(key, None) + if isinstance(conversion_status, ConversionStatus) or conversion_status is None: + conversion_status = key + new_fsdp_config[conversion_status] = value + continue + + if conversion_status == ConversionStatus.REMOVED: + logger.warning(f"Argument {key} has been removed in FSDP2, skipping this key...") + continue + + if conversion_status == ConversionStatus.NOT_YET_IMPLEMENTED: + logger.warning(f"Argument {key} is not yet implemented in FSDP2, skipping this key...") + continue + + if conversion_status is None: + logger.warning(f"Argument {key} is not being converted, skipping this key...") + new_fsdp_config[key] = value + else: + if key in ARGUMENT_VALUE_MAPPING: + value = ARGUMENT_VALUE_MAPPING[key].get(value, value) + new_fsdp_config[ARGUMENT_KEY_MAPPING[key]] = value + + new_fsdp_config["fsdp_version"] = 2 + config["fsdp_config"] = new_fsdp_config + return config + + +def to_fsdp2_command_parser(subparsers=None): + description = "Convert an Accelerate config from FSDP1 to FSDP2" + + if subparsers is not None: + parser = subparsers.add_parser("to-fsdp2", description=description) + else: + parser = CustomArgumentParser(description=description) + + parser.add_argument("--config_file", type=str, help="The config file to convert to FSDP2", required=True) + parser.add_argument( + "--overwrite", + action="store_true", + help="Overwrite the config file if it exists", + default=False, + ) + parser.add_argument( + "--output_file", + type=str, + help="The path to the output file to write the converted config to. If not provided, the input file will be overwritten (if --overwrite is set)", + default=None, + ) + if subparsers is not None: + parser.set_defaults(func=to_fsdp2_command) + + return parser + + +def load_config(config_file: str) -> dict: + with open(config_file) as f: + config = yaml.safe_load(f) + if not config: + raise ValueError("Config file is empty") + + return config + + +def to_fsdp2_command(args): + _validate_to_fsdp2_args(args) + config = load_config(args.config_file) + + if args.overwrite and args.output_file is None: + args.output_file = args.config_file + + new_config = convert_config_to_fsdp2(config) + + with open(args.output_file, "w") as f: + yaml.dump(new_config, f) diff --git a/accelerate/commands/tpu.py b/accelerate/commands/tpu.py new file mode 100644 index 0000000000000000000000000000000000000000..fc0f07bf8697bfdb6484d3bf817f2e18b1313b00 --- /dev/null +++ b/accelerate/commands/tpu.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python + +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import subprocess + +from packaging.version import Version, parse + +from accelerate.commands.config.config_args import default_config_file, load_config_from_file + + +_description = "Run commands across TPU VMs for initial setup before running `accelerate launch`." + + +def tpu_command_parser(subparsers=None): + if subparsers is not None: + parser = subparsers.add_parser("tpu-config", description=_description) + else: + parser = argparse.ArgumentParser("Accelerate tpu-config command", description=_description) + # Core arguments + config_args = parser.add_argument_group( + "Config Arguments", "Arguments that can be configured through `accelerate config`." + ) + config_args.add_argument( + "--config_file", + type=str, + default=None, + help="Path to the config file to use for accelerate.", + ) + config_args.add_argument( + "--tpu_name", + default=None, + help="The name of the TPU to use. If not specified, will use the TPU specified in the config file.", + ) + config_args.add_argument( + "--tpu_zone", + default=None, + help="The zone of the TPU to use. If not specified, will use the zone specified in the config file.", + ) + pod_args = parser.add_argument_group("TPU Arguments", "Arguments for options ran inside the TPU.") + pod_args.add_argument( + "--use_alpha", + action="store_true", + help="Whether to use `gcloud alpha` when running the TPU training script instead of `gcloud`.", + ) + pod_args.add_argument( + "--command_file", + default=None, + help="The path to the file containing the commands to run on the pod on startup.", + ) + pod_args.add_argument( + "--command", + action="append", + nargs="+", + help="A command to run on the pod. Can be passed multiple times.", + ) + pod_args.add_argument( + "--install_accelerate", + action="store_true", + help="Whether to install accelerate on the pod. Defaults to False.", + ) + pod_args.add_argument( + "--accelerate_version", + default="latest", + help="The version of accelerate to install on the pod. If not specified, will use the latest pypi version. Specify 'dev' to install from GitHub.", + ) + pod_args.add_argument( + "--debug", action="store_true", help="If set, will print the command that would be run instead of running it." + ) + + if subparsers is not None: + parser.set_defaults(func=tpu_command_launcher) + return parser + + +def tpu_command_launcher(args): + defaults = None + + # Get the default from the config file if it exists. + if args.config_file is not None or os.path.isfile(default_config_file): + defaults = load_config_from_file(args.config_file) + if not args.command_file and defaults.command_file is not None and not args.command: + args.command_file = defaults.command_file + if not args.command and defaults.commands is not None: + args.command = defaults.commands + if not args.tpu_name: + args.tpu_name = defaults.tpu_name + if not args.tpu_zone: + args.tpu_zone = defaults.tpu_zone + if args.accelerate_version == "dev": + args.accelerate_version = "git+https://github.com/huggingface/accelerate.git" + elif args.accelerate_version == "latest": + args.accelerate_version = "accelerate -U" + elif isinstance(parse(args.accelerate_version), Version): + args.accelerate_version = f"accelerate=={args.accelerate_version}" + + if not args.command_file and not args.command: + raise ValueError("You must specify either a command file or a command to run on the pod.") + + if args.command_file: + with open(args.command_file) as f: + args.command = [f.read().splitlines()] + + # To turn list of lists into list of strings + if isinstance(args.command[0], list): + args.command = [line for cmd in args.command for line in cmd] + # Default to the shared folder and install accelerate + new_cmd = ["cd /usr/share"] + if args.install_accelerate: + new_cmd += [f"pip install {args.accelerate_version}"] + new_cmd += args.command + args.command = "; ".join(new_cmd) + + # Then send it to gcloud + # Eventually try to use google-api-core to do this instead of subprocess + cmd = ["gcloud"] + if args.use_alpha: + cmd += ["alpha"] + cmd += [ + "compute", + "tpus", + "tpu-vm", + "ssh", + args.tpu_name, + "--zone", + args.tpu_zone, + "--command", + args.command, + "--worker", + "all", + ] + if args.debug: + print(f"Running {' '.join(cmd)}") + return + subprocess.run(cmd) + print("Successfully setup pod.") + + +def main(): + parser = tpu_command_parser() + args = parser.parse_args() + + tpu_command_launcher(args) diff --git a/accelerate/commands/utils.py b/accelerate/commands/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..326f37d7f93de2417e4171e5ffe91193fb97225c --- /dev/null +++ b/accelerate/commands/utils.py @@ -0,0 +1,123 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + + +class _StoreAction(argparse.Action): + """ + Custom action that allows for `-` or `_` to be passed in for an argument. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + new_option_strings = [] + for option_string in self.option_strings: + new_option_strings.append(option_string) + if "_" in option_string[2:]: + # Add `-` version to the option string + new_option_strings.append(option_string.replace("_", "-")) + self.option_strings = new_option_strings + + def __call__(self, parser, namespace, values, option_string=None): + setattr(namespace, self.dest, values) + if not hasattr(namespace, "nondefault"): + namespace.nondefault = set() + namespace.nondefault.add(self.dest) + + +class _StoreConstAction(_StoreAction): + """ + Same as `argparse._StoreConstAction` but uses the custom `_StoreAction`. + """ + + def __init__(self, option_strings, dest, const, default=None, required=False, help=None): + super().__init__( + option_strings=option_strings, + dest=dest, + nargs=0, + const=const, + default=default, + required=required, + help=help, + ) + + def __call__(self, parser, namespace, values, option_string=None): + super().__call__(parser, namespace, self.const, option_string) + + +class _StoreTrueAction(_StoreConstAction): + """ + Same as `argparse._StoreTrueAction` but uses the custom `_StoreConstAction`. + """ + + def __init__( + self, + option_strings, + dest, + default=None, + required=False, + help=None, + ): + super().__init__( + option_strings=option_strings, dest=dest, const=True, default=default, required=required, help=help + ) + + +class CustomArgumentGroup(argparse._ArgumentGroup): + """ + Custom argument group that allows for the use of `-` or `_` in arguments passed and overrides the help for each + when applicable. + """ + + def _add_action(self, action): + args = vars(action) + if isinstance(action, argparse._StoreTrueAction): + action = _StoreTrueAction( + args["option_strings"], args["dest"], args["default"], args["required"], args["help"] + ) + elif isinstance(action, argparse._StoreConstAction): + action = _StoreConstAction( + args["option_strings"], + args["dest"], + args["const"], + args["default"], + args["required"], + args["help"], + ) + elif isinstance(action, argparse._StoreAction): + action = _StoreAction(**args) + action = super()._add_action(action) + return action + + +class CustomArgumentParser(argparse.ArgumentParser): + """ + Custom argument parser that allows for the use of `-` or `_` in arguments passed and overrides the help for each + when applicable. + """ + + def add_argument(self, *args, **kwargs): + if "action" in kwargs: + # Translate action -> class + if kwargs["action"] == "store_true": + kwargs["action"] = _StoreTrueAction + else: + kwargs["action"] = _StoreAction + super().add_argument(*args, **kwargs) + + def add_argument_group(self, *args, **kwargs): + group = CustomArgumentGroup(self, *args, **kwargs) + self._action_groups.append(group) + return group diff --git a/accelerate/test_utils/__init__.py b/accelerate/test_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..67031886219f56344b1173c792ac4d8d421746d3 --- /dev/null +++ b/accelerate/test_utils/__init__.py @@ -0,0 +1,66 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .testing import ( + DEFAULT_LAUNCH_COMMAND, + are_the_same_tensors, + assert_exception, + capture_call_output, + device_count, + execute_subprocess_async, + get_launch_command, + get_torch_dist_unique_port, + memory_allocated_func, + path_in_accelerate_package, + pytest_xdist_worker_id, + require_bnb, + require_cpu, + require_cuda, + require_cuda_or_hpu, + require_cuda_or_xpu, + require_fp8, + require_fp16, + require_huggingface_suite, + require_mlu, + require_mps, + require_multi_device, + require_multi_gpu, + require_multi_gpu_or_xpu, + require_multi_xpu, + require_musa, + require_non_cpu, + require_non_hpu, + require_non_torch_xla, + require_non_xpu, + require_npu, + require_pippy, + require_sdaa, + require_single_device, + require_single_gpu, + require_single_xpu, + require_torch_min_version, + require_torchao, + require_torchvision, + require_tpu, + require_transformer_engine, + require_transformer_engine_mxfp8, + require_xpu, + run_first, + skip, + slow, + torch_device, +) +from .training import RegressionDataset, RegressionModel + + +from .scripts import test_script, test_sync, test_ops # isort: skip diff --git a/accelerate/test_utils/examples.py b/accelerate/test_utils/examples.py new file mode 100644 index 0000000000000000000000000000000000000000..18549cdb31492650e707b6008e25eeac8e33b04e --- /dev/null +++ b/accelerate/test_utils/examples.py @@ -0,0 +1,148 @@ +#!/usr/bin/env python + +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +A collection of utilities for comparing `examples/complete_*_example.py` scripts with the capabilities inside of each +`examples/by_feature` example. `compare_against_test` is the main function that should be used when testing, while the +others are used to either get the code that matters, or to preprocess them (such as stripping comments) +""" + +import os +from typing import Optional + + +def get_function_contents_by_name(lines: list[str], name: str): + """ + Extracts a function from `lines` of segmented source code with the name `name`. + + Args: + lines (`List[str]`): + Source code of a script separated by line. + name (`str`): + The name of the function to extract. Should be either `training_function` or `main` + """ + if name != "training_function" and name != "main": + raise ValueError(f"Incorrect function name passed: {name}, choose either 'main' or 'training_function'") + good_lines, found_start = [], False + for line in lines: + if not found_start and f"def {name}" in line: + found_start = True + good_lines.append(line) + continue + if found_start: + if name == "training_function" and "def main" in line: + return good_lines + if name == "main" and "if __name__" in line: + return good_lines + good_lines.append(line) + + +def clean_lines(lines: list[str]): + """ + Filters `lines` and removes any entries that start with a comment ('#') or is just a newline ('\n') + + Args: + lines (`List[str]`): + Source code of a script separated by line. + """ + return [line for line in lines if not line.lstrip().startswith("#") and line != "\n"] + + +def compare_against_test( + base_filename: str, feature_filename: str, parser_only: bool, secondary_filename: Optional[str] = None +): + """ + Tests whether the additional code inside of `feature_filename` was implemented in `base_filename`. This should be + used when testing to see if `complete_*_.py` examples have all of the implementations from each of the + `examples/by_feature/*` scripts. + + It utilizes `nlp_example.py` to extract out all of the repeated training code, so that only the new additional code + is examined and checked. If something *other* than `nlp_example.py` should be used, such as `cv_example.py` for the + `complete_cv_example.py` script, it should be passed in for the `secondary_filename` parameter. + + Args: + base_filename (`str` or `os.PathLike`): + The filepath of a single "complete" example script to test, such as `examples/complete_cv_example.py` + feature_filename (`str` or `os.PathLike`): + The filepath of a single feature example script. The contents of this script are checked to see if they + exist in `base_filename` + parser_only (`bool`): + Whether to compare only the `main()` sections in both files, or to compare the contents of + `training_loop()` + secondary_filename (`str`, *optional*): + A potential secondary filepath that should be included in the check. This function extracts the base + functionalities off of "examples/nlp_example.py", so if `base_filename` is a script other than + `complete_nlp_example.py`, the template script should be included here. Such as `examples/cv_example.py` + """ + with open(base_filename) as f: + base_file_contents = f.readlines() + with open(os.path.abspath(os.path.join("examples", "nlp_example.py"))) as f: + full_file_contents = f.readlines() + with open(feature_filename) as f: + feature_file_contents = f.readlines() + if secondary_filename is not None: + with open(secondary_filename) as f: + secondary_file_contents = f.readlines() + + # This is our base, we remove all the code from here in our `full_filename` and `feature_filename` to find the new content + if parser_only: + base_file_func = clean_lines(get_function_contents_by_name(base_file_contents, "main")) + full_file_func = clean_lines(get_function_contents_by_name(full_file_contents, "main")) + feature_file_func = clean_lines(get_function_contents_by_name(feature_file_contents, "main")) + if secondary_filename is not None: + secondary_file_func = clean_lines(get_function_contents_by_name(secondary_file_contents, "main")) + else: + base_file_func = clean_lines(get_function_contents_by_name(base_file_contents, "training_function")) + full_file_func = clean_lines(get_function_contents_by_name(full_file_contents, "training_function")) + feature_file_func = clean_lines(get_function_contents_by_name(feature_file_contents, "training_function")) + if secondary_filename is not None: + secondary_file_func = clean_lines( + get_function_contents_by_name(secondary_file_contents, "training_function") + ) + + _dl_line = "train_dataloader, eval_dataloader = get_dataloaders(accelerator, batch_size)\n" + + # Specific code in our script that differs from the full version, aka what is new + new_feature_code = [] + passed_idxs = [] # We keep track of the idxs just in case it's a repeated statement + it = iter(feature_file_func) + for i in range(len(feature_file_func) - 1): + if i not in passed_idxs: + line = next(it) + if (line not in full_file_func) and (line.lstrip() != _dl_line): + if "TESTING_MOCKED_DATALOADERS" not in line: + new_feature_code.append(line) + passed_idxs.append(i) + else: + # Skip over the `config['num_epochs'] = 2` statement + _ = next(it) + + # Extract out just the new parts from the full_file_training_func + new_full_example_parts = [] + passed_idxs = [] # We keep track of the idxs just in case it's a repeated statement + for i, line in enumerate(base_file_func): + if i not in passed_idxs: + if (line not in full_file_func) and (line.lstrip() != _dl_line): + if "TESTING_MOCKED_DATALOADERS" not in line: + new_full_example_parts.append(line) + passed_idxs.append(i) + + # Finally, get the overall diff + diff_from_example = [line for line in new_feature_code if line not in new_full_example_parts] + if secondary_filename is not None: + diff_from_two = [line for line in full_file_contents if line not in secondary_file_func] + diff_from_example = [line for line in diff_from_example if line not in diff_from_two] + + return diff_from_example diff --git a/accelerate/test_utils/scripts/__init__.py b/accelerate/test_utils/scripts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c9cbe26c257b515f657c05e1996d517e69613972 --- /dev/null +++ b/accelerate/test_utils/scripts/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/accelerate/test_utils/scripts/external_deps/__init__.py b/accelerate/test_utils/scripts/external_deps/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c9cbe26c257b515f657c05e1996d517e69613972 --- /dev/null +++ b/accelerate/test_utils/scripts/external_deps/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/accelerate/test_utils/scripts/external_deps/test_checkpointing.py b/accelerate/test_utils/scripts/external_deps/test_checkpointing.py new file mode 100644 index 0000000000000000000000000000000000000000..6a1553898ec3d55e64822c204ddf7e705069ce8a --- /dev/null +++ b/accelerate/test_utils/scripts/external_deps/test_checkpointing.py @@ -0,0 +1,269 @@ +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import json +import os + +import evaluate +import torch +from datasets import load_dataset +from torch.optim import AdamW +from torch.utils.data import DataLoader +from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed + +from accelerate import Accelerator, DistributedType +from accelerate.utils.deepspeed import DummyOptim, DummyScheduler + + +MAX_GPU_BATCH_SIZE = 16 +EVAL_BATCH_SIZE = 32 + + +def get_dataloaders(accelerator: Accelerator, batch_size: int = 16, model_name: str = "bert-base-cased"): + """ + Creates a set of `DataLoader`s for the `glue` dataset. + + Args: + accelerator (`Accelerator`): + An `Accelerator` object + batch_size (`int`, *optional*): + The batch size for the train and validation DataLoaders. + model_name (`str`, *optional*): + """ + tokenizer = AutoTokenizer.from_pretrained(model_name) + datasets = load_dataset("glue", "mrpc") + + def tokenize_function(examples): + # max_length=None => use the model max length (it's actually the default) + outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=None) + return outputs + + # Apply the method we just defined to all the examples in all the splits of the dataset + tokenized_datasets = datasets.map( + tokenize_function, batched=True, remove_columns=["idx", "sentence1", "sentence2"], load_from_cache_file=False + ) + + # We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the + # transformers library + tokenized_datasets = tokenized_datasets.rename_column("label", "labels") + + def collate_fn(examples): + # On TPU it's best to pad everything to the same length or training will be very slow. + if accelerator.distributed_type == DistributedType.XLA: + return tokenizer.pad(examples, padding="max_length", max_length=128, return_tensors="pt") + return tokenizer.pad(examples, padding="longest", return_tensors="pt") + + # Instantiate dataloaders. + train_dataloader = DataLoader( + tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size + ) + eval_dataloader = DataLoader( + tokenized_datasets["validation"], shuffle=False, collate_fn=collate_fn, batch_size=EVAL_BATCH_SIZE + ) + + return train_dataloader, eval_dataloader + + +def evaluation_loop(accelerator, model, eval_dataloader, metric): + model.eval() + samples_seen = 0 + for step, batch in enumerate(eval_dataloader): + # We could avoid this line since we set the accelerator with `device_placement=True`. + batch.to(accelerator.device) + with torch.no_grad(): + outputs = model(**batch) + predictions = outputs.logits.argmax(dim=-1) + # It is slightly faster to call this once, than multiple times + predictions, references = accelerator.gather( + (predictions, batch["labels"]) + ) # If we are in a multiprocess environment, the last batch has duplicates + if accelerator.use_distributed: + if step == len(eval_dataloader) - 1: + predictions = predictions[: len(eval_dataloader.dataset) - samples_seen] + references = references[: len(eval_dataloader.dataset) - samples_seen] + else: + samples_seen += references.shape[0] + metric.add_batch( + predictions=predictions, + references=references, + ) + + eval_metric = metric.compute() + return eval_metric["accuracy"] + + +def training_function(config, args): + # Initialize accelerator + accelerator = Accelerator() + + # Sample hyper-parameters for learning rate, batch size, seed and a few other HPs + lr = config["lr"] + num_epochs = int(config["num_epochs"]) + seed = int(config["seed"]) + batch_size = int(config["batch_size"]) + model_name = args.model_name_or_path + + set_seed(seed) + train_dataloader, eval_dataloader = get_dataloaders(accelerator, batch_size, model_name) + + # Instantiate the model (we build the model here so that the seed also control new weights initialization) + model = AutoModelForSequenceClassification.from_pretrained(model_name, return_dict=True) + + # Instantiate optimizer + optimizer_cls = ( + AdamW + if accelerator.state.deepspeed_plugin is None + or "optimizer" not in accelerator.state.deepspeed_plugin.deepspeed_config + else DummyOptim + ) + optimizer = optimizer_cls(params=model.parameters(), lr=lr) + + if accelerator.state.deepspeed_plugin is not None: + gradient_accumulation_steps = accelerator.state.deepspeed_plugin.deepspeed_config[ + "gradient_accumulation_steps" + ] + else: + gradient_accumulation_steps = 1 + max_training_steps = (len(train_dataloader) * num_epochs) // gradient_accumulation_steps + + # Instantiate scheduler + if ( + accelerator.state.deepspeed_plugin is None + or "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config + ): + lr_scheduler = get_linear_schedule_with_warmup( + optimizer=optimizer, + num_warmup_steps=0, + num_training_steps=max_training_steps, + ) + else: + lr_scheduler = DummyScheduler(optimizer, total_num_steps=max_training_steps, warmup_num_steps=0) + + # Prepare everything + # There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the + # prepare method. + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare( + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler + ) + + # We need to keep track of how many total steps we have iterated over + overall_step = 0 + # We also need to keep track of the stating epoch so files are named properly + starting_epoch = 0 + metric = evaluate.load("glue", "mrpc") + ending_epoch = num_epochs + + if args.partial_train_epoch is not None: + ending_epoch = args.partial_train_epoch + + if args.resume_from_checkpoint: + accelerator.load_state(args.resume_from_checkpoint) + epoch_string = args.resume_from_checkpoint.split("epoch_")[1] + state_epoch_num = "" + for char in epoch_string: + if char.isdigit(): + state_epoch_num += char + else: + break + starting_epoch = int(state_epoch_num) + 1 + accuracy = evaluation_loop(accelerator, model, eval_dataloader, metric) + accelerator.print("resumed checkpoint performance:", accuracy) + accelerator.print("resumed checkpoint's scheduler's lr:", lr_scheduler.get_lr()[0]) + accelerator.print("resumed optimizers's lr:", optimizer.param_groups[0]["lr"]) + with open(os.path.join(args.output_dir, f"state_{starting_epoch - 1}.json")) as f: + resumed_state = json.load(f) + assert resumed_state["accuracy"] == accuracy, "Accuracy mismatch, loading from checkpoint failed" + assert resumed_state["lr"] == lr_scheduler.get_lr()[0], ( + "Scheduler learning rate mismatch, loading from checkpoint failed" + ) + assert resumed_state["optimizer_lr"] == optimizer.param_groups[0]["lr"], ( + "Optimizer learning rate mismatch, loading from checkpoint failed" + ) + assert resumed_state["epoch"] == starting_epoch - 1, "Epoch mismatch, loading from checkpoint failed" + return + + # Now we train the model + state = {} + for epoch in range(starting_epoch, ending_epoch): + model.train() + for step, batch in enumerate(train_dataloader): + outputs = model(**batch) + loss = outputs.loss + loss = loss / gradient_accumulation_steps + accelerator.backward(loss) + if step % gradient_accumulation_steps == 0: + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + overall_step += 1 + output_dir = f"epoch_{epoch}" + output_dir = os.path.join(args.output_dir, output_dir) + accelerator.save_state(output_dir) + accuracy = evaluation_loop(accelerator, model, eval_dataloader, metric) + state["accuracy"] = accuracy + state["lr"] = lr_scheduler.get_lr()[0] + state["optimizer_lr"] = optimizer.param_groups[0]["lr"] + state["epoch"] = epoch + state["step"] = overall_step + accelerator.print(f"epoch {epoch}:", state) + + accelerator.wait_for_everyone() + if accelerator.is_main_process: + with open(os.path.join(args.output_dir, f"state_{epoch}.json"), "w") as f: + json.dump(state, f) + accelerator.end_training() + + +def main(): + parser = argparse.ArgumentParser(description="Simple example of training script tracking peak GPU memory usage.") + parser.add_argument( + "--model_name_or_path", + type=str, + default="bert-base-cased", + help="Path to pretrained model or model identifier from huggingface.co/models.", + required=False, + ) + parser.add_argument( + "--output_dir", + type=str, + default=".", + help="Optional save directory where all checkpoint folders will be stored. Default is the current working directory.", + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help="If the training should continue from a checkpoint folder.", + ) + parser.add_argument( + "--partial_train_epoch", + type=int, + default=None, + help="If passed, the training will stop after this number of epochs.", + ) + parser.add_argument( + "--num_epochs", + type=int, + default=2, + help="Number of train epochs.", + ) + args = parser.parse_args() + config = {"lr": 2e-5, "num_epochs": args.num_epochs, "seed": 42, "batch_size": 16} + + training_function(config, args) + + +if __name__ == "__main__": + main() diff --git a/accelerate/test_utils/scripts/external_deps/test_ds_alst_ulysses_sp.py b/accelerate/test_utils/scripts/external_deps/test_ds_alst_ulysses_sp.py new file mode 100644 index 0000000000000000000000000000000000000000..c756d5dc77f9ce51e305001697fad5c342575887 --- /dev/null +++ b/accelerate/test_utils/scripts/external_deps/test_ds_alst_ulysses_sp.py @@ -0,0 +1,131 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Test script for verifying ALST/Ulysses SP works +""" + +import torch +from deepspeed.runtime.utils import move_to_device +from transformers import AutoModelForCausalLM, AutoTokenizer + +from accelerate import Accelerator +from accelerate.utils import ParallelismConfig, set_seed +from accelerate.utils.dataclasses import DeepSpeedSequenceParallelConfig + + +set_seed(42) + +world_size = 2 +model_name = "hf-internal-testing/tiny-random-LlamaForCausalLM" + +micro_batch_size = 1 + +parallelism_config = ParallelismConfig( + sp_backend="deepspeed", + sp_size=world_size, + # dp_shard_size=1, # set if dp is wanted as well + sp_handler=DeepSpeedSequenceParallelConfig( + sp_seq_length=256, + sp_seq_length_is_variable=True, + sp_attn_implementation="sdpa", + ), +) + +accelerator = Accelerator( + parallelism_config=parallelism_config, +) + +tokenizer = AutoTokenizer.from_pretrained(model_name) +model = AutoModelForCausalLM.from_pretrained(model_name) + +samples = 4 +seqlen = 32 +input_ids = torch.arange(1, seqlen * samples + 1).view(-1, seqlen) + 100 +position_ids = torch.arange(seqlen * samples).view(-1, seqlen) + +ds = torch.utils.data.TensorDataset(input_ids, position_ids) + + +def collate_fn(batch): + input_ids, position_ids = batch[0] + return dict( + input_ids=input_ids.unsqueeze(0), + position_ids=position_ids.unsqueeze(0), + labels=input_ids.unsqueeze(0), + ) + + +dl = torch.utils.data.DataLoader(ds, batch_size=micro_batch_size, collate_fn=collate_fn) + +optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) + +rank = torch.distributed.get_rank() + +if rank == 0: + print(f"DL orig: {len(dl)} samples") + +model, optimizer, dl = accelerator.prepare(model, optimizer, dl) + +if rank == 0: + print(f"DL w/ adapter: {len(dl)} samples") + +sp_size = parallelism_config.sp_size if parallelism_config else 1 +if sp_size > 1: + from deepspeed.utils import groups + + sp_group = groups._get_sequence_parallel_group() + sp_world_size = parallelism_config.sp_size + +unwrapped_model = accelerator.unwrap_model(model) + +# Normal training loop +for iter, batch in enumerate(dl): + optimizer.zero_grad() + + if rank == 0: + print(f"batch {iter}: seqlen: {len(batch['input_ids'][0])}") + batch = move_to_device(batch, model.device) + outputs = model(**batch) + + shift_labels = batch["shift_labels"] + loss = unwrapped_model.loss_function( + logits=outputs.logits, + labels=None, + shift_labels=shift_labels, + vocab_size=unwrapped_model.config.vocab_size, + ) + + if sp_size > 1: + # differentiable weighted per-shard-loss aggregation across ranks + losses_per_rank = torch.distributed.nn.functional.all_gather(loss, group=sp_group) + # special dealing with SFT that has prompt tokens that aren't used in loss computation + good_tokens = (shift_labels != -100).view(-1).sum() + good_tokens_per_rank = torch.distributed.nn.functional.all_gather(good_tokens, group=sp_group) + total_loss = sum( + losses_per_rank[rank] * good_tokens_per_rank[rank] + for rank in range(sp_world_size) + if good_tokens_per_rank[rank] > 0 + ) + total_good_tokens = sum(good_tokens_per_rank) + loss = total_loss / max(total_good_tokens, 1) + + if rank == 0: + accelerator.print(f"{iter}: {loss=}") + accelerator.log(dict(train_loss=loss, step=iter)) + + accelerator.backward(loss) + optimizer.step() + +accelerator.end_training() diff --git a/accelerate/test_utils/scripts/external_deps/test_ds_multiple_model.py b/accelerate/test_utils/scripts/external_deps/test_ds_multiple_model.py new file mode 100644 index 0000000000000000000000000000000000000000..a9ff3f3d2d0214f2b36f2b51e1d029b6ddf7cb7c --- /dev/null +++ b/accelerate/test_utils/scripts/external_deps/test_ds_multiple_model.py @@ -0,0 +1,331 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Test script for verifying multiple models can be utilized with Accelerate + DeepSpeed: + +Scenario 1: One model is training, another model is being used for inference/logits to impact training in some form. +Scenario 2: Two models are training simultaneously, which means two optimizers, etc. +""" + +import argparse +from pathlib import Path + +import evaluate +import torch +from datasets import load_dataset +from torch.optim import AdamW +from torch.utils.data import DataLoader +from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup + +from accelerate import Accelerator, DeepSpeedPlugin, DistributedType +from accelerate.state import AcceleratorState +from accelerate.utils.deepspeed import get_active_deepspeed_plugin + + +EVAL_BATCH_SIZE = 16 + + +class NoiseModel(torch.nn.Module): + def __init__(self, noise_factor=0.1): + super().__init__() + self.noise_factor = torch.nn.Parameter(torch.tensor(noise_factor, dtype=torch.float32)) + + def forward(self, loss): + return loss * self.noise_factor + + +def get_dataloaders(accelerator: Accelerator, batch_size: int = 16, model_name: str = "bert-base-cased"): + """ + Creates a set of `DataLoader`s for the `glue` dataset. + + Args: + accelerator (`Accelerator`): + An `Accelerator` object + batch_size (`int`, *optional*): + The batch size for the train and validation DataLoaders. + model_name (`str`, *optional*): + """ + tokenizer = AutoTokenizer.from_pretrained(model_name) + datasets = load_dataset("glue", "mrpc") + + def tokenize_function(examples): + # max_length=None => use the model max length (it's actually the default) + outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=None) + return outputs + + # Apply the method we just defined to all the examples in all the splits of the dataset + tokenized_datasets = datasets.map( + tokenize_function, batched=True, remove_columns=["idx", "sentence1", "sentence2"], load_from_cache_file=False + ) + + # We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the + # transformers library + tokenized_datasets = tokenized_datasets.rename_column("label", "labels") + + def collate_fn(examples): + # On TPU it's best to pad everything to the same length or training will be very slow. + if accelerator.distributed_type == DistributedType.XLA: + return tokenizer.pad(examples, padding="max_length", max_length=128, return_tensors="pt") + return tokenizer.pad(examples, padding="longest", return_tensors="pt") + + # Instantiate dataloaders. + train_dataloader = DataLoader( + tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size + ) + eval_dataloader = DataLoader( + tokenized_datasets["validation"], shuffle=False, collate_fn=collate_fn, batch_size=EVAL_BATCH_SIZE + ) + + return train_dataloader, eval_dataloader + + +test_file_path = __file__ +path = Path(test_file_path).resolve() +test_file_dir_str = str(path.parent.parent.parent.parent.parent.parent) + +# Create our DS plugins +# We use custom schedulers and optimizers, hence `model_only` +ds_config_file = dict( + zero2=f"{test_file_dir_str}/tests/deepspeed/ds_config_zero2_model_only.json", + zero3=f"{test_file_dir_str}/tests/deepspeed/ds_config_zero3_model_only.json", +) + + +def single_model_training(config, args): + # Training a single model, we have a `noise` model that is untrainable used to inject some noise into the training process + num_epochs = config["num_epochs"] + zero2_plugin = DeepSpeedPlugin(hf_ds_config=ds_config_file["zero2"]) + zero3_plugin = DeepSpeedPlugin(hf_ds_config=ds_config_file["zero3"]) + + deepspeed_plugins = {"training": zero2_plugin, "inference": zero3_plugin} + + # Initialize accelerator + accelerator = Accelerator( + deepspeed_plugins=deepspeed_plugins, + mixed_precision="bf16", + ) + + # Initialize model under zero2 plugin + assert get_active_deepspeed_plugin(accelerator.state) is zero2_plugin + train_model = AutoModelForSequenceClassification.from_pretrained(args.model_name_or_path) + train_dataloader, eval_dataloader = get_dataloaders( + accelerator, batch_size=config["batch_size"], model_name=args.model_name_or_path + ) + max_training_steps = len(train_dataloader) * config["num_epochs"] + optimizer = AdamW(train_model.parameters(), lr=config["lr"]) + lr_scheduler = get_linear_schedule_with_warmup( + optimizer, num_warmup_steps=0, num_training_steps=max_training_steps + ) + + train_dataloader, eval_dataloader, train_model, optimizer, lr_scheduler = accelerator.prepare( + train_dataloader, eval_dataloader, train_model, optimizer, lr_scheduler + ) + + # Now prepare the model under zero3 plugin + accelerator.state.select_deepspeed_plugin("inference") + assert get_active_deepspeed_plugin(accelerator.state) is zero3_plugin + inference_model = NoiseModel() + inference_model = accelerator.prepare(inference_model) + inference_model.eval() + + # Run training loop + accelerator.state.select_deepspeed_plugin("training") + # We also need to keep track of the stating epoch so files are named properly + starting_epoch = 0 + + # Now we train the model + best_performance = 0 + metric = evaluate.load("glue", "mrpc") + performance_metric = {} + for epoch in range(starting_epoch, num_epochs): + train_model.train() + inference_model.train() + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(train_model): + outputs_1 = train_model(**batch) + with torch.no_grad(): + outputs_2 = inference_model(outputs_1.loss) + # Combine the losses + loss = outputs_1.loss + outputs_2 + accelerator.backward(loss) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + train_model.eval() + for step, batch in enumerate(eval_dataloader): + with torch.no_grad(): + outputs = train_model(**batch) + predictions = outputs.logits.argmax(dim=-1) + # It is slightly faster to call this once, than multiple times + predictions, references = accelerator.gather_for_metrics((predictions, batch["labels"])) + metric.add_batch( + predictions=predictions, + references=references, + ) + + eval_metric = metric.compute() + # Use accelerator.print to print only on the main process. + accelerator.print(f"epoch {epoch}:", eval_metric) + performance_metric[f"epoch-{epoch}"] = eval_metric["accuracy"] + + if best_performance < eval_metric["accuracy"]: + best_performance = eval_metric["accuracy"] + assert best_performance > performance_metric["epoch-0"] + + +def multiple_model_training(config, args): + # This will essentially be like a k-fold model, but one model is Zero-2 and another model is Zero-3 + num_epochs = config["num_epochs"] + zero2_plugin = DeepSpeedPlugin(hf_ds_config=ds_config_file["zero2"]) + zero3_plugin = DeepSpeedPlugin(hf_ds_config=ds_config_file["zero3"]) + + deepspeed_plugins = {"zero2": zero2_plugin, "zero3": zero3_plugin} + + # Initialize accelerator + zero2_accelerator = Accelerator( + deepspeed_plugins=deepspeed_plugins, + mixed_precision="bf16", + ) + + # Since an `AcceleratorState` has already been made, we can just reuse it here + zero3_accelerator = Accelerator() + + # Initialize model under zero2 plugin + assert get_active_deepspeed_plugin(zero2_accelerator.state) is zero2_plugin + zero2_model = AutoModelForSequenceClassification.from_pretrained(args.model_name_or_path) + train_dataloader, eval_dataloader = get_dataloaders( + zero2_accelerator, batch_size=config["batch_size"], model_name=args.model_name_or_path + ) + max_training_steps = len(train_dataloader) * config["num_epochs"] + zero2_optimizer = AdamW(zero2_model.parameters(), lr=config["lr"]) + zero2_lr_scheduler = get_linear_schedule_with_warmup( + zero2_optimizer, num_warmup_steps=0, num_training_steps=max_training_steps + ) + + train_dataloader, eval_dataloader, zero2_model, zero2_optimizer, zero2_lr_scheduler = zero2_accelerator.prepare( + train_dataloader, eval_dataloader, zero2_model, zero2_optimizer, zero2_lr_scheduler + ) + assert zero2_accelerator.deepspeed_engine_wrapped.engine is zero2_model + + # now do Zero3 + zero3_accelerator.state.select_deepspeed_plugin("zero3") + zero3_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = zero2_plugin.deepspeed_config[ + "train_micro_batch_size_per_gpu" + ] + assert get_active_deepspeed_plugin(zero3_accelerator.state) is zero3_plugin + zero3_model = AutoModelForSequenceClassification.from_pretrained(args.model_name_or_path) + zero3_optimizer = AdamW(zero3_model.parameters(), lr=config["lr"]) + zero3_lr_scheduler = get_linear_schedule_with_warmup( + zero3_optimizer, num_warmup_steps=0, num_training_steps=max_training_steps + ) + zero3_model, zero3_optimizer, zero3_lr_scheduler = zero3_accelerator.prepare( + zero3_model, zero3_optimizer, zero3_lr_scheduler + ) + assert zero3_accelerator.deepspeed_engine_wrapped.engine is zero3_model + + # Run training loop + starting_epoch = 0 + + # Now we train the model + best_performance_a = 0 + best_performance_b = 0 + metric_a = evaluate.load("glue", "mrpc") + metric_b = evaluate.load("glue", "mrpc") + performance_metric_a = {} + performance_metric_b = {} + for epoch in range(starting_epoch, num_epochs): + zero2_model.train() + zero3_model.train() + for step, batch in enumerate(train_dataloader): + with zero2_accelerator.accumulate(zero2_model, zero3_model): + outputs_1 = zero2_model(**batch) + zero2_accelerator.backward(outputs_1.loss) + zero2_optimizer.step() + zero2_lr_scheduler.step() + zero2_optimizer.zero_grad() + outputs_2 = zero3_model(**batch) + zero3_accelerator.backward(outputs_2.loss) + zero3_optimizer.step() + zero3_lr_scheduler.step() + zero3_optimizer.zero_grad() + + zero2_model.eval() + zero3_model.eval() + for step, batch in enumerate(eval_dataloader): + with torch.no_grad(): + logits_a = zero2_model(**batch).logits + logits_b = zero3_model(**batch).logits + # Combine the logits from both models + predictions_a = logits_a.argmax(dim=-1) + predictions_b = logits_b.argmax(dim=-1) + # It is slightly faster to call this once, than multiple times + predictions_a, predictions_b, references = zero2_accelerator.gather_for_metrics( + (predictions_a, predictions_b, batch["labels"]) + ) + metric_a.add_batch( + predictions=predictions_a, + references=references, + ) + metric_b.add_batch( + predictions=predictions_b, + references=references, + ) + + eval_metric_a = metric_a.compute() + eval_metric_b = metric_b.compute() + # Use accelerator.print to print only on the main process. + zero2_accelerator.print(f"epoch {epoch}:", eval_metric_a, eval_metric_b) + performance_metric_a[f"epoch-{epoch}"] = eval_metric_a["accuracy"] + performance_metric_b[f"epoch-{epoch}"] = eval_metric_b["accuracy"] + + if best_performance_a < eval_metric_a["accuracy"]: + best_performance_a = eval_metric_a["accuracy"] + if best_performance_b < eval_metric_b["accuracy"]: + best_performance_b = eval_metric_b["accuracy"] + assert best_performance_a > performance_metric_a["epoch-0"] + assert best_performance_b > performance_metric_b["epoch-0"] + + +def main(): + parser = argparse.ArgumentParser(description="Simple example of training script tracking peak GPU memory usage.") + parser.add_argument( + "--model_name_or_path", + type=str, + default="bert-base-cased", + help="Path to pretrained model or model identifier from huggingface.co/models.", + required=False, + ) + parser.add_argument( + "--performance_lower_bound", + type=float, + default=None, + help="Optional lower bound for the performance metric. If set, the training will throw error when the performance metric drops below this value.", + ) + parser.add_argument( + "--num_epochs", + type=int, + default=3, + help="Number of train epochs.", + ) + args = parser.parse_args() + config = {"lr": 2e-5, "num_epochs": args.num_epochs, "seed": 42, "batch_size": 8} + single_model_training(config, args) + AcceleratorState._reset_state(True) + multiple_model_training(config, args) + + +if __name__ == "__main__": + main() diff --git a/accelerate/test_utils/scripts/external_deps/test_metrics.py b/accelerate/test_utils/scripts/external_deps/test_metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..d1bfe351509148ebc48067584e9d61b93e7210a6 --- /dev/null +++ b/accelerate/test_utils/scripts/external_deps/test_metrics.py @@ -0,0 +1,307 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import math +import os +from copy import deepcopy + +import datasets +import evaluate +import torch +import transformers +from datasets import load_dataset +from torch.utils.data import DataLoader, IterableDataset +from transformers import AutoModelForSequenceClassification, AutoTokenizer + +from accelerate import Accelerator, DataLoaderConfiguration, DistributedType +from accelerate.data_loader import DataLoaderDispatcher +from accelerate.test_utils import RegressionDataset, RegressionModel, torch_device +from accelerate.utils import is_torch_xla_available, set_seed + + +os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" + + +class ListHandler(logging.Handler): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.logs = [] + + def emit(self, record): + self.logs.append(record) + + +def get_basic_setup(accelerator, num_samples=82, batch_size=16): + "Returns everything needed to perform basic training" + set_seed(42) + model = RegressionModel() + ddp_model = deepcopy(model) + dset = RegressionDataset(length=num_samples) + dataloader = DataLoader(dset, batch_size=batch_size) + model.to(accelerator.device) + ddp_model, dataloader = accelerator.prepare(ddp_model, dataloader) + return model, ddp_model, dataloader + + +def get_dataloader(accelerator: Accelerator, use_longest=False): + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/mrpc-bert-base-cased") + dataset = load_dataset("glue", "mrpc", split="validation") + + def tokenize_function(examples): + outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=None) + return outputs + + with accelerator.main_process_first(): + tokenized_datasets = dataset.map( + tokenize_function, + batched=True, + remove_columns=["idx", "sentence1", "sentence2"], + ) + + tokenized_datasets = tokenized_datasets.rename_column("label", "labels") + + def collate_fn(examples): + if use_longest: + return tokenizer.pad(examples, padding="longest", return_tensors="pt") + return tokenizer.pad(examples, padding="max_length", max_length=128, return_tensors="pt") + + return DataLoader(tokenized_datasets, shuffle=False, collate_fn=collate_fn, batch_size=16) + + +def get_mrpc_setup(dispatch_batches, split_batches): + dataloader_config = DataLoaderConfiguration(dispatch_batches=dispatch_batches, split_batches=split_batches) + accelerator = Accelerator(dataloader_config=dataloader_config) + dataloader = get_dataloader(accelerator, not dispatch_batches) + model = AutoModelForSequenceClassification.from_pretrained( + "hf-internal-testing/mrpc-bert-base-cased", return_dict=True + ) + ddp_model, ddp_dataloader = accelerator.prepare(model, dataloader) + return { + "ddp": [ddp_model, ddp_dataloader, torch_device], + "no": [model, dataloader, accelerator.device], + }, accelerator + + +def generate_predictions(model, dataloader, accelerator): + logits_and_targets = [] + for batch in dataloader: + input, target = batch.values() + with torch.no_grad(): + logit = model(input) + logit, target = accelerator.gather_for_metrics((logit, target)) + logits_and_targets.append((logit, target)) + logits, targs = [], [] + for logit, targ in logits_and_targets: + logits.append(logit) + targs.append(targ) + logits, targs = torch.cat(logits), torch.cat(targs) + return logits, targs + + +def test_torch_metrics( + accelerator: Accelerator, num_samples=82, dispatch_batches=False, split_batches=False, batch_size=16 +): + _, ddp_model, dataloader = get_basic_setup(accelerator, num_samples, batch_size) + logits, _ = generate_predictions(ddp_model, dataloader, accelerator) + assert len(logits) == num_samples, ( + f"Unexpected number of inputs:\n Expected: {num_samples}\n Actual: {len(logits)}" + ) + + +def test_mrpc(dispatch_batches: bool = False, split_batches: bool = False): + metric = evaluate.load("glue", "mrpc") + setup, accelerator = get_mrpc_setup(dispatch_batches, split_batches) + # First do baseline + model, dataloader, device = setup["no"] + model.to(device) + model.eval() + for batch in dataloader: + batch.to(device) + with torch.inference_mode(): + outputs = model(**batch) + preds = outputs.logits.argmax(dim=-1) + metric.add_batch(predictions=preds, references=batch["labels"]) + baseline = metric.compute() + + # Then do distributed + model, dataloader, device = setup["ddp"] + model.eval() + for batch in dataloader: + with torch.inference_mode(): + outputs = model(**batch) + preds = outputs.logits.argmax(dim=-1) + references = batch["labels"] + preds, references = accelerator.gather_for_metrics((preds, references)) + metric.add_batch(predictions=preds, references=references) + distributed = metric.compute() + + for key in "accuracy f1".split(): + assert math.isclose(baseline[key], distributed[key]), ( + f"Baseline and Distributed are not the same for key {key}:\n\tBaseline: {baseline[key]}\n\tDistributed: {distributed[key]}\n" + ) + + +def test_gather_for_metrics_with_non_tensor_objects_iterable_dataset(): + class DummyIterableDataset(IterableDataset): + def __init__(self, data): + self.data = data + + def __len__(self): + return len(self.data) + + def __iter__(self): + yield from self.data + + iterable_dataset = DummyIterableDataset([n for n in range(30)]) + dataloader = DataLoader(iterable_dataset, batch_size=4) + accelerator = Accelerator() + prepared_dataloader = accelerator.prepare(dataloader) + + if accelerator.is_main_process: + logger = logging.root.manager.loggerDict["accelerate.accelerator"] + list_handler = ListHandler() + logger.addHandler(list_handler) + + batches_for_metrics = [] + for batch in prepared_dataloader: + batches_for_metrics.append(accelerator.gather_for_metrics(batch)) + + assert torch.cat(batches_for_metrics).size(0) == 30 + + if accelerator.is_main_process: + assert len(list_handler.logs) == 0 + logger.removeHandler(list_handler) + + +def test_gather_for_metrics_with_iterable_dataset(): + class DummyIterableDataset(IterableDataset): + def __init__(self, data): + self.data = data + + def __len__(self): + return len(self.data) + + def __iter__(self): + yield from self.data + + iterable_dataset = DummyIterableDataset(torch.as_tensor(range(30))) + dataloader = DataLoader(iterable_dataset, batch_size=4) + + accelerator = Accelerator() + prepared_dataloader = accelerator.prepare(dataloader) + + assert isinstance(prepared_dataloader, DataLoaderDispatcher) + + if accelerator.is_main_process: + logger = logging.root.manager.loggerDict["accelerate.accelerator"] + list_handler = ListHandler() + logger.addHandler(list_handler) + + batches_for_metrics = [] + for batch in prepared_dataloader: + batches_for_metrics.append(accelerator.gather_for_metrics(batch)) + + assert torch.cat(batches_for_metrics).size(0) == 30 + + if accelerator.is_main_process: + assert len(list_handler.logs) == 0 + + logger.removeHandler(list_handler) + + +def test_gather_for_metrics_drop_last(): + accelerator = Accelerator() + per_device_batch_size = 5 + num_items = (10 * accelerator.num_processes) + 1 + dataloader = DataLoader(range(num_items), batch_size=per_device_batch_size, drop_last=True) + dataloader = accelerator.prepare(dataloader) + + iterator = iter(dataloader) + next(iterator) # Skip first batch tensor([0, 1, 2, 3, 4], device='cuda:0') + batch = next(iterator) + gathered_items = accelerator.gather_for_metrics(batch) + + # Should return a full set of complete batches from each GPU + num_expected_items = per_device_batch_size * accelerator.num_processes + assert gathered_items.size(0) == (num_expected_items), ( + f"Expected number of items: {num_expected_items}, Actual: {gathered_items.size(0)}" + ) + + +def main(): + dataloader_config = DataLoaderConfiguration(split_batches=False, dispatch_batches=False) + accelerator = Accelerator(dataloader_config=dataloader_config) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_warning() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + # TorchXLA does not support batch dispatching. 'put_on_device' is always False for + # TorchXLA, which can cause a value error in 'prepare_data_loader' function. + dispatch_batches_options = [False] if accelerator.state.distributed_type == DistributedType.XLA else [True, False] + + # Temporarily close this test for TorchXLA due to the 'Cannot set version_counter for + # inference tensor' error in inference mode. Reopen it after TorchXLA fixes this bug. + # These are a bit slower so they should only be ran on the GPU or TPU + if accelerator.device.type != "cpu" and not is_torch_xla_available(): + if accelerator.is_local_main_process: + print("**Testing gather_for_metrics**") + for split_batches in [True, False]: + for dispatch_batches in dispatch_batches_options: + if accelerator.is_local_main_process: + print(f"With: `split_batches={split_batches}`, `dispatch_batches={dispatch_batches}`") + test_mrpc(dispatch_batches, split_batches) + accelerator.state._reset_state() + print("test_gather_for_metrics_with_iterable_dataset") + test_gather_for_metrics_with_iterable_dataset() + print("test gather_for_metrics_with_non_tensor_objects_iterable_dataset") + test_gather_for_metrics_with_non_tensor_objects_iterable_dataset() + + # MpDeviceLoader in TorchXLA is an asynchronous loader that preloads several batches into cache. + # This can cause the 'end_of_dataloader' of DataLoaderStateMixin to be set earlier than intended. + # Skip this test when TorchXLA is enabled. + if accelerator.state.distributed_type != DistributedType.XLA: + if accelerator.is_local_main_process: + print("**Test torch metrics**") + for split_batches in [True, False]: + for dispatch_batches in dispatch_batches_options: + dataloader_config = DataLoaderConfiguration( + split_batches=split_batches, dispatch_batches=dispatch_batches + ) + accelerator = Accelerator(dataloader_config=dataloader_config) + if accelerator.is_local_main_process: + print(f"With: `split_batches={split_batches}`, `dispatch_batches={dispatch_batches}`, length=99") + test_torch_metrics(accelerator, 99) + accelerator.state._reset_state() + if accelerator.is_local_main_process: + print("**Test last batch is not dropped when perfectly divisible**") + accelerator = Accelerator() + test_torch_metrics(accelerator, 512) + accelerator.state._reset_state() + if accelerator.is_local_main_process: + print("**Test that `drop_last` is taken into account**") + test_gather_for_metrics_drop_last() + accelerator.end_training() + accelerator.state._reset_state() + + +def _mp_fn(index): + # For xla_spawn (TPUs) + main() + + +if __name__ == "__main__": + main() diff --git a/accelerate/test_utils/scripts/external_deps/test_peak_memory_usage.py b/accelerate/test_utils/scripts/external_deps/test_peak_memory_usage.py new file mode 100644 index 0000000000000000000000000000000000000000..5d75259e9959c1b9f75850d81546a5e98dd206e8 --- /dev/null +++ b/accelerate/test_utils/scripts/external_deps/test_peak_memory_usage.py @@ -0,0 +1,323 @@ +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import gc +import json +import os + +import torch +from datasets import load_dataset +from torch.optim import AdamW +from torch.utils.data import DataLoader +from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed + +from accelerate import Accelerator, DistributedType +from accelerate.utils import ( + is_hpu_available, + is_mlu_available, + is_musa_available, + is_neuron_available, + is_npu_available, + is_sdaa_available, + is_xpu_available, +) +from accelerate.utils.deepspeed import DummyOptim, DummyScheduler + + +MAX_GPU_BATCH_SIZE = 16 +EVAL_BATCH_SIZE = 32 + + +# Converting Bytes to Megabytes +def b2mb(x): + return int(x / 2**20) + + +# This context manager is used to track the peak memory usage of the process +class TorchTracemalloc: + def __enter__(self): + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() # reset the peak gauge to zero + self.begin = torch.cuda.memory_allocated() + elif is_mlu_available(): + torch.mlu.empty_cache() + torch.mlu.reset_max_memory_allocated() # reset the peak gauge to zero + self.begin = torch.mlu.memory_allocated() + elif is_sdaa_available(): + torch.sdaa.empty_cache() + torch.sdaa.reset_max_memory_allocated() # reset the peak gauge to zero + self.begin = torch.sdaa.memory_allocated() + elif is_musa_available(): + torch.musa.empty_cache() + torch.musa.reset_max_memory_allocated() # reset the peak gauge to zero + self.begin = torch.musa.memory_allocated() + elif is_npu_available(): + torch.npu.empty_cache() + torch.npu.reset_max_memory_allocated() # reset the peak gauge to zero + self.begin = torch.npu.memory_allocated() + elif is_xpu_available(): + torch.xpu.empty_cache() + torch.xpu.reset_peak_memory_stats() # reset the peak gauge to zero + self.begin = torch.xpu.memory_allocated() + elif is_hpu_available(): + # torch.hpu.empty_cache() # not available on hpu as it reserves all device memory for the current process + torch.hpu.reset_peak_memory_stats() # reset the peak gauge to zero + self.begin = torch.hpu.memory_allocated() + elif is_neuron_available(): + torch.neuron.empty_cache() + torch.neuron.reset_peak_memory_stats() # reset the peak gauge to zero + self.begin = torch.neuron.memory_allocated() + return self + + def __exit__(self, *exc): + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + self.end = torch.cuda.memory_allocated() + self.peak = torch.cuda.max_memory_allocated() + elif is_mlu_available(): + torch.mlu.empty_cache() + self.end = torch.mlu.memory_allocated() + self.begin = torch.mlu.max_memory_allocated() + elif is_sdaa_available(): + torch.sdaa.empty_cache() + self.end = torch.sdaa.memory_allocated() + self.begin = torch.sdaa.max_memory_allocated() + elif is_musa_available(): + torch.musa.empty_cache() + self.end = torch.musa.memory_allocated() + self.begin = torch.musa.max_memory_allocated() + elif is_npu_available(): + torch.npu.empty_cache() + self.end = torch.npu.memory_allocated() + self.peak = torch.npu.max_memory_allocated() + elif is_xpu_available(): + torch.xpu.empty_cache() + self.end = torch.xpu.memory_allocated() + self.peak = torch.xpu.max_memory_allocated() + elif is_hpu_available(): + # torch.hpu.empty_cache() # not available on hpu as it reserves all device memory for the current process + self.end = torch.hpu.memory_allocated() + self.peak = torch.hpu.max_memory_allocated() + elif is_neuron_available(): + torch.neuron.empty_cache() + self.end = torch.neuron.memory_allocated() + self.peak = torch.neuron.max_memory_allocated() + self.used = b2mb(self.end - self.begin) + self.peaked = b2mb(self.peak - self.begin) + # print(f"delta used/peak {self.used:4d}/{self.peaked:4d}") + + +def get_dataloaders( + accelerator: Accelerator, + batch_size: int = 16, + model_name: str = "bert-base-cased", + n_train: int = 320, + n_val: int = 160, +): + """ + Creates a set of `DataLoader`s for the `glue` dataset. + + Args: + accelerator (`Accelerator`): + An `Accelerator` object + batch_size (`int`, *optional*): + The batch size for the train and validation DataLoaders. + model_name (`str`, *optional*): + The name of the model to use. + n_train (`int`, *optional*): + The number of training examples to use. + n_val (`int`, *optional*): + The number of validation examples to use. + """ + tokenizer = AutoTokenizer.from_pretrained(model_name) + datasets = load_dataset( + "glue", "mrpc", split={"train": f"train[:{n_train}]", "validation": f"validation[:{n_val}]"} + ) + + def tokenize_function(examples): + # max_length=None => use the model max length (it's actually the default) + outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=None) + return outputs + + # Apply the method we just defined to all the examples in all the splits of the dataset + tokenized_datasets = datasets.map( + tokenize_function, batched=True, remove_columns=["idx", "sentence1", "sentence2"], load_from_cache_file=False + ) + + # We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the + # transformers library + tokenized_datasets = tokenized_datasets.rename_column("label", "labels") + + def collate_fn(examples): + # On TPU it's best to pad everything to the same length or training will be very slow. + if accelerator.distributed_type == DistributedType.XLA: + return tokenizer.pad(examples, padding="max_length", max_length=128, return_tensors="pt") + return tokenizer.pad(examples, padding="longest", return_tensors="pt") + + # Instantiate dataloaders. + train_dataloader = DataLoader( + tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size + ) + eval_dataloader = DataLoader( + tokenized_datasets["validation"], shuffle=False, collate_fn=collate_fn, batch_size=EVAL_BATCH_SIZE + ) + + return train_dataloader, eval_dataloader + + +def training_function(config, args): + # Initialize accelerator + accelerator = Accelerator() + + # Sample hyper-parameters for learning rate, batch size, seed and a few other HPs + lr = config["lr"] + num_epochs = int(config["num_epochs"]) + seed = int(config["seed"]) + batch_size = int(config["batch_size"]) + model_name = args.model_name_or_path + + set_seed(seed) + train_dataloader, eval_dataloader = get_dataloaders(accelerator, batch_size, model_name, args.n_train, args.n_val) + + # Instantiate the model (we build the model here so that the seed also control new weights initialization) + model = AutoModelForSequenceClassification.from_pretrained(model_name, return_dict=True) + + # Instantiate optimizer + optimizer_cls = ( + AdamW + if accelerator.state.deepspeed_plugin is None + or "optimizer" not in accelerator.state.deepspeed_plugin.deepspeed_config + else DummyOptim + ) + optimizer = optimizer_cls(params=model.parameters(), lr=lr) + + if accelerator.state.deepspeed_plugin is not None: + gradient_accumulation_steps = accelerator.state.deepspeed_plugin.deepspeed_config[ + "gradient_accumulation_steps" + ] + else: + gradient_accumulation_steps = 1 + max_training_steps = (len(train_dataloader) * num_epochs) // gradient_accumulation_steps + + # Instantiate scheduler + if ( + accelerator.state.deepspeed_plugin is None + or "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config + ): + lr_scheduler = get_linear_schedule_with_warmup( + optimizer=optimizer, + num_warmup_steps=0, + num_training_steps=max_training_steps, + ) + else: + lr_scheduler = DummyScheduler(optimizer, total_num_steps=max_training_steps, warmup_num_steps=0) + + # Prepare everything + # There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the + # prepare method. + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare( + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler + ) + + # We need to keep track of how many total steps we have iterated over + overall_step = 0 + # We also need to keep track of the stating epoch so files are named properly + starting_epoch = 0 + + # Now we train the model + train_total_peak_memory = {} + for epoch in range(starting_epoch, num_epochs): + with TorchTracemalloc() as tracemalloc: + model.train() + for step, batch in enumerate(train_dataloader): + outputs = model(**batch) + loss = outputs.loss + loss = loss / gradient_accumulation_steps + accelerator.backward(loss) + if step % gradient_accumulation_steps == 0: + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + overall_step += 1 + + # Printing the GPU memory usage details such as allocated memory, peak memory, and total memory usage + accelerator.print(f"Memory before entering the train : {b2mb(tracemalloc.begin)}") + accelerator.print(f"Memory consumed at the end of the train (end-begin): {tracemalloc.used}") + accelerator.print(f"Peak Memory consumed during the train (max-begin): {tracemalloc.peaked}") + accelerator.print( + f"Total Peak Memory consumed during the train (max): {tracemalloc.peaked + b2mb(tracemalloc.begin)}" + ) + train_total_peak_memory[f"epoch-{epoch}"] = tracemalloc.peaked + b2mb(tracemalloc.begin) + if args.peak_memory_upper_bound is not None: + assert train_total_peak_memory[f"epoch-{epoch}"] <= args.peak_memory_upper_bound, ( + "Peak memory usage exceeded the upper bound" + ) + + accelerator.wait_for_everyone() + if accelerator.is_main_process: + with open(os.path.join(args.output_dir, "peak_memory_utilization.json"), "w") as f: + json.dump(train_total_peak_memory, f) + accelerator.end_training() + + +def main(): + parser = argparse.ArgumentParser(description="Simple example of training script tracking peak GPU memory usage.") + parser.add_argument( + "--model_name_or_path", + type=str, + default="bert-base-cased", + help="Path to pretrained model or model identifier from huggingface.co/models.", + required=False, + ) + parser.add_argument( + "--output_dir", + type=str, + default=".", + help="Optional save directory where all checkpoint folders will be stored. Default is the current working directory.", + ) + parser.add_argument( + "--peak_memory_upper_bound", + type=float, + default=None, + help="The upper bound of peak memory usage in MB. If set, the training will throw an error if the peak memory usage exceeds this value.", + ) + parser.add_argument( + "--n_train", + type=int, + default=320, + help="Number of training examples to use.", + ) + parser.add_argument( + "--n_val", + type=int, + default=160, + help="Number of validation examples to use.", + ) + parser.add_argument( + "--num_epochs", + type=int, + default=1, + help="Number of train epochs.", + ) + args = parser.parse_args() + config = {"lr": 2e-5, "num_epochs": args.num_epochs, "seed": 42, "batch_size": 16} + training_function(config, args) + + +if __name__ == "__main__": + main() diff --git a/accelerate/test_utils/scripts/external_deps/test_performance.py b/accelerate/test_utils/scripts/external_deps/test_performance.py new file mode 100644 index 0000000000000000000000000000000000000000..8e500bdd4c01013904375f010e99d22aea4e4ff9 --- /dev/null +++ b/accelerate/test_utils/scripts/external_deps/test_performance.py @@ -0,0 +1,299 @@ +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import json +import os +from contextlib import nullcontext +from pathlib import Path + +import evaluate +import torch +from datasets import load_dataset +from torch.optim import AdamW +from torch.utils.data import DataLoader +from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup + +from accelerate import Accelerator, DistributedType +from accelerate.parallelism_config import ParallelismConfig +from accelerate.utils import SAFE_WEIGHTS_NAME, set_seed +from accelerate.utils.deepspeed import DummyOptim, DummyScheduler + + +MAX_GPU_BATCH_SIZE = 16 +EVAL_BATCH_SIZE = 32 + + +def get_dataloaders(accelerator: Accelerator, batch_size: int = 16, model_name: str = "bert-base-cased"): + """ + Creates a set of `DataLoader`s for the `glue` dataset. + + Args: + accelerator (`Accelerator`): + An `Accelerator` object + batch_size (`int`, *optional*): + The batch size for the train and validation DataLoaders. + model_name (`str`, *optional*): + """ + tokenizer = AutoTokenizer.from_pretrained(model_name) + + datasets = load_dataset("glue", "mrpc") + + def tokenize_function(examples): + # max_length=None => use the model max length (it's actually the default) + outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=None) + return outputs + + # Apply the method we just defined to all the examples in all the splits of the dataset + tokenized_datasets = datasets.map( + tokenize_function, batched=True, remove_columns=["idx", "sentence1", "sentence2"], load_from_cache_file=False + ) + + # We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the + # transformers library + tokenized_datasets = tokenized_datasets.rename_column("label", "labels") + + def collate_fn(examples): + # On TPU it's best to pad everything to the same length or training will be very slow. + if accelerator.distributed_type == DistributedType.XLA: + return tokenizer.pad(examples, padding="max_length", max_length=128, return_tensors="pt") + return tokenizer.pad(examples, padding="longest", return_tensors="pt") + + # Instantiate dataloaders. + train_dataloader = DataLoader( + tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size + ) + eval_dataloader = DataLoader( + tokenized_datasets["validation"], shuffle=False, collate_fn=collate_fn, batch_size=EVAL_BATCH_SIZE + ) + + return train_dataloader, eval_dataloader + + +def training_function(config, args): + accelerator_kwargs = {} + # need this for DeepSpeed tests as `args.tp_size` would be None and `torch.distributed.init_device_mesh` would fail + if args.tp_size is not None: + accelerator_kwargs["parallelism_config"] = ParallelismConfig(tp_size=args.tp_size) + + # Initialize accelerator + accelerator = Accelerator(**accelerator_kwargs) + + # Sample hyper-parameters for learning rate, batch size, seed and a few other HPs + lr = config["lr"] + num_epochs = int(config["num_epochs"]) + seed = int(config["seed"]) + batch_size = int(config["batch_size"]) + model_name = args.model_name_or_path + + set_seed(seed) + train_dataloader, eval_dataloader = get_dataloaders(accelerator, batch_size, model_name) + + # Add TP related kwargs if provided + model_kwargs = {} + if args.tp_plan is not None: + model_kwargs["tp_plan"] = args.tp_plan + if args.tp_size is not None: + model_kwargs["tp_size"] = args.tp_size + + # Instantiate the model (we build the model here so that the seed also control new weights initialization) + model = AutoModelForSequenceClassification.from_pretrained(model_name, return_dict=True, **model_kwargs) + + if args.add_pad_token: + if model.config.pad_token_id is None: + model.config.pad_token_id = 0 + + # Instantiate optimizer + optimizer_cls = ( + AdamW + if accelerator.state.deepspeed_plugin is None + or "optimizer" not in accelerator.state.deepspeed_plugin.deepspeed_config + else DummyOptim + ) + optimizer = optimizer_cls(params=model.parameters(), lr=lr) + + max_training_steps = len(train_dataloader) * num_epochs + + # Instantiate scheduler + linear_decay_scheduler = False + if ( + accelerator.state.deepspeed_plugin is None + or "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config + ): + lr_scheduler = get_linear_schedule_with_warmup( + optimizer=optimizer, + num_warmup_steps=0, + num_training_steps=max_training_steps, + ) + linear_decay_scheduler = True + else: + lr_scheduler = DummyScheduler(optimizer, total_num_steps=max_training_steps, warmup_num_steps=0) + + # Prepare everything + # There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the + # prepare method. + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare( + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler + ) + + # We also need to keep track of the stating epoch so files are named properly + starting_epoch = 0 + + # Now we train the model + metric = evaluate.load("glue", "mrpc") + best_performance = 0 + performance_metric = {} + expected_lr_after_first_optim_step = lr * ( + 1 - 1 / (max_training_steps / accelerator.num_processes / accelerator.gradient_accumulation_steps) + ) + lr_scheduler_check_completed = False + for epoch in range(starting_epoch, num_epochs): + model.train() + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(model): + outputs = model(**batch) + loss = outputs.loss + accelerator.backward(loss) + context = nullcontext + if args.tp_plan is not None: + from torch.distributed._tensor.experimental import implicit_replication + + context = implicit_replication + with context(): + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # assert the learning rate after first optimizer step + if ( + accelerator.sync_gradients + and not lr_scheduler_check_completed + and linear_decay_scheduler + and accelerator.state.mixed_precision == "no" + ): + assert lr_scheduler.get_last_lr()[0] == expected_lr_after_first_optim_step, ( + f"Wrong lr found at second step, expected {expected_lr_after_first_optim_step}, got {lr_scheduler.get_last_lr()[0]}" + ) + lr_scheduler_check_completed = True + + model.eval() + samples_seen = 0 + for step, batch in enumerate(eval_dataloader): + # We could avoid this line since we set the accelerator with `device_placement=True`. + batch.to(accelerator.device) + with torch.no_grad(): + outputs = model(**batch) + predictions = outputs.logits.argmax(dim=-1) + # It is slightly faster to call this once, than multiple times + predictions, references = accelerator.gather( + (predictions, batch["labels"]) + ) # If we are in a multiprocess environment, the last batch has duplicates + if accelerator.use_distributed: + if step == len(eval_dataloader) - 1: + predictions = predictions[: len(eval_dataloader.dataset) - samples_seen] + references = references[: len(eval_dataloader.dataset) - samples_seen] + else: + samples_seen += references.shape[0] + metric.add_batch( + predictions=predictions, + references=references, + ) + + eval_metric = metric.compute() + # Use accelerator.print to print only on the main process. + accelerator.print(f"epoch {epoch}:", eval_metric) + performance_metric[f"epoch-{epoch}"] = eval_metric["accuracy"] + + if best_performance < eval_metric["accuracy"]: + best_performance = eval_metric["accuracy"] + + # check that the LR is 0 + if linear_decay_scheduler and accelerator.state.mixed_precision == "no": + assert lr_scheduler.get_last_lr()[0] == 0, ( + f"Wrong lr found at last step, expected 0, got {lr_scheduler.get_last_lr()[0]}" + ) + + if args.performance_lower_bound is not None: + assert args.performance_lower_bound <= best_performance, ( + f"Best performance metric {best_performance} is lower than the lower bound {args.performance_lower_bound}" + ) + + accelerator.wait_for_everyone() + if accelerator.is_main_process: + with open(os.path.join(args.output_dir, "all_results.json"), "w") as f: + json.dump(performance_metric, f) + + # TODO: skip saving of the model test for TP until the feature lands + if args.tp_plan is None: + # Finally try saving the model + accelerator.save_model(model, args.output_dir) + accelerator.wait_for_everyone() + if args.tp_plan is None: + assert Path(args.output_dir, SAFE_WEIGHTS_NAME).exists(), ( + "Model was not saved when calling `Accelerator.save_model`" + ) + accelerator.end_training() + + +def main(): + parser = argparse.ArgumentParser(description="Simple example of training script tracking peak GPU memory usage.") + parser.add_argument( + "--model_name_or_path", + type=str, + default="bert-base-cased", + help="Path to pretrained model or model identifier from huggingface.co/models.", + required=False, + ) + parser.add_argument( + "--output_dir", + type=str, + default=".", + help="Optional save directory where all checkpoint folders will be stored. Default is the current working directory.", + ) + parser.add_argument( + "--performance_lower_bound", + type=float, + default=None, + help="Optional lower bound for the performance metric. If set, the training will throw error when the performance metric drops below this value.", + ) + parser.add_argument( + "--num_epochs", + type=int, + default=3, + help="Number of train epochs.", + ) + parser.add_argument( + "--add_pad_token", + type=bool, + default=False, + help="To add pad token if not exists.", + ) + parser.add_argument( + "--tp_plan", + type=str, + default=None, + help="pass 'auto' to use TP", + ) + parser.add_argument( + "--tp_size", + type=int, + default=None, + help="TP size to be used to shard the model", + ) + args = parser.parse_args() + config = {"lr": 2e-5, "num_epochs": args.num_epochs, "seed": 42, "batch_size": 16} + training_function(config, args) + + +if __name__ == "__main__": + main() diff --git a/accelerate/test_utils/scripts/external_deps/test_pippy.py b/accelerate/test_utils/scripts/external_deps/test_pippy.py new file mode 100644 index 0000000000000000000000000000000000000000..1dbd86c46b4a0c12df8ea4d736c7cd1e03f81813 --- /dev/null +++ b/accelerate/test_utils/scripts/external_deps/test_pippy.py @@ -0,0 +1,117 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from transformers import ( + BertConfig, + BertForMaskedLM, + GPT2Config, + GPT2ForSequenceClassification, +) + +from accelerate import PartialState +from accelerate.inference import prepare_pippy +from accelerate.test_utils import torch_device +from accelerate.utils import DistributedType, set_seed + + +model_to_config = { + "bert": (BertForMaskedLM, BertConfig, 512), + "gpt2": (GPT2ForSequenceClassification, GPT2Config, 1024), +} + + +def get_model_and_data_for_text(model_name, device, num_processes: int = 2): + initializer, config, seq_len = model_to_config[model_name] + config_args = {} + # Eventually needed for batch inference tests on gpt-2 when bs != 1 + # if model_name == "gpt2": + # config_args["pad_token_id"] = 0 + model_config = config(**config_args) + model = initializer(model_config) + kwargs = dict(low=0, high=model_config.vocab_size, device=device, dtype=torch.int64, requires_grad=False) + trace_input = torch.randint(size=(1, seq_len), **kwargs) + inference_inputs = torch.randint(size=(num_processes, seq_len), **kwargs) + return model, trace_input, inference_inputs + + +def test_bert(batch_size: int = 2): + set_seed(42) + state = PartialState() + model, trace_input, inference_inputs = get_model_and_data_for_text("bert", "cpu", batch_size) + model = prepare_pippy(model, example_args=(trace_input,), no_split_module_classes=model._no_split_modules) + # For inference args need to be a tuple + inputs = inference_inputs.to(torch_device) + with torch.no_grad(): + output = model(inputs) + # Zach: Check that we just grab the real outputs we need at the end + if not state.is_last_process: + assert output is None, "Output was not generated on just the last process!" + else: + assert output is not None, "Output was not generated in the last process!" + + +def test_gpt2(batch_size: int = 2): + set_seed(42) + state = PartialState() + model, trace_input, inference_inputs = get_model_and_data_for_text("gpt2", "cpu", batch_size) + model = prepare_pippy(model, example_args=(trace_input,), no_split_module_classes=model._no_split_modules) + # For inference args need to be a tuple + inputs = inference_inputs.to(torch_device) + with torch.no_grad(): + output = model(inputs) + # Zach: Check that we just grab the real outputs we need at the end + if not state.is_last_process: + assert output is None, "Output was not generated on just the last process!" + else: + assert output is not None, "Output was not generated in the last process!" + + +# Currently disabled, enable again once PyTorch pippy interface can trace a resnet34 +# def test_resnet(batch_size: int = 2): +# set_seed(42) +# state = PartialState() +# model = resnet34() +# input_tensor = torch.rand(1, 3, 224, 224) +# model = prepare_pippy( +# model, +# example_args=(input_tensor,), +# ) +# inference_inputs = torch.rand(batch_size, 3, 224, 224) +# inputs = send_to_device(inference_inputs, torch_device) +# with torch.no_grad(): +# output = model(inputs) +# # Zach: Check that we just grab the real outputs we need at the end +# if not state.is_last_process: +# assert output is None, "Output was not generated on just the last process!" +# else: +# assert output is not None, "Output was not generated in the last process!" + + +if __name__ == "__main__": + state = PartialState() + state.print("Testing pippy integration...") + try: + if state.distributed_type in [DistributedType.MULTI_GPU, DistributedType.MULTI_XPU, DistributedType.MULTI_HPU]: + state.print("Testing GPT2...") + test_gpt2() + # Issue: When modifying the tokenizer for batch GPT2 inference, there's an issue + # due to references + # NameError: cannot access free variable 'chunk_args_list' where it is not associated with a value in enclosing scope + # test_gpt2(3) + state.print("Testing BERT...") + test_bert() + else: + print("Less than two GPUs found, not running tests!") + finally: + state.destroy_process_group() diff --git a/accelerate/test_utils/scripts/external_deps/test_zero3_integration.py b/accelerate/test_utils/scripts/external_deps/test_zero3_integration.py new file mode 100644 index 0000000000000000000000000000000000000000..f6e46d342a0c754cf1ed5b2629290d62e7c943aa --- /dev/null +++ b/accelerate/test_utils/scripts/external_deps/test_zero3_integration.py @@ -0,0 +1,59 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch.distributed + +from accelerate.test_utils import require_huggingface_suite, torch_device +from accelerate.utils import is_transformers_available + + +if is_transformers_available(): + from transformers import AutoModel, TrainingArguments + + +GPT2_TINY = "sshleifer/tiny-gpt2" + + +@require_huggingface_suite +def init_torch_dist_then_launch_deepspeed(): + if torch_device == "xpu": + backend = "xccl" + elif torch_device == "hpu": + backend = "hccl" + else: + backend = "nccl" + + torch.distributed.init_process_group(backend=backend) + deepspeed_config = { + "zero_optimization": { + "stage": 3, + }, + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + } + train_args = TrainingArguments( + output_dir="./", + deepspeed=deepspeed_config, + ) + model = AutoModel.from_pretrained(GPT2_TINY) + assert train_args is not None + assert model is not None + + +def main(): + init_torch_dist_then_launch_deepspeed() + + +if __name__ == "__main__": + main() diff --git a/accelerate/test_utils/scripts/test_cli.py b/accelerate/test_utils/scripts/test_cli.py new file mode 100644 index 0000000000000000000000000000000000000000..fc9dd1d36e8f2949d6e3cebb8ce65efbf0b9e5e4 --- /dev/null +++ b/accelerate/test_utils/scripts/test_cli.py @@ -0,0 +1,32 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +from accelerate.utils import is_xpu_available + + +def main(): + accelerator_type = "GPU" + num_accelerators = 0 + if torch.cuda.is_available(): + num_accelerators = torch.cuda.device_count() + accelerator_type = "GPU" + elif is_xpu_available(): + num_accelerators = torch.xpu.device_count() + accelerator_type = "XPU" + print(f"Successfully ran on {num_accelerators} {accelerator_type}s") + + +if __name__ == "__main__": + main() diff --git a/accelerate/test_utils/scripts/test_ddp_comm_hook.py b/accelerate/test_utils/scripts/test_ddp_comm_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..0db5844e026d1c035670e518a8f81d33136ea665 --- /dev/null +++ b/accelerate/test_utils/scripts/test_ddp_comm_hook.py @@ -0,0 +1,85 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +from accelerate import Accelerator, DDPCommunicationHookType, DistributedDataParallelKwargs, PartialState +from accelerate.utils import is_hpu_available + + +class MockModel(torch.nn.Module): + def __init__(self): + super().__init__() + torch.manual_seed(0) + self.p = torch.nn.Parameter(torch.randn(40, 20)) + + def forward(self, x, rank): + return self.p * (x ** (1 + rank)) + + +def _run_and_get_grads(model, rank): + torch.manual_seed(2024) + input = torch.randn(40, 20) + output = model(input, rank) + output.mean().backward() + param = next(model.parameters()) + return param.grad + + +def test_ddp_comm_hook(comm_hook, comm_wrapper, comm_state_option): + ddp_kwargs = DistributedDataParallelKwargs( + comm_hook=comm_hook, + comm_wrapper=comm_wrapper, + comm_state_option=comm_state_option, + ) + accelerator = Accelerator(kwargs_handlers=[ddp_kwargs]) + + model = accelerator.prepare(MockModel()) + hook_grads = _run_and_get_grads(model, accelerator.local_process_index) + + reference_model = torch.nn.parallel.DistributedDataParallel( + MockModel().to(accelerator.device), + device_ids=[accelerator.local_process_index], + output_device=accelerator.local_process_index, + ) + reference_grads = _run_and_get_grads(reference_model, accelerator.local_process_index) + + torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-2, atol=1e-2) + + +def main(): + for comm_hook, comm_wrapper, comm_state_option in [ + (DDPCommunicationHookType.NO, DDPCommunicationHookType.NO, {}), + (DDPCommunicationHookType.FP16, DDPCommunicationHookType.NO, {}), + (DDPCommunicationHookType.BF16, DDPCommunicationHookType.NO, {}), + (DDPCommunicationHookType.POWER_SGD, DDPCommunicationHookType.NO, {}), + (DDPCommunicationHookType.POWER_SGD, DDPCommunicationHookType.FP16, {}), + (DDPCommunicationHookType.POWER_SGD, DDPCommunicationHookType.BF16, {}), + (DDPCommunicationHookType.POWER_SGD, DDPCommunicationHookType.NO, {"matrix_approximation_rank": 2}), + (DDPCommunicationHookType.BATCHED_POWER_SGD, DDPCommunicationHookType.NO, {}), + (DDPCommunicationHookType.BATCHED_POWER_SGD, DDPCommunicationHookType.FP16, {}), + (DDPCommunicationHookType.BATCHED_POWER_SGD, DDPCommunicationHookType.BF16, {}), + ]: + if is_hpu_available(): + HPU_UNSUPPORTED_COMM_HOOKS = {DDPCommunicationHookType.FP16, DDPCommunicationHookType.BF16} + if comm_hook in HPU_UNSUPPORTED_COMM_HOOKS or comm_wrapper in HPU_UNSUPPORTED_COMM_HOOKS: + print(f"Skipping test DDP comm hook: {comm_hook}, comm wrapper: {comm_wrapper} on HPU") + continue + + print(f"Test DDP comm hook: {comm_hook}, comm wrapper: {comm_wrapper}") + test_ddp_comm_hook(comm_hook, comm_wrapper, comm_state_option) + PartialState().destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/accelerate/test_utils/scripts/test_distributed_data_loop.py b/accelerate/test_utils/scripts/test_distributed_data_loop.py new file mode 100644 index 0000000000000000000000000000000000000000..84a9247a6bce04a55a67ad33e1761c627772aced --- /dev/null +++ b/accelerate/test_utils/scripts/test_distributed_data_loop.py @@ -0,0 +1,429 @@ +#!/usr/bin/env python + +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pickle +import tempfile +import warnings +from unittest.mock import Mock + +import torch +from torch.utils.data import ( + BatchSampler, + DataLoader, + Dataset, + IterableDataset, + RandomSampler, + TensorDataset, + default_collate, +) + +from accelerate.accelerator import Accelerator, DataLoaderConfiguration +from accelerate.utils.dataclasses import DistributedType + + +NUM_ELEMENTS = 22 +NUM_WORKERS = 4 +BATCH_SIZE = 4 + + +class DummyDataset(Dataset): + def __len__(self): + return NUM_ELEMENTS + + def __getitem__(self, index): + squeeze = False + + if isinstance(index, int): + index = [index] + squeeze = True + elif isinstance(index, slice): + index = list(range(*index.indices(self.size))) + else: + index = list(index) + + batch = [{"index": i, "label": i % 2, "random_augmentation": torch.rand(1).item()} for i in index] + + if squeeze: + batch = batch[0] + + return batch + + +class DummyIterableDataset(IterableDataset): + def __init__(self, data): + self.data = data + + def __iter__(self): + yield from self.data + + +def create_accelerator(even_batches=True): + dataloader_config = DataLoaderConfiguration(even_batches=even_batches) + accelerator = Accelerator(dataloader_config=dataloader_config) + assert accelerator.num_processes == 2, "this script expects that two GPUs are available" + return accelerator + + +def create_dataloader( + accelerator: Accelerator, dataset_size: int, batch_size: int, iterable: bool = False, shuffle: bool = False +): + """ + Create a simple DataLoader to use during the test cases + """ + values = torch.as_tensor(range(dataset_size)) + if shuffle: + values = values[torch.randperm(values.size(0))] + if iterable: + dataset = DummyIterableDataset(values) + else: + dataset = TensorDataset(torch.as_tensor(range(dataset_size))) + + dl = DataLoader(dataset, batch_size=batch_size) + dl = accelerator.prepare(dl) + + return dl + + +def verify_dataloader_batch_sizes( + accelerator: Accelerator, + dataset_size: int, + batch_size: int, + process_0_expected_batch_sizes: list[int], + process_1_expected_batch_sizes: list[int], +): + """ + A helper function for verifying the batch sizes coming from a prepared dataloader in each process + """ + dl = create_dataloader(accelerator=accelerator, dataset_size=dataset_size, batch_size=batch_size) + + batch_sizes = [len(batch[0]) for batch in dl] + + if accelerator.process_index == 0: + assert batch_sizes == process_0_expected_batch_sizes + elif accelerator.process_index == 1: + assert batch_sizes == process_1_expected_batch_sizes + + +def test_default_ensures_even_batch_sizes(): + accelerator = create_accelerator() + + # without padding, we would expect a different number of batches + verify_dataloader_batch_sizes( + accelerator, + dataset_size=3, + batch_size=1, + process_0_expected_batch_sizes=[1, 1], + process_1_expected_batch_sizes=[1, 1], + ) + + # without padding, we would expect the same number of batches, but different sizes + verify_dataloader_batch_sizes( + accelerator, + dataset_size=7, + batch_size=2, + process_0_expected_batch_sizes=[2, 2], + process_1_expected_batch_sizes=[2, 2], + ) + + +def test_can_disable_even_batches(): + accelerator = create_accelerator(even_batches=False) + + verify_dataloader_batch_sizes( + accelerator, + dataset_size=3, + batch_size=1, + process_0_expected_batch_sizes=[1, 1], + process_1_expected_batch_sizes=[1], + ) + + verify_dataloader_batch_sizes( + accelerator, + dataset_size=7, + batch_size=2, + process_0_expected_batch_sizes=[2, 2], + process_1_expected_batch_sizes=[2, 1], + ) + + +def test_can_join_uneven_inputs(): + accelerator = create_accelerator(even_batches=False) + + model = torch.nn.Linear(1, 1) + ddp_model = accelerator.prepare(model) + + dl = create_dataloader(accelerator, dataset_size=3, batch_size=1) + + batch_idxs = [] + with accelerator.join_uneven_inputs([ddp_model]): + for batch_idx, batch in enumerate(dl): + output = ddp_model(batch[0].float()) + loss = output.sum() + loss.backward() + batch_idxs.append(batch_idx) + + accelerator.wait_for_everyone() + + if accelerator.process_index == 0: + assert batch_idxs == [0, 1] + elif accelerator.process_index == 1: + assert batch_idxs == [0] + + +def test_join_raises_warning_for_non_ddp_distributed(accelerator): + with warnings.catch_warnings(record=True) as w: + with accelerator.join_uneven_inputs([Mock()]): + pass + + assert issubclass(w[-1].category, UserWarning) + assert "only supported for multi-GPU" in str(w[-1].message) + + +def test_join_can_override_even_batches(): + default_even_batches = True + overridden_even_batches = False + accelerator = create_accelerator(even_batches=default_even_batches) + model = torch.nn.Linear(1, 1) + ddp_model = accelerator.prepare(model) + train_dl = create_dataloader(accelerator, dataset_size=3, batch_size=1) + valid_dl = create_dataloader(accelerator, dataset_size=3, batch_size=1) + + with accelerator.join_uneven_inputs([ddp_model], even_batches=overridden_even_batches): + train_dl_overridden_value = train_dl.batch_sampler.even_batches + valid_dl_overridden_value = valid_dl.batch_sampler.even_batches + + assert train_dl_overridden_value == overridden_even_batches + assert valid_dl_overridden_value == overridden_even_batches + assert train_dl.batch_sampler.even_batches == default_even_batches + assert valid_dl.batch_sampler.even_batches == default_even_batches + + +def test_join_can_override_for_mixed_type_dataloaders(): + default_even_batches = True + overridden_even_batches = False + accelerator = create_accelerator(even_batches=default_even_batches) + model = torch.nn.Linear(1, 1) + ddp_model = accelerator.prepare(model) + create_dataloader(accelerator, dataset_size=3, batch_size=1, iterable=True) + batch_dl = create_dataloader(accelerator, dataset_size=3, batch_size=1) + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore") + try: + with accelerator.join_uneven_inputs([ddp_model], even_batches=overridden_even_batches): + batch_dl_overridden_value = batch_dl.batch_sampler.even_batches + except AttributeError: + # ensure attribute error is not raised when processing iterable dl + raise AssertionError + + assert batch_dl_overridden_value == overridden_even_batches + assert batch_dl.batch_sampler.even_batches == default_even_batches + + +def test_join_raises_warning_for_iterable_when_overriding_even_batches(): + accelerator = create_accelerator() + model = torch.nn.Linear(1, 1) + ddp_model = accelerator.prepare(model) + create_dataloader(accelerator, dataset_size=3, batch_size=1, iterable=True) + + with warnings.catch_warnings(record=True) as w: + with accelerator.join_uneven_inputs([ddp_model], even_batches=False): + pass + + assert issubclass(w[-1].category, UserWarning) + assert "only supported for map-style datasets" in str(w[-1].message) + + +def test_pickle_accelerator(): + accelerator = create_accelerator() + data_loader = create_dataloader(accelerator, dataset_size=32, batch_size=4) + _ = accelerator.prepare(data_loader) + pickled_accelerator = pickle.dumps(accelerator) + unpickled_accelerator = pickle.loads(pickled_accelerator) + # TODO: Maybe this should be implemented as __eq__ for AcceleratorState? + assert accelerator.state.__dict__ == unpickled_accelerator.state.__dict__ + + +def test_data_loader(data_loader, accelerator): + # Prepare the DataLoader + data_loader = accelerator.prepare(data_loader) + + all_examples = [] + for i, batch in enumerate(data_loader): + index, _ = accelerator.gather_for_metrics((batch["index"], batch["label"])) + all_examples.extend(index.detach().cpu().numpy().tolist()) + + # Sort the examples + sorted_all_examples = sorted(all_examples) + + # Check if all elements are present in the sorted list of iterated samples + assert len(set(sorted_all_examples)) == NUM_ELEMENTS, ( + "Not all the dataset elements have been iterated in an epoch due to duplication of samples across processes." + ) + + +def _test_stateful_dataloader_resume(accelerator, iterable): + """ + Helper: iterate a stateful dataloader, save state after a few batches using `load_state_dict`, + resume from the saved state, and verify the resumed batches match what was originally unseen. + + Saves early (after 3 batches) so many batches remain, exposing any off-by-one in state restoration. + Tested with both iterable and map-style datasets to cover different state_dict code paths. + """ + old_dataloader_config = accelerator.dataloader_config + try: + accelerator.dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=True) + prepared_dl = create_dataloader( + accelerator, dataset_size=32 * accelerator.num_processes, batch_size=4, iterable=iterable, shuffle=True + ) + untrained_batches = [] + save_step = 2 + for step, batch in enumerate(prepared_dl): + if step == save_step: + state_dict = prepared_dl.state_dict() + if step > save_step: + untrained_batches.append(batch) + not_skipped_batches = accelerator.gather(untrained_batches) + prepared_dl.load_state_dict(state_dict) + resumed_batches = [] + for batch in prepared_dl: + resumed_batches.append(batch) + resumed_batches = accelerator.gather(resumed_batches) + assert len(not_skipped_batches) == len(resumed_batches), ( + f"Expected {len(not_skipped_batches)} batches after resume, got {len(resumed_batches)}" + ) + for b1, b2 in zip(not_skipped_batches, resumed_batches): + for v1, v2 in zip(b1, b2): + assert torch.equal(v1, v2), f"Batch {b1} and {b2} are not equal" + finally: + accelerator.dataloader_config = old_dataloader_config + + +def test_stateful_dataloader(accelerator): + """ + Tests that a stateful dataloader can be iterated over, saved after a few batches using `load_state_dict`, and then + resumed from the saved state. + + The result should be the same as the rest of the data that iterated over after saving. + """ + _test_stateful_dataloader_resume(accelerator, iterable=True) + _test_stateful_dataloader_resume(accelerator, iterable=False) + + +def _test_stateful_dataloader_save_state_resume(accelerator, iterable): + """ + Helper: iterate a stateful dataloader, save state after a few batches using `Accelerator.save_state`, + resume, and verify the resumed batches match what was originally unseen. + """ + old_dataloader_config = accelerator.dataloader_config + try: + with tempfile.TemporaryDirectory() as tmpdir: + accelerator.dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=True) + prepared_dl = create_dataloader( + accelerator, dataset_size=32 * accelerator.num_processes, batch_size=4, iterable=iterable, shuffle=True + ) + untrained_batches = [] + save_step = 2 + for step, batch in enumerate(prepared_dl): + if step == save_step: + accelerator.save_state(tmpdir) + if step > save_step: + untrained_batches.append(batch) + not_skipped_batches = accelerator.gather(untrained_batches) + accelerator.load_state(tmpdir) + resumed_batches = [] + for batch in prepared_dl: + resumed_batches.append(batch) + resumed_batches = accelerator.gather(resumed_batches) + assert len(not_skipped_batches) == len(resumed_batches), ( + f"Expected {len(not_skipped_batches)} batches after resume, got {len(resumed_batches)}" + ) + for b1, b2 in zip(not_skipped_batches, resumed_batches): + for v1, v2 in zip(b1, b2): + assert torch.equal(v1, v2), f"Batch {b1} and {b2} are not equal" + finally: + accelerator.dataloader_config = old_dataloader_config + + +def test_stateful_dataloader_save_state(accelerator): + """ + Tests that a stateful dataloader can be iterated over, saved after a few batches using `Accelerator.save_state`, + and then resumed from the saved state. + + The result should be the same as the rest of the data that iterated over after saving. + """ + _test_stateful_dataloader_save_state_resume(accelerator, iterable=True) + _test_stateful_dataloader_save_state_resume(accelerator, iterable=False) + + +def main(): + accelerator = create_accelerator() + torch.manual_seed(accelerator.process_index) + + accelerator.print("Test that even_batches variable ensures uniform batches across processes") + test_default_ensures_even_batch_sizes() + + accelerator.print("Run tests with even_batches disabled") + test_can_disable_even_batches() + + accelerator.print("Test joining uneven inputs") + test_can_join_uneven_inputs() + + accelerator.print("Test overriding even_batches when joining uneven inputs") + test_join_can_override_even_batches() + + accelerator.print("Test overriding even_batches for mixed dataloader types") + test_join_can_override_for_mixed_type_dataloaders() + + accelerator.print("Test overriding even_batches raises a warning for iterable dataloaders") + test_join_raises_warning_for_iterable_when_overriding_even_batches() + + accelerator.print("Test join with non DDP distributed raises warning") + original_state = accelerator.state.distributed_type + accelerator.state.distributed_type = DistributedType.FSDP + test_join_raises_warning_for_non_ddp_distributed(accelerator) + accelerator.state.distributed_type = original_state + + accelerator.print("Test pickling an accelerator") + test_pickle_accelerator() + + dataset = DummyDataset() + + accelerator.print("Test DataLoader with shuffle=False") + loader = DataLoader(dataset, shuffle=False, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS) + test_data_loader(loader, accelerator) + + accelerator.print("Test DataLoader with shuffle=True") + loader = DataLoader(dataset, shuffle=True, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS) + test_data_loader(loader, accelerator) + + accelerator.print("Test DataLoader with batch_sampler") + sampler = BatchSampler(RandomSampler(dataset), batch_size=BATCH_SIZE, drop_last=False) + loader = DataLoader(dataset, batch_sampler=sampler, num_workers=NUM_WORKERS) + test_data_loader(loader, accelerator) + + accelerator.print("Test DataLoader with sampler as an instance of `BatchSampler`") + sampler = BatchSampler(RandomSampler(dataset), batch_size=BATCH_SIZE, drop_last=False) + loader = DataLoader(dataset, sampler=sampler, batch_size=None, collate_fn=default_collate, num_workers=NUM_WORKERS) + test_data_loader(loader, accelerator) + test_stateful_dataloader(accelerator) + test_stateful_dataloader_save_state(accelerator) + + accelerator.end_training() + + +if __name__ == "__main__": + main() diff --git a/accelerate/test_utils/scripts/test_merge_weights.py b/accelerate/test_utils/scripts/test_merge_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..f280c8fa17919367001389a8efcd78faaa041f0c --- /dev/null +++ b/accelerate/test_utils/scripts/test_merge_weights.py @@ -0,0 +1,158 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import logging +import shutil +from pathlib import Path + +import torch +from safetensors.torch import load_file +from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy, StateDictType +from torch.utils.data import DataLoader + +from accelerate import Accelerator, FullyShardedDataParallelPlugin +from accelerate.commands.merge import merge_command, merge_command_parser +from accelerate.state import AcceleratorState +from accelerate.test_utils import torch_device +from accelerate.test_utils.training import RegressionDataset +from accelerate.utils import merge_fsdp_weights, patch_environment, save_fsdp_model + + +logging.basicConfig(level=logging.INFO) + +parser = merge_command_parser() + + +class TinyModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(16, 16) + self.activation = torch.nn.ReLU() + self.linear2 = torch.nn.Linear(16, 16) + self.softmax = torch.nn.Softmax() + + def forward(self, x): + return self.linear2(self.activation(self.linear1(x))) + + +def setup(): + if AcceleratorState._shared_state != {}: + AcceleratorState()._reset_state() + plugin = FullyShardedDataParallelPlugin( + sharding_strategy=ShardingStrategy.FULL_SHARD, state_dict_type=StateDictType.SHARDED_STATE_DICT + ) + model = TinyModel() + with patch_environment(fsdp_auto_wrap_policy="SIZE_BASED_WRAP"): + plugin.set_auto_wrap_policy(model) + accelerator = Accelerator(fsdp_plugin=plugin) + model = accelerator.prepare(model) + return model, plugin, accelerator + + +def mock_training(accelerator, model): + train_set = RegressionDataset(length=128, seed=42) + train_dl = DataLoader(train_set, batch_size=16, shuffle=False) + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + + train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer) + for _ in range(3): + for batch in train_dl: + model.zero_grad() + output = model(batch["x"]) + loss = torch.nn.functional.mse_loss(output, batch["y"]) + accelerator.backward(loss) + optimizer.step() + return model + + +def check_weights(operation, state_1, state_2): + for weight_1, weight_2 in zip(state_1.values(), state_2.values()): + if operation == "same": + assert torch.allclose(weight_1, weight_2) + else: + assert not torch.allclose(weight_1, weight_2) + + +def check_safetensors_weights(path, model): + safe_state_dict = load_file(path / "model.safetensors") + safe_loaded_model = TinyModel().to(torch_device) + check_weights("diff", model.state_dict(), safe_loaded_model.state_dict()) + safe_loaded_model.load_state_dict(safe_state_dict) + check_weights("same", model.state_dict(), safe_loaded_model.state_dict()) + + +def check_pytorch_weights(path, model): + nonsafe_state_dict = torch.load(path / "pytorch_model.bin", weights_only=True) + nonsafe_loaded_model = TinyModel().to(torch_device) + check_weights("diff", model.state_dict(), nonsafe_loaded_model.state_dict()) + nonsafe_loaded_model.load_state_dict(nonsafe_state_dict) + check_weights("same", model.state_dict(), nonsafe_loaded_model.state_dict()) + + +def test_merge_weights_safetensors(model, path): + # Should now be saved at `path/merged.safetensors` + merge_fsdp_weights(path / "pytorch_model_fsdp_0", path, safe_serialization=True) + check_safetensors_weights(path, model) + + +def test_merge_weights_command_safetensors(model, path): + args = parser.parse_args([str(path / "pytorch_model_fsdp_0"), str(path)]) + merge_command(args) + check_safetensors_weights(path, model) + + +def test_merge_weights_pytorch(model, path): + # Should now be saved at `path/merged.bin` + merge_fsdp_weights(path / "pytorch_model_fsdp_0", path, safe_serialization=False) + check_pytorch_weights(path, model) + + +def test_merge_weights_command_pytorch(model, path): + args = parser.parse_args([str(path / "pytorch_model_fsdp_0"), str(path), "--unsafe_serialization"]) + merge_command(args) + check_pytorch_weights(path, model) + + +if __name__ == "__main__": + # Note this test requires at least two accelerators! + model, plugin, accelerator = setup() + if accelerator.num_processes > 1: + try: + # Initial setup for things + out_path = Path("test_merge_weights_fsdp_weights") + if not out_path.exists(): + out_path.mkdir(parents=True, exist_ok=True) + + # Train briefly once weights aren't the baseline + model = mock_training(accelerator, model) + accelerator.wait_for_everyone() + + gc.collect() # Needed for some lingering refs after training + save_fsdp_model(plugin, accelerator, model, out_path) + accelerator.wait_for_everyone() + + # Finally we can test + test_merge_weights_safetensors(model, out_path) + test_merge_weights_command_safetensors(model, out_path) + test_merge_weights_pytorch(model, out_path) + test_merge_weights_command_pytorch(model, out_path) + except Exception: + raise + finally: + # Cleanup in case of any failures + if accelerator.is_main_process: + shutil.rmtree(out_path) + accelerator.wait_for_everyone() + accelerator.end_training() diff --git a/accelerate/test_utils/scripts/test_notebook.py b/accelerate/test_utils/scripts/test_notebook.py new file mode 100644 index 0000000000000000000000000000000000000000..bc75a861758e20fcec3196ea75c556679a2632b3 --- /dev/null +++ b/accelerate/test_utils/scripts/test_notebook.py @@ -0,0 +1,125 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Test file to ensure that in general certain situational setups for notebooks work. +""" + +import os +import time + +from pytest import mark, raises +from torch.distributed.elastic.multiprocessing.errors import ChildFailedError + +from accelerate import PartialState, notebook_launcher +from accelerate.test_utils import require_bnb +from accelerate.utils import is_bnb_available, is_xpu_available + + +def basic_function(): + # Just prints the PartialState + print(f"PartialState:\n{PartialState()}") + + +def tough_nut_function(queue): + if queue.empty(): + return + trial = queue.get() + if trial > 0: + queue.put(trial - 1) + raise RuntimeError("The nut hasn't cracked yet! Try again.") + + print(f"PartialState:\n{PartialState()}") + + +def bipolar_sleep_function(sleep_sec: int): + state = PartialState() + if state.process_index % 2 == 0: + raise RuntimeError("I'm an even process. I don't like to sleep.") + else: + time.sleep(sleep_sec) + + +NUM_PROCESSES = int(os.environ.get("ACCELERATE_NUM_PROCESSES", 1)) + + +def test_can_initialize(): + notebook_launcher(basic_function, (), num_processes=NUM_PROCESSES) + + +@mark.skipif(NUM_PROCESSES < 2, reason="Need at least 2 processes to test static rendezvous backends") +def test_static_rdzv_backend(): + notebook_launcher(basic_function, (), num_processes=NUM_PROCESSES, rdzv_backend="static") + + +@mark.skipif(NUM_PROCESSES < 2, reason="Need at least 2 processes to test c10d rendezvous backends") +def test_c10d_rdzv_backend(): + notebook_launcher(basic_function, (), num_processes=NUM_PROCESSES, rdzv_backend="c10d") + + +@mark.skipif(NUM_PROCESSES < 2, reason="Need at least 2 processes to test fault tolerance") +def test_fault_tolerant(max_restarts: int = 3): + # Use torch.multiprocessing to get the right context for the current device + import torch.multiprocessing as mp + + # Get appropriate context - 'spawn' for XPU, 'fork' for others + if is_xpu_available(): + ctx = mp.get_context("spawn") + else: + ctx = mp.get_context("fork") + queue = ctx.Queue() + queue.put(max_restarts) + notebook_launcher(tough_nut_function, (queue,), num_processes=NUM_PROCESSES, max_restarts=max_restarts) + + +@mark.skipif(NUM_PROCESSES < 2, reason="Need at least 2 processes to test monitoring") +def test_monitoring(monitor_interval: float = 0.01, sleep_sec: int = 100): + start_time = time.time() + with raises(ChildFailedError, match="I'm an even process. I don't like to sleep."): + notebook_launcher( + bipolar_sleep_function, + (sleep_sec,), + num_processes=NUM_PROCESSES, + monitor_interval=monitor_interval, + ) + assert time.time() - start_time < sleep_sec, "Monitoring did not stop the process in time." + + +@require_bnb +def test_problematic_imports(): + with raises(RuntimeError, match="Please keep these imports"): + import bitsandbytes as bnb # noqa: F401 + + notebook_launcher(basic_function, (), num_processes=NUM_PROCESSES) + + +def main(): + print("Test basic notebook can be ran") + test_can_initialize() + print("Test static rendezvous backend") + test_static_rdzv_backend() + print("Test c10d rendezvous backend") + test_c10d_rdzv_backend() + print("Test fault tolerant") + test_fault_tolerant() + print("Test monitoring") + test_monitoring() + if is_bnb_available(): + print("Test problematic imports (bnb)") + test_problematic_imports() + if NUM_PROCESSES > 1: + PartialState().destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/accelerate/test_utils/scripts/test_ops.py b/accelerate/test_utils/scripts/test_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..1aa9b095e257bf63623b5127624b54c4a376c5ff --- /dev/null +++ b/accelerate/test_utils/scripts/test_ops.py @@ -0,0 +1,181 @@ +#!/usr/bin/env python + +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from accelerate import PartialState +from accelerate.test_utils.testing import assert_exception +from accelerate.utils.dataclasses import DistributedType +from accelerate.utils.operations import ( + DistributedOperationException, + broadcast, + copy_tensor_to_devices, + gather, + gather_object, + pad_across_processes, + reduce, +) + + +def create_tensor(state): + return (torch.arange(state.num_processes) + 1.0 + (state.num_processes * state.process_index)).to(state.device) + + +def test_gather(state): + tensor = create_tensor(state) + gathered_tensor = gather(tensor) + assert gathered_tensor.tolist() == list(range(1, state.num_processes**2 + 1)) + + +def test_gather_object(state): + # Gather objects in TorchXLA is not supported. + if state.distributed_type == DistributedType.XLA: + return + obj = [state.process_index] + gathered_obj = gather_object(obj) + assert len(gathered_obj) == state.num_processes, f"{gathered_obj}, {len(gathered_obj)} != {state.num_processes}" + assert gathered_obj == list(range(state.num_processes)), f"{gathered_obj} != {list(range(state.num_processes))}" + + +def test_gather_non_contiguous(state): + # Skip this test because the 'is_contiguous' function of XLA tensor always returns True. + if state.distributed_type == DistributedType.XLA: + return + + # Create a non-contiguous tensor (enforce non-contiguity after device memory allocation) + tensor = torch.arange(12, device=state.device).view(4, 3).t() + assert not tensor.is_contiguous() + # Shouldn't error out + _ = gather(tensor) + + +def test_broadcast(state): + tensor = create_tensor(state) + broadcasted_tensor = broadcast(tensor) + assert broadcasted_tensor.shape == torch.Size([state.num_processes]) + assert broadcasted_tensor.tolist() == list(range(1, state.num_processes + 1)) + + +def test_pad_across_processes(state): + # We need to pad the tensor with one more element if we are the main process + # to ensure that we can pad + if state.is_main_process: + tensor = torch.arange(state.num_processes + 1).to(state.device) + else: + tensor = torch.arange(state.num_processes).to(state.device) + padded_tensor = pad_across_processes(tensor) + assert padded_tensor.shape == torch.Size([state.num_processes + 1]) + if not state.is_main_process: + assert padded_tensor.tolist() == list(range(0, state.num_processes)) + [0] + + +def test_reduce_sum(state): + # For now runs on only two processes + if state.num_processes != 2: + return + tensor = create_tensor(state) + reduced_tensor = reduce(tensor, "sum") + truth_tensor = torch.tensor([4.0, 6]).to(state.device) + assert torch.allclose(reduced_tensor, truth_tensor), f"{reduced_tensor} != {truth_tensor}" + + +def test_reduce_mean(state): + # For now runs on only two processes + if state.num_processes != 2: + return + tensor = create_tensor(state) + reduced_tensor = reduce(tensor, "mean") + truth_tensor = torch.tensor([2.0, 3]).to(state.device) + assert torch.allclose(reduced_tensor, truth_tensor), f"{reduced_tensor} != {truth_tensor}" + + +def test_op_checker(state): + # Must be in a distributed state, and gathering is currently not supported in TorchXLA. + if state.distributed_type in [DistributedType.NO, DistributedType.XLA]: + return + state.debug = True + # `pad_across_processes` + if state.process_index == 0: + data = {"tensor": torch.tensor([[0.0, 1, 2, 3, 4]]).to(state.device)} + else: + data = {"tensor": torch.tensor([[[0.0, 1, 2, 3, 4, 5]]]).to(state.device)} + + with assert_exception(DistributedOperationException): + pad_across_processes(data, dim=0) + + # `reduce` + if state.process_index == 0: + data = {"tensor": torch.tensor([[0.0, 1, 2, 3, 4]]).to(state.device)} + else: + data = {"tensor": torch.tensor([[[0.0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]]).to(state.device)} + + with assert_exception(DistributedOperationException): + reduce(data) + + # `broadcast` + if state.process_index == 0: + data = {"tensor": torch.tensor([[0.0, 1, 2, 3, 4]]).to(state.device)} + else: + data = {"tensor": torch.tensor([[[0.0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]]).to(state.device)} + + with assert_exception(DistributedOperationException): + broadcast(data) + + state.debug = False + + +def test_copy_tensor_to_devices(state): + if state.distributed_type not in [DistributedType.MULTI_GPU, DistributedType.XLA]: + return + if state.is_main_process: + tensor = torch.tensor([1, 2, 3], dtype=torch.int).to(state.device) + else: + tensor = None + tensor = copy_tensor_to_devices(tensor) + assert torch.allclose(tensor, torch.tensor([1, 2, 3], dtype=torch.int, device=state.device)) + + +def _mp_fn(index): + # For xla_spawn (TPUs) + main() + + +def main(): + state = PartialState() + state.print(f"State: {state}") + state.print("testing gather") + test_gather(state) + state.print("testing gather_object") + test_gather_object(state) + state.print("testing gather non-contiguous") + test_gather_non_contiguous(state) + state.print("testing broadcast") + test_broadcast(state) + state.print("testing pad_across_processes") + test_pad_across_processes(state) + state.print("testing reduce_sum") + test_reduce_sum(state) + state.print("testing reduce_mean") + test_reduce_mean(state) + state.print("testing op_checker") + test_op_checker(state) + state.print("testing sending tensors across devices") + test_copy_tensor_to_devices(state) + state.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/accelerate/test_utils/scripts/test_script.py b/accelerate/test_utils/scripts/test_script.py new file mode 100644 index 0000000000000000000000000000000000000000..e53e12a18fa79c221640e2e73bdf83c552f66bfa --- /dev/null +++ b/accelerate/test_utils/scripts/test_script.py @@ -0,0 +1,909 @@ +#!/usr/bin/env python + +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import io +import math +import time +from copy import deepcopy +from pathlib import Path + +import numpy as np +import torch +from torch.utils.data import DataLoader, Dataset + +from accelerate import Accelerator +from accelerate.data_loader import SeedableRandomSampler, prepare_data_loader +from accelerate.state import AcceleratorState +from accelerate.test_utils import RegressionDataset, RegressionModel, are_the_same_tensors +from accelerate.utils import ( + DataLoaderConfiguration, + DistributedType, + gather, + gather_object, + is_bf16_available, + is_cuda_available, + is_datasets_available, + is_fp16_available, + is_hpu_available, + is_mps_available, + is_pytest_available, + set_seed, + synchronize_rng_states, +) + + +if is_hpu_available(): + ATOL = 1e-3 + RTOL = 1e-3 +else: + ATOL = 1e-6 + RTOL = 1e-6 + + +def generate_baseline_dataloader(train_set, generator, batch_size, use_seedable_sampler=False): + "Creates a dataloader that can also use the `SeedableRandomSampler`" + if use_seedable_sampler: + # The SeedableRandomSampler is needed during distributed setups + # for full reproducibility across processes with the `DataLoader` + sampler = SeedableRandomSampler( + generator=generator, + data_source=train_set, + num_samples=len(train_set), + ) + return DataLoader(train_set, batch_size=batch_size, sampler=sampler) + else: + return DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator) + + +def print_main(state): + print(f"Printing from the main process {state.process_index}") + + +def print_local_main(state): + print(f"Printing from the local main process {state.local_process_index}") + + +def print_last(state): + print(f"Printing from the last process {state.process_index}") + + +def print_on(state, process_idx): + print(f"Printing from process {process_idx}: {state.process_index}") + + +def process_execution_check(): + accelerator = Accelerator() + num_processes = accelerator.num_processes + # Test main_process_first context manager + path = Path("check_main_process_first.txt") + with accelerator.main_process_first(): + if accelerator.is_main_process: + time.sleep(0.1) # ensure main process takes longest + with open(path, "a+") as f: + f.write("Currently in the main process\n") + else: + with open(path, "a+") as f: + f.write("Now on another process\n") + accelerator.wait_for_everyone() + + if accelerator.is_main_process: + with open(path) as f: + text = "".join(f.readlines()) + try: + assert text.startswith("Currently in the main process\n"), "Main process was not first" + if num_processes > 1: + assert text.endswith("Now on another process\n"), "Main process was not first" + assert text.count("Now on another process\n") == accelerator.num_processes - 1, ( + f"Only wrote to file {text.count('Now on another process') + 1} times, not {accelerator.num_processes}" + ) + except AssertionError: + path.unlink() + raise + + if accelerator.is_main_process and path.exists(): + path.unlink() + accelerator.wait_for_everyone() + # Test the decorators + f = io.StringIO() + with contextlib.redirect_stdout(f): + accelerator.on_main_process(print_main)(accelerator.state) + result = f.getvalue().rstrip() + if accelerator.is_main_process: + assert result == "Printing from the main process 0", f"{result} != Printing from the main process 0" + else: + assert f.getvalue().rstrip() == "", f'{result} != ""' + f.truncate(0) + f.seek(0) + + with contextlib.redirect_stdout(f): + accelerator.on_local_main_process(print_local_main)(accelerator.state) + if accelerator.is_local_main_process: + assert f.getvalue().rstrip() == "Printing from the local main process 0" + else: + assert f.getvalue().rstrip() == "" + f.truncate(0) + f.seek(0) + + with contextlib.redirect_stdout(f): + accelerator.on_last_process(print_last)(accelerator.state) + if accelerator.is_last_process: + assert f.getvalue().rstrip() == f"Printing from the last process {accelerator.state.num_processes - 1}" + else: + assert f.getvalue().rstrip() == "" + f.truncate(0) + f.seek(0) + + for process_idx in range(num_processes): + with contextlib.redirect_stdout(f): + accelerator.on_process(print_on, process_index=process_idx)(accelerator.state, process_idx) + if accelerator.process_index == process_idx: + assert f.getvalue().rstrip() == f"Printing from process {process_idx}: {accelerator.process_index}" + else: + assert f.getvalue().rstrip() == "" + f.truncate(0) + f.seek(0) + + +def init_state_check(): + # Test we can instantiate this twice in a row. + state = AcceleratorState() + if state.local_process_index == 0: + print("Testing, testing. 1, 2, 3.") + print(state) + + +def rng_sync_check(): + state = AcceleratorState() + synchronize_rng_states(["torch"]) + assert are_the_same_tensors(torch.get_rng_state()), "RNG states improperly synchronized on CPU." + if state.distributed_type == DistributedType.MULTI_GPU: + synchronize_rng_states(["cuda"]) + assert are_the_same_tensors(torch.cuda.get_rng_state()), "RNG states improperly synchronized on GPU." + elif state.distributed_type == DistributedType.MULTI_XPU: + synchronize_rng_states(["xpu"]) + assert are_the_same_tensors(torch.xpu.get_rng_state()), "RNG states improperly synchronized on XPU." + generator = torch.Generator() + synchronize_rng_states(["generator"], generator=generator) + assert are_the_same_tensors(generator.get_state()), "RNG states improperly synchronized in generator." + + if state.local_process_index == 0: + print("All rng are properly synched.") + + +def dl_preparation_check(): + state = AcceleratorState() + length = 32 * state.num_processes + + dl = DataLoader(range(length), batch_size=8) + dl = prepare_data_loader(dl, state.device, state.num_processes, state.process_index, put_on_device=True) + result = [] + for batch in dl: + result.append(gather(batch)) + result = torch.cat(result) + + assert torch.equal(result.cpu(), torch.arange(0, length).long()), "Wrong non-shuffled dataloader result." + + dl = DataLoader(range(length), batch_size=8) + dl = prepare_data_loader( + dl, + state.device, + state.num_processes, + state.process_index, + put_on_device=True, + split_batches=True, + ) + result = [] + for batch in dl: + result.append(gather(batch)) + result = torch.cat(result) + assert torch.equal(result.cpu(), torch.arange(0, length).long()), "Wrong non-shuffled dataloader result." + + if state.process_index == 0: + print("Non-shuffled dataloader passing.") + + dl = DataLoader(range(length), batch_size=8, shuffle=True) + dl = prepare_data_loader(dl, state.device, state.num_processes, state.process_index, put_on_device=True) + result = [] + for batch in dl: + result.append(gather(batch)) + result = torch.cat(result).tolist() + result.sort() + assert result == list(range(length)), "Wrong shuffled dataloader result." + + dl = DataLoader(range(length), batch_size=8, shuffle=True) + dl = prepare_data_loader( + dl, + state.device, + state.num_processes, + state.process_index, + put_on_device=True, + split_batches=True, + ) + result = [] + for batch in dl: + result.append(gather(batch)) + result = torch.cat(result).tolist() + result.sort() + assert result == list(range(length)), "Wrong shuffled dataloader result." + + if state.local_process_index == 0: + print("Shuffled dataloader passing.") + + +def central_dl_preparation_check(): + state = AcceleratorState() + length = 32 * state.num_processes + + dl = DataLoader(range(length), batch_size=8) + dl = prepare_data_loader( + dl, state.device, state.num_processes, state.process_index, put_on_device=True, dispatch_batches=True + ) + result = [] + for batch in dl: + result.append(gather(batch)) + result = torch.cat(result) + assert torch.equal(result.cpu(), torch.arange(0, length).long()), "Wrong non-shuffled dataloader result." + + dl = DataLoader(range(length), batch_size=8) + dl = prepare_data_loader( + dl, + state.device, + state.num_processes, + state.process_index, + put_on_device=True, + split_batches=True, + dispatch_batches=True, + ) + result = [] + for batch in dl: + result.append(gather(batch)) + result = torch.cat(result) + assert torch.equal(result.cpu(), torch.arange(0, length).long()), "Wrong non-shuffled dataloader result." + + if state.process_index == 0: + print("Non-shuffled central dataloader passing.") + + dl = DataLoader(range(length), batch_size=8, shuffle=True) + dl = prepare_data_loader( + dl, state.device, state.num_processes, state.process_index, put_on_device=True, dispatch_batches=True + ) + result = [] + for batch in dl: + result.append(gather(batch)) + result = torch.cat(result).tolist() + result.sort() + assert result == list(range(length)), "Wrong shuffled dataloader result." + + dl = DataLoader(range(length), batch_size=8, shuffle=True) + dl = prepare_data_loader( + dl, + state.device, + state.num_processes, + state.process_index, + put_on_device=True, + split_batches=True, + dispatch_batches=True, + ) + result = [] + for batch in dl: + result.append(gather(batch)) + result = torch.cat(result).tolist() + result.sort() + assert result == list(range(length)), "Wrong shuffled dataloader result." + + if state.local_process_index == 0: + print("Shuffled central dataloader passing.") + + +def custom_sampler_check(): + state = AcceleratorState() + + class CustomDataset(Dataset): + def __init__(self, data): + self.data = data + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + return self.data[index] + + class CustomBatchSampler: + def __init__(self, dataset_length: int, batch_size: int, shuffle: bool = True): + self.batch_size = batch_size + self.data_index = np.arange(dataset_length) + self.shuffle = shuffle + + def __iter__(self): + num_batches = len(self) + if self.shuffle: + index = np.random.permutation(self.data_index) + else: + index = self.data_index + output = np.array_split(index, num_batches) + yield from output + + def __len__(self): + return math.ceil(len(self.data_index) / self.batch_size) + + dataset = CustomDataset(range(32 * state.num_processes)) + sampler = CustomBatchSampler(len(dataset), batch_size=8) + dl = DataLoader(dataset, batch_sampler=sampler) + dl = prepare_data_loader(dl, state.device, state.num_processes, state.process_index) + # We need just ensure that `dl.batch_sampler` (or `dl.batch_sampler.batch_sampler` is indeed the old batch sampler + if hasattr(dl.batch_sampler, "batch_sampler"): + assert isinstance(dl.batch_sampler.batch_sampler, CustomBatchSampler), ( + "Custom sampler was changed after calling `prepare_data_loader`" + ) + else: + assert isinstance(dl.batch_sampler, CustomBatchSampler), ( + "Custom sampler was changed after calling `prepare_data_loader`" + ) + + +def check_seedable_sampler(): + # Set seed + set_seed(42) + train_set = RegressionDataset(length=10, seed=42) + train_dl = DataLoader(train_set, batch_size=2, shuffle=True) + + config = DataLoaderConfiguration(use_seedable_sampler=True) + accelerator = Accelerator(dataloader_config=config) + train_dl = accelerator.prepare(train_dl) + original_items = [] + for _ in range(3): + for batch in train_dl: + original_items.append(batch["x"]) + original_items = torch.cat(original_items) + + # Set seed again and the epoch + set_seed(42) + train_dl.set_epoch(0) + new_items = [] + for _ in range(3): + for batch in train_dl: + new_items.append(batch["x"]) + new_items = torch.cat(new_items) + assert torch.allclose(original_items, new_items), "Did not obtain the same items with the same seed and epoch." + + +def check_seedable_sampler_in_batch_sampler_shard(): + set_seed(42) + + config = DataLoaderConfiguration(use_seedable_sampler=True) + accelerator = Accelerator(dataloader_config=config) + assert accelerator.num_processes > 1, "This test requires more than one process." + + dataloader = DataLoader(list(range(10)), batch_size=1, shuffle=True) + prepared_data_loader = prepare_data_loader( + dataloader=dataloader, + use_seedable_sampler=True, + ) + + target_sampler = prepared_data_loader.batch_sampler.batch_sampler.sampler + assert isinstance(target_sampler, SeedableRandomSampler), ( + "Sampler in BatchSamplerShard is not SeedableRandomSampler." + ) + + +def check_seedable_sampler_with_data_seed(): + # Set seed + set_seed(42) + data_seed = 42 + train_set = RegressionDataset(length=10, seed=42) + train_dl = DataLoader(train_set, batch_size=2, shuffle=True) + + config = DataLoaderConfiguration(use_seedable_sampler=True, data_seed=data_seed) + accelerator = Accelerator(dataloader_config=config) + prepared_dl = accelerator.prepare(train_dl) + original_items = [] + for _ in range(3): + for batch in prepared_dl: + original_items.append(batch["x"]) + original_items = torch.cat(original_items) + + # Set new data seed + config.data_seed = 43 + accelerator = Accelerator(dataloader_config=config) + prepared_dl = accelerator.prepare(train_dl) + new_items = [] + for _ in range(3): + for batch in prepared_dl: + new_items.append(batch["x"]) + new_items = torch.cat(new_items) + assert not torch.allclose(original_items, new_items), "Obtained the same items with different data seed." + + +def mock_training(length, batch_size, generator, use_seedable_sampler=False): + set_seed(42) + generator.manual_seed(42) + train_set = RegressionDataset(length=length, seed=42) + + train_dl = generate_baseline_dataloader(train_set, generator, batch_size, use_seedable_sampler) + model = RegressionModel() + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + for epoch in range(3): + for batch in train_dl: + model.zero_grad() + output = model(batch["x"]) + loss = torch.nn.functional.mse_loss(output, batch["y"]) + loss.backward() + optimizer.step() + return train_set, model + + +def training_check(use_seedable_sampler=False): + state = AcceleratorState() + generator = torch.Generator() + batch_size = 8 + length = batch_size * 4 * state.num_processes + + train_set, old_model = mock_training(length, batch_size * state.num_processes, generator, use_seedable_sampler) + assert are_the_same_tensors(old_model.a), "Did not obtain the same model on both processes." + assert are_the_same_tensors(old_model.b), "Did not obtain the same model on both processes." + + accelerator = Accelerator() + train_dl = generate_baseline_dataloader(train_set, generator, batch_size, use_seedable_sampler) + model = RegressionModel() + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + + train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer) + set_seed(42) + generator.manual_seed(42) + for _ in range(3): + for batch in train_dl: + model.zero_grad() + output = model(batch["x"]) + loss = torch.nn.functional.mse_loss(output, batch["y"]) + accelerator.backward(loss) + optimizer.step() + + model = accelerator.unwrap_model(model).cpu() + torch.testing.assert_close( + old_model.a, + model.a, + atol=ATOL, + rtol=RTOL, + msg=lambda msg: f"Did not obtain the same model on CPU or distributed training.\n{msg}", + ) + torch.testing.assert_close( + old_model.b, + model.b, + atol=ATOL, + rtol=RTOL, + msg=lambda msg: f"Did not obtain the same model on CPU or distributed training.\n{msg}", + ) + + accelerator.print("Training yielded the same results on one CPU or distributed setup with no batch split.") + + dataloader_config = DataLoaderConfiguration(split_batches=True, use_seedable_sampler=use_seedable_sampler) + accelerator = Accelerator(dataloader_config=dataloader_config) + train_dl = generate_baseline_dataloader( + train_set, generator, batch_size * state.num_processes, use_seedable_sampler + ) + model = RegressionModel() + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + + train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer) + set_seed(42) + generator.manual_seed(42) + for _ in range(3): + for batch in train_dl: + model.zero_grad() + output = model(batch["x"]) + loss = torch.nn.functional.mse_loss(output, batch["y"]) + accelerator.backward(loss) + optimizer.step() + + model = accelerator.unwrap_model(model).cpu() + torch.testing.assert_close( + old_model.a, + model.a, + atol=ATOL, + rtol=RTOL, + msg=lambda msg: f"Did not obtain the same model on CPU or distributed training.\n{msg}", + ) + torch.testing.assert_close( + old_model.b, + model.b, + atol=ATOL, + rtol=RTOL, + msg=lambda msg: f"Did not obtain the same model on CPU or distributed training.\n{msg}", + ) + + accelerator.print("Training yielded the same results on one CPU or distributed setup with batch split.") + + # FP32 wrapper check + if is_cuda_available() or is_mps_available(): + # Mostly a test that model.forward will have autocast when running unwrap_model(model, keep_fp32_wrapper=True) + print("Keep fp32 wrapper check.") + AcceleratorState._reset_state() + accelerator = Accelerator(mixed_precision="fp16") + + model = torch.nn.Linear(2, 4) + model = accelerator.prepare(model) + model_with_fp32_wrapper = accelerator.unwrap_model(model, keep_fp32_wrapper=True) + + # Run forward with fp16 as input. + # When the model is with mixed precision wrapper, no error will be raised. + input_tensor = torch.Tensor([1, 2]).to(dtype=torch.float16, device=accelerator.device) + output = model_with_fp32_wrapper(input_tensor) + + # BF16 support + if is_bf16_available(): + # Mostly a test that BF16 doesn't crash as the operation inside the model is not converted to BF16 + print("BF16 training check.") + AcceleratorState._reset_state() + dataloader_config = DataLoaderConfiguration(use_seedable_sampler=use_seedable_sampler) + accelerator = Accelerator(mixed_precision="bf16", dataloader_config=dataloader_config) + train_dl = generate_baseline_dataloader(train_set, generator, batch_size, use_seedable_sampler) + model = RegressionModel() + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + + train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer) + set_seed(42) + generator.manual_seed(42) + for _ in range(3): + for batch in train_dl: + model.zero_grad() + output = model(batch["x"]) + loss = torch.nn.functional.mse_loss(output, batch["y"]) + accelerator.backward(loss) + optimizer.step() + + model = accelerator.unwrap_model(model).cpu() + torch.testing.assert_close( + old_model.a, + model.a, + atol=ATOL, + rtol=RTOL, + msg=lambda msg: f"Did not obtain the same model on CPU or distributed training.\n{msg}", + ) + torch.testing.assert_close( + old_model.b, + model.b, + atol=ATOL, + rtol=RTOL, + msg=lambda msg: f"Did not obtain the same model on CPU or distributed training.\n{msg}", + ) + + # FP16 support (HPU fp16 model seems to be off by 10% from the CPU, which is a lot of numerical error) + if is_fp16_available() and not is_hpu_available(): + # Mostly a test that FP16 doesn't crash as the operation inside the model is not converted to FP16 + print("FP16 training check.") + AcceleratorState._reset_state() + dataloader_config = DataLoaderConfiguration(use_seedable_sampler=use_seedable_sampler) + accelerator = Accelerator(mixed_precision="fp16", dataloader_config=dataloader_config) + train_dl = generate_baseline_dataloader(train_set, generator, batch_size, use_seedable_sampler) + model = RegressionModel() + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + + train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer) + set_seed(42) + generator.manual_seed(42) + for _ in range(3): + for batch in train_dl: + model.zero_grad() + output = model(batch["x"]) + loss = torch.nn.functional.mse_loss(output, batch["y"]) + accelerator.backward(loss) + optimizer.step() + + model = accelerator.unwrap_model(model).cpu() + torch.testing.assert_close( + old_model.a, + model.a, + atol=ATOL, + rtol=RTOL, + msg=lambda msg: f"Did not obtain the same model on CPU or distributed training.\n{msg}", + ) + torch.testing.assert_close( + old_model.b, + model.b, + atol=ATOL, + rtol=RTOL, + msg=lambda msg: f"Did not obtain the same model on CPU or distributed training.\n{msg}", + ) + + +def test_split_between_processes_dataset(datasets_Dataset): + state = AcceleratorState() + data = datasets_Dataset.from_list([dict(k=v) for v in range(2 * state.num_processes)]) + with state.split_between_processes(data, apply_padding=False) as results: + assert len(results) == 2, ( + f"Each process did not have two items. Process index: {state.process_index}; Length: {len(results)}" + ) + + data = datasets_Dataset.from_list([dict(k=v) for v in range(2 * state.num_processes - 1)]) + with state.split_between_processes(data, apply_padding=False) as results: + if state.is_last_process: + assert len(results) == 1, ( + f"Last process did not receive a single item. Process index: {state.process_index}; Length: {len(results)}" + ) + else: + assert len(results) == 2, ( + f"One of the intermediate processes did not receive two items. Process index: {state.process_index}; Length: {len(results)}" + ) + state.wait_for_everyone() + + odd_data = datasets_Dataset.from_list([dict(k=v) for v in range(2 * state.num_processes - 1)]) + even_data = datasets_Dataset.from_list([dict(k=v) for v in range(2 * state.num_processes)]) + + for data in [odd_data, even_data]: + expected_output = data["k"] + + with state.split_between_processes(data, apply_padding=True) as results: + if state.num_processes == 1: + assert len(results) == len(data), ( + f"Single process did not receive all items. Process index: {state.process_index}; Length: {len(results)}" + ) + else: + assert len(results) == 2, ( + f"Each process did not have two items. Process index: {state.process_index}; Length: {len(results)}" + ) + + results_per_process = [] + for result in results: + results_per_process.append(result) + + state.wait_for_everyone() + + gathered_results = gather_object(results_per_process) + output = [r["k"] for r in gathered_results[: len(data)]] + + assert expected_output == output, f"Gathered results is incorrect. Expected: {expected_output}; Got: {output}" + + +def test_split_between_processes_list(): + state = AcceleratorState() + data = list(range(0, 2 * state.num_processes)) + with state.split_between_processes(data) as results: + assert len(results) == 2, ( + f"Each process did not have two items. Process index: {state.process_index}; Length: {len(results)}" + ) + state.wait_for_everyone() + + even_data = list(range(0, (2 * state.num_processes))) + odd_data = list(range(0, (2 * state.num_processes) - 1)) + for data in [odd_data, even_data]: + expected_output = data + + with state.split_between_processes(data, apply_padding=True) as results: + num_samples_per_device = math.ceil(len(data) / state.num_processes) + # Test all processes gets the correct number of item(s) + assert len(results) == num_samples_per_device, ( + f"Process {state.device} did not get the correct number of item(s). Process index: {state.process_index}; Length: {len(results)}" + ) + + results_per_process = [] + for result in results: + results_per_process.append(result) + + state.wait_for_everyone() + + gathered_results = gather_object(results_per_process) + output = gathered_results[: len(data)] + + assert expected_output == output, f"Gathered results is incorrect. Expected: {expected_output}; Got: {output}" + + +def test_split_between_processes_nested_dict(): + state = AcceleratorState() + a = [1, 2, 3, 4, 5, 6, 7, 8] + b = ["a", "b", "c", "d", "e", "f", "g", "h"] + c = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]) + if state.num_processes in (1, 2, 4): + data = {"a": a, "b": b, "c": c} + data_copy = deepcopy(data) + with state.split_between_processes(data) as results: + if state.process_index == 0: + assert results["a"] == data_copy["a"][: 8 // state.num_processes] + elif state.num_processes == 2: + assert results["a"] == data_copy["a"][4:] + elif state.process_index == 3: + # We return a list each time + assert results["a"] == data_copy["a"][-2:], f"Expected: {data_copy['a'][-2]}, Actual: {results['a']}" + if state.process_index == 0: + assert results["b"] == data_copy["b"][: 8 // state.num_processes] + elif state.num_processes == 2: + assert results["b"] == data_copy["b"][4:] + elif state.process_index == 3: + assert results["b"] == data_copy["b"][-2:] + if state.process_index == 0: + assert torch.allclose(results["c"], data_copy["c"][: 8 // state.num_processes]), ( + f"Did not obtain expected values on process 0, expected `{data['c'][: 8 // state.num_processes]}`, received: {results['c']}" + ) + elif state.num_processes == 2: + assert torch.allclose(results["c"], data_copy["c"][4:]), ( + f"Did not obtain expected values on process 2, expected `{data['c'][4:]}`, received: {results['c']}" + ) + elif state.process_index == 3: + assert torch.allclose(results["c"], data_copy["c"][-2:]), ( + f"Did not obtain expected values on process 4, expected `{data['c'][-2:]}`, received: {results['c']}" + ) + + state.wait_for_everyone() + + +def test_split_between_processes_tensor(): + state = AcceleratorState() + if state.num_processes > 1: + data = torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7]]).to(state.device) + with state.split_between_processes(data) as results: + if state.process_index == 0: + expected = torch.tensor([[0, 1, 2, 3]]).to(state.device) + else: + expected = torch.tensor([[4, 5, 6, 7]]).to(state.device) + torch.testing.assert_close(results, expected) + state.wait_for_everyone() + + even_data = torch.tensor([[i] for i in range(2 * state.num_processes)]).to(state.device) + odd_data = torch.tensor([[i] for i in range(2 * state.num_processes - 1)]).to(state.device) + for data in [even_data, odd_data]: + expected_output = [torch.tensor(i) for i in data.tolist()] + + with state.split_between_processes(data, apply_padding=True) as results: + num_samples_per_device = math.ceil(len(data) / state.num_processes) + assert len(results) == num_samples_per_device, ( + f"Process {state.device} did not get the correct number of item(s). Process index: {state.process_index}; Length: {len(results)}" + ) + results_per_process = [] + for result in results: + results_per_process.append(result.to("cpu")) + + state.wait_for_everyone() + + gathered_results = gather_object(results_per_process) + output = gathered_results[: len(data)] + + assert expected_output == output, f"Gathered results is incorrect. Expected: {expected_output}; Got: {output}" + + +def test_split_between_processes_evenly(): + state = AcceleratorState() + if state.num_processes in (1, 2, 4, 8): + data = list(range(17)) + num_samples_per_process = len(data) // state.num_processes + num_extras = len(data) % state.num_processes + with state.split_between_processes(data) as results: + if state.process_index < num_extras: + assert len(results) == num_samples_per_process + 1, ( + f"Each Process should have even elements. Expected: {num_samples_per_process + 1}, Actual: {len(results)}" + ) + else: + assert len(results) == num_samples_per_process, ( + f"Each Process should have even elements. Expected: {num_samples_per_process}, Actual: {len(results)}" + ) + state.wait_for_everyone() + + +def test_trigger(): + accelerator = Accelerator() + # should start with being false + assert accelerator.check_trigger() is False + + # set a breakpoint on the main process + if accelerator.is_main_process: + accelerator.set_trigger() + + # check it's been activated across all processes + # calls `all_reduce` and triggers a sync + assert accelerator.check_trigger() is True + + # check it's been reset after the sync + assert accelerator.check_trigger() is False + + +def test_reinstantiated_state(): + import pytest + + AcceleratorState._reset_state() + simple_model = torch.nn.Linear(1, 1) + # First define an accelerator + accelerator = Accelerator() + # Then call `reset_state`, breaking the state existing in the accelerator + AcceleratorState._reset_state() + # Now try and prepare a simple model, should raise the custom error early + with pytest.raises(AttributeError) as cm: + accelerator.prepare(simple_model) + assert "`AcceleratorState` object has no attribute" in str(cm.value.args[0]) + assert "This happens if `AcceleratorState._reset_state()`" in str(cm.value.args[0]) + + +def main(): + accelerator = Accelerator() + state = accelerator.state + if state.local_process_index == 0: + print("**Initialization**") + init_state_check() + state.wait_for_everyone() + + if state.distributed_type == DistributedType.MULTI_GPU: + num_processes_per_node = torch.cuda.device_count() + else: + num_processes_per_node = state.num_processes + + # We only run this test on non-multinode + if num_processes_per_node == state.num_processes: + if state.process_index == 0: + print("\n**Test process execution**") + process_execution_check() + + if state.process_index == 0: + print("\n**Test split between processes as a list**") + test_split_between_processes_list() + + if state.process_index == 0: + print("\n**Test split between processes as a dict**") + test_split_between_processes_nested_dict() + + if state.process_index == 0: + print("\n**Test split between processes as a tensor**") + test_split_between_processes_tensor() + + if state.process_index == 0: + print("\n**Test split between processes evenly**") + test_split_between_processes_evenly() + + if state.process_index == 0: + print("\n**Test split between processes as a datasets.Dataset**") + if is_datasets_available(): + from datasets import Dataset as datasets_Dataset + + test_split_between_processes_dataset(datasets_Dataset) + else: + print("Skipped because Hugging Face datasets is not available") + + if state.local_process_index == 0: + print("\n**Test random number generator synchronization**") + rng_sync_check() + + if state.local_process_index == 0: + print("\n**DataLoader integration test**") + dl_preparation_check() + if state.distributed_type != DistributedType.XLA: + central_dl_preparation_check() + custom_sampler_check() + check_seedable_sampler() + check_seedable_sampler_with_data_seed() + + if state.num_processes > 1: + check_seedable_sampler_in_batch_sampler_shard() + + # Trainings are not exactly the same in DeepSpeed and CPU mode + if state.distributed_type == DistributedType.DEEPSPEED: + return + + if state.local_process_index == 0: + print("\n**Training integration test**") + training_check(use_seedable_sampler=False) + training_check(use_seedable_sampler=True) + + if state.local_process_index == 0: + print("\n**Breakpoint trigger test**") + test_trigger() + + if is_pytest_available(): + if state.local_process_index == 0: + print("\n**Test reinstantiated state**") + test_reinstantiated_state() + + state.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/accelerate/test_utils/scripts/test_sync.py b/accelerate/test_utils/scripts/test_sync.py new file mode 100644 index 0000000000000000000000000000000000000000..310e46e1920466fdf57866919472f619a41f5b33 --- /dev/null +++ b/accelerate/test_utils/scripts/test_sync.py @@ -0,0 +1,413 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from copy import deepcopy + +import torch +import torch.nn.functional as F +from torch.optim import AdamW +from torch.optim.lr_scheduler import LambdaLR +from torch.utils.data import DataLoader + +from accelerate.accelerator import Accelerator, DataLoaderConfiguration, GradientAccumulationPlugin +from accelerate.state import GradientState +from accelerate.test_utils import RegressionDataset, RegressionModel +from accelerate.utils import DistributedType, set_seed + + +def check_model_parameters(model_a, model_b, did_step, iteration, **kwargs): + for param, grad_param in zip(model_a.parameters(), model_b.parameters()): + if not param.requires_grad: + continue + if not did_step: + # Grads should not be in sync + assert torch.allclose(param.grad, grad_param.grad, **kwargs) is False, ( + f"Gradients in sync when they should not be at iteration {iteration}:\nmodel_a grad ({param.grad}) == model_b grad ({grad_param.grad})" + ) + else: + # Grads should be in sync + assert torch.allclose(param.grad, grad_param.grad, **kwargs) is True, ( + f"Gradients not in sync when they should be at iteration {iteration}:\nmodel_a grad ({param.grad}) != model_b grad ({grad_param.grad})" + ) + + +def step_model(model, input, target, accelerator, do_backward=True): + model.train() + output = model(input) + loss = F.mse_loss(output, target.to(output.device)) + if not do_backward: + loss /= accelerator.gradient_accumulation_steps + loss.backward() + else: + accelerator.backward(loss) + + +def get_training_setup(accelerator, sched=False): + "Returns everything needed to perform basic training" + set_seed(42) + model = RegressionModel() + ddp_model = deepcopy(model) + dset = RegressionDataset(length=80) + dataloader = DataLoader(dset, batch_size=16) + model.to(accelerator.device) + if sched: + opt = AdamW(params=model.parameters(), lr=1e-3) + ddp_opt = AdamW(params=ddp_model.parameters(), lr=1e-3) + sched = LambdaLR(opt, lr_lambda=lambda epoch: epoch**0.65) + ddp_sched = LambdaLR(ddp_opt, lr_lambda=lambda epoch: epoch**0.65) + # Make a copy of `model` + if sched: + ddp_model, ddp_opt, ddp_sched, dataloader = accelerator.prepare(ddp_model, ddp_opt, ddp_sched, dataloader) + else: + ddp_model, dataloader = accelerator.prepare(ddp_model, dataloader) + if sched: + return (model, opt, sched, dataloader, ddp_model, ddp_opt, ddp_sched) + return model, ddp_model, dataloader + + +def test_noop_sync(accelerator): + # Test when on a single CPU or GPU that the context manager does nothing + model, ddp_model, dataloader = get_training_setup(accelerator) + # Use a single batch + ddp_input, ddp_target = next(iter(dataloader)).values() + for iteration in range(3): + # Gather the distributed inputs and targs for the base model + input, target = accelerator.gather((ddp_input, ddp_target)) + input, target = input.to(accelerator.device), target.to(accelerator.device) + # Perform our initial ground truth step in non "DDP" + step_model(model, input, target, accelerator) + # Do "gradient accumulation" (noop) + if iteration % 2 == 0: + # Accumulate grads locally + with accelerator.no_sync(ddp_model): + step_model(ddp_model, ddp_input, ddp_target, accelerator) + else: + # Sync grads + step_model(ddp_model, ddp_input, ddp_target, accelerator) + + # Since `no_sync` is a noop, `ddp_model` and `model` grads should always be in sync + check_model_parameters(model, ddp_model, True, iteration) + for param, ddp_param in zip(model.parameters(), ddp_model.parameters()): + if not param.requires_grad: + continue + assert torch.allclose(param.grad, ddp_param.grad), ( + f"Gradients not in sync when they should be:\nModel grad ({param.grad}) != DDP grad ({ddp_param.grad})" + ) + + # Shuffle ddp_input on each iteration + torch.manual_seed(1337 + iteration) + ddp_input = ddp_input[torch.randperm(len(ddp_input))] + + +def test_distributed_sync(accelerator): + # Test on distributed setup that context manager behaves properly + model, ddp_model, dataloader = get_training_setup(accelerator) + # Use a single batch + ddp_input, ddp_target = next(iter(dataloader)).values() + for iteration in range(3): + # Gather the distributed inputs and targs for the base model + input, target = accelerator.gather((ddp_input, ddp_target)) + input, target = input.to(accelerator.device), target.to(accelerator.device) + # Perform our initial ground truth step in non "DDP" + step_model(model, input, target, accelerator) + # Do "gradient accumulation" (noop) + if iteration % 2 == 0: + # Accumulate grads locally + with accelerator.no_sync(ddp_model): + step_model(ddp_model, ddp_input, ddp_target, accelerator) + else: + # Sync grads + step_model(ddp_model, ddp_input, ddp_target, accelerator) + + # DDP model and model should only be in sync when not (iteration % 2 == 0) + for param, ddp_param in zip(model.parameters(), ddp_model.parameters()): + if not param.requires_grad: + continue + if iteration % 2 == 0: + # Grads should not be in sync + assert torch.allclose(param.grad, ddp_param.grad) is False, ( + f"Gradients in sync when they should not be:\nModel grad ({param.grad}) == DDP grad ({ddp_param.grad})" + ) + else: + # Grads should be in sync + assert torch.allclose(param.grad, ddp_param.grad) is True, ( + f"Gradients not in sync when they should be:\nModel grad ({param.grad}) != DDP grad ({ddp_param.grad})" + ) + + # Shuffle ddp_input on each iteration + torch.manual_seed(1337 + iteration) + ddp_input = ddp_input[torch.randperm(len(ddp_input))] + + +def test_distributed_sync_multiple_fwd(accelerator): + # Test on distributed setup that context manager behaves properly when used with multiple forwards followed by multiple backwards + model, ddp_model, dataloader = get_training_setup(accelerator) + # Do multiple forwards + losses = [] + num_iterations = 3 + for iteration in range(num_iterations): + ddp_input, ddp_target = next(iter(dataloader)).values() + + # Gather the distributed inputs and targs for the base model + input, target = accelerator.gather((ddp_input, ddp_target)) + input, target = input.to(accelerator.device), target.to(accelerator.device) + + # Perform our initial ground truth step in non "DDP" + step_model(model, input, target, accelerator) + + # Accumulate grads locally + with accelerator.no_sync(ddp_model): + ddp_output = ddp_model(ddp_input) + loss = F.mse_loss(ddp_output, ddp_target.to(ddp_output.device)) + losses.append(loss) + + # Do multiple backwards and sync only at the last backward + for iteration in range(num_iterations): + loss = losses[iteration] + + if iteration < num_iterations - 1: + # Accumulate grads locally + accelerator.backward(loss) + + # DDP model and model should only be in sync after last backward + for param, ddp_param in zip(model.parameters(), ddp_model.parameters()): + if not param.requires_grad: + continue + # Grads should not be in sync + assert torch.allclose(param.grad, ddp_param.grad) is False, ( + f"Gradients in sync when they should not be:\nModel grad ({param.grad}) == DDP grad ({ddp_param.grad})" + ) + + else: + # Sync grads if last backward + with accelerator.trigger_sync_in_backward(ddp_model): + accelerator.backward(loss) + + # DDP model and model should only be in sync after last backward + for param, ddp_param in zip(model.parameters(), ddp_model.parameters()): + if not param.requires_grad: + continue + # Grads should be in sync + assert torch.allclose(param.grad, ddp_param.grad) is True, ( + f"Gradients not in sync when they should be:\nModel grad ({param.grad}) != DDP grad ({ddp_param.grad})" + ) + + +def test_gradient_accumulation(split_batches=False, dispatch_batches=False, sync_each_batch=False): + gradient_accumulation_plugin = GradientAccumulationPlugin(num_steps=2, sync_each_batch=sync_each_batch) + dataloader_config = DataLoaderConfiguration(split_batches=split_batches, dispatch_batches=dispatch_batches) + accelerator = Accelerator( + dataloader_config=dataloader_config, + gradient_accumulation_plugin=gradient_accumulation_plugin, + ) + # Test that context manager behaves properly + model, ddp_model, dataloader = get_training_setup(accelerator) + for iteration, batch in enumerate(dataloader): + ddp_input, ddp_target = batch.values() + # Gather the distributed inputs and targs for the base model + input, target = accelerator.gather((ddp_input, ddp_target)) + input, target = input.to(accelerator.device), target.to(accelerator.device) + # Perform our initial ground truth step in non "DDP" + step_model(model, input, target, accelerator, False) + # Do "gradient accumulation" (noop) + with accelerator.accumulate(ddp_model): + step_model(ddp_model, ddp_input, ddp_target, accelerator) + + # DDP model and model should only be in sync when not (iteration % 2 == 0) + for param, ddp_param in zip(model.parameters(), ddp_model.parameters()): + if not param.requires_grad: + continue + if ((iteration + 1) % 2 == 0) or (iteration == len(dataloader) - 1) or sync_each_batch: + # Grads should be in sync + assert torch.allclose(param.grad, ddp_param.grad) is True, ( + f"Gradients not in sync when they should be at iteration {iteration}:\nModel grad ({param.grad}) != DDP grad ({ddp_param.grad})" + ) + else: + # Grads should not be in sync + assert torch.allclose(param.grad, ddp_param.grad) is False, ( + f"Gradients in sync when they should not be at iteration {iteration}:\nModel grad ({param.grad}) == DDP grad ({ddp_param.grad})" + ) + + # Shuffle ddp_input on each iteration + torch.manual_seed(1337 + iteration) + ddp_input = ddp_input[torch.randperm(len(ddp_input))] + GradientState._reset_state() + + +def test_gradient_accumulation_with_opt_and_scheduler( + split_batches=False, dispatch_batches=False, sync_each_batch=False +): + gradient_accumulation_plugin = GradientAccumulationPlugin(num_steps=2, sync_each_batch=sync_each_batch) + dataloader_config = DataLoaderConfiguration(split_batches=split_batches, dispatch_batches=dispatch_batches) + accelerator = Accelerator( + dataloader_config=dataloader_config, + gradient_accumulation_plugin=gradient_accumulation_plugin, + ) + # Test that context manager behaves properly + model, opt, sched, dataloader, ddp_model, ddp_opt, ddp_sched = get_training_setup(accelerator, True) + for iteration, batch in enumerate(dataloader): + ddp_input, ddp_target = batch.values() + # Gather the distributed inputs and targs for the base model + input, target = accelerator.gather((ddp_input, ddp_target)) + input, target = input.to(accelerator.device), target.to(accelerator.device) + # Perform our initial ground truth step in non "DDP" + model.train() + ddp_model.train() + step_model(model, input, target, accelerator, False) + opt.step() + + if ((iteration + 1) % 2 == 0) or ((iteration + 1) == len(dataloader)): + if split_batches: + sched.step() + else: + for _ in range(accelerator.num_processes): + sched.step() + + # Perform gradient accumulation under wrapper + with accelerator.accumulate(ddp_model): + step_model(ddp_model, ddp_input, ddp_target, accelerator) + ddp_opt.step() + ddp_sched.step() + + # Learning rates should be the same + assert opt.param_groups[0]["lr"] == ddp_opt.param_groups[0]["lr"], ( + f"Learning rates found in each optimizer did not align\nopt: {opt.param_groups[0]['lr']}\nDDP opt: {ddp_opt.param_groups[0]['lr']}\n" + ) + did_step = (((iteration + 1) % 2) == 0) or ((iteration + 1) == len(dataloader)) + if accelerator.num_processes > 1: + check_model_parameters( + model, + ddp_model, + did_step or sync_each_batch, # syncs at each grad_accum interval of if sync_each_batch==True + iteration, + rtol=1e-3, # needs a relative tolerance due to roundoff errors + ) + + if did_step: + opt.zero_grad() # flush gradients every accum step + ddp_opt.zero_grad() + + # Shuffle ddp_input on each iteration + torch.manual_seed(1337 + iteration) + GradientState._reset_state() + + +def test_dataloader_break(): + accelerator = Accelerator() + first_dset = RegressionDataset(length=80) + first_dataloader = DataLoader(first_dset, batch_size=16) + second_dset = RegressionDataset(length=96) + second_dataloader = DataLoader(second_dset, batch_size=16) + first_dataloader, second_dataloader = accelerator.prepare(first_dataloader, second_dataloader) + + assert accelerator.gradient_state.active_dataloader is None + for iteration, _ in enumerate(first_dataloader): + assert id(accelerator.gradient_state.active_dataloader) == id(first_dataloader) + if iteration < len(first_dataloader) - 1: + assert not accelerator.gradient_state.end_of_dataloader + if iteration == 1: + for batch_num, _ in enumerate(second_dataloader): + assert id(accelerator.gradient_state.active_dataloader) == id(second_dataloader) + if batch_num < len(second_dataloader) - 1: + assert not accelerator.gradient_state.end_of_dataloader + else: + assert accelerator.gradient_state.end_of_dataloader + else: + assert accelerator.gradient_state.end_of_dataloader + assert accelerator.gradient_state.active_dataloader is None + + +def main(): + accelerator = Accelerator() + state = accelerator.state + if state.local_process_index == 0: + print("**Test `accumulate` gradient accumulation with dataloader break**") + if state.distributed_type != DistributedType.XLA: + test_dataloader_break() + if state.distributed_type == DistributedType.NO: + if state.local_process_index == 0: + print("**Test NOOP `no_sync` context manager**") + test_noop_sync(accelerator) + if state.distributed_type in ( + DistributedType.MULTI_GPU, + DistributedType.MULTI_NPU, + DistributedType.MULTI_MLU, + DistributedType.MULTI_SDAA, + DistributedType.MULTI_MUSA, + DistributedType.MULTI_CPU, + DistributedType.MULTI_HPU, + DistributedType.MULTI_NEURON, + ): + if state.local_process_index == 0: + print("**Test Distributed `no_sync` context manager**") + test_distributed_sync(accelerator) + if state.local_process_index == 0: + print("**Test Distributed `no_sync` context manager with multiple forwards**") + test_distributed_sync_multiple_fwd(accelerator) + if state.distributed_type in ( + DistributedType.MULTI_GPU, + DistributedType.MULTI_NPU, + DistributedType.MULTI_MLU, + DistributedType.MULTI_SDAA, + DistributedType.MULTI_MUSA, + DistributedType.MULTI_HPU, + DistributedType.MULTI_NEURON, + ): + for split_batch in [True, False]: + for dispatch_batches in [True, False]: + for sync_each_batch in [True, False]: + if state.local_process_index == 0: + print( + "**Test `accumulate` gradient accumulation, ", + f"`split_batches={split_batch}` and `dispatch_batches={dispatch_batches}` and `sync_each_batch={sync_each_batch}`**", + ) + test_gradient_accumulation(split_batch, dispatch_batches, sync_each_batch) + + # Currently will break on torch 2.0 +, need to investigate why + if state.local_process_index == 0: + print( + "**Test `accumulate` gradient accumulation with optimizer and scheduler, ", + "`split_batches=False`, `dispatch_batches=False`, `sync_each_batch=False`**", + ) + test_gradient_accumulation_with_opt_and_scheduler() + if state.distributed_type in ( + DistributedType.MULTI_GPU, + DistributedType.MULTI_NPU, + DistributedType.MULTI_MLU, + DistributedType.MULTI_SDAA, + DistributedType.MULTI_MUSA, + DistributedType.MULTI_HPU, + DistributedType.MULTI_NEURON, + ): + for split_batch in [True, False]: + for dispatch_batches in [True, False]: + for sync_each_batch in [True, False]: + if not split_batch and not dispatch_batches and not sync_each_batch: + continue + if state.local_process_index == 0: + print( + "**Test `accumulate` gradient accumulation with optimizer and scheduler, ", + f"`split_batches={split_batch}` and `dispatch_batches={dispatch_batches}` and `sync_each_batch={sync_each_batch}`**", + ) + test_gradient_accumulation_with_opt_and_scheduler(split_batch, dispatch_batches, sync_each_batch) + state.destroy_process_group() + + +def _mp_fn(index): + # For xla_spawn (TPUs) + main() + + +if __name__ == "__main__": + main() diff --git a/accelerate/test_utils/testing.py b/accelerate/test_utils/testing.py new file mode 100644 index 0000000000000000000000000000000000000000..c809dc97fe1070438b8e4d5ea301cd9e6f52d9dc --- /dev/null +++ b/accelerate/test_utils/testing.py @@ -0,0 +1,889 @@ +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import inspect +import io +import os +import re +import shutil +import subprocess +import sys +import tempfile +import unittest +from contextlib import contextmanager +from functools import partial +from pathlib import Path +from typing import Optional, Union +from unittest import mock + +import torch + +import accelerate + +from ..state import AcceleratorState +from ..utils import ( + check_cuda_fp8_capability, + compare_versions, + gather, + is_aim_available, + is_bnb_available, + is_clearml_available, + is_comet_ml_available, + is_cuda_available, + is_datasets_available, + is_deepspeed_available, + is_dvclive_available, + is_fp8_available, + is_fp16_available, + is_habana_gaudi1, + is_hpu_available, + is_import_timer_available, + is_matplotlib_available, + is_mlflow_available, + is_mlu_available, + is_mps_available, + is_musa_available, + is_neuron_available, + is_npu_available, + is_pandas_available, + is_pippy_available, + is_pytest_available, + is_schedulefree_available, + is_sdaa_available, + is_swanlab_available, + is_tensorboard_available, + is_timm_available, + is_torch_version, + is_torch_xla_available, + is_torchao_available, + is_torchdata_stateful_dataloader_available, + is_torchvision_available, + is_trackio_available, + is_transformer_engine_available, + is_transformer_engine_mxfp8_available, + is_transformers_available, + is_triton_available, + is_wandb_available, + is_xpu_available, + str_to_bool, +) + + +def get_backend(): + if is_torch_xla_available(): + return "xla", torch.cuda.device_count(), torch.cuda.memory_allocated + elif is_cuda_available(): + return "cuda", torch.cuda.device_count(), torch.cuda.memory_allocated + elif is_mps_available(min_version="2.0"): + return "mps", 1, torch.mps.current_allocated_memory + elif is_mps_available(): + return "mps", 1, lambda: 0 + elif is_mlu_available(): + return "mlu", torch.mlu.device_count(), torch.mlu.memory_allocated + elif is_sdaa_available(): + return "sdaa", torch.sdaa.device_count(), torch.sdaa.memory_allocated + elif is_musa_available(): + return "musa", torch.musa.device_count(), torch.musa.memory_allocated + elif is_npu_available(): + return "npu", torch.npu.device_count(), torch.npu.memory_allocated + elif is_xpu_available(): + return "xpu", torch.xpu.device_count(), torch.xpu.memory_allocated + elif is_hpu_available(): + return "hpu", torch.hpu.device_count(), torch.hpu.memory_allocated + elif is_neuron_available(): + return "neuron", torch.neuron.device_count(), torch.neuron.memory_allocated + else: + return "cpu", 1, lambda: 0 + + +torch_device, device_count, memory_allocated_func = get_backend() + + +def get_launch_command(**kwargs) -> list: + """ + Wraps around `kwargs` to help simplify launching from `subprocess`. + + Example: + ```python + # returns ['accelerate', 'launch', '--num_processes=2', '--device_count=2'] + get_launch_command(num_processes=2, device_count=2) + ``` + """ + command = ["accelerate", "launch"] + for k, v in kwargs.items(): + if isinstance(v, bool) and v: + command.append(f"--{k}") + elif v is not None: + command.append(f"--{k}={v}") + return command + + +DEFAULT_LAUNCH_COMMAND = get_launch_command(num_processes=device_count, monitor_interval=0.1) + + +def parse_flag_from_env(key, default=False): + try: + value = os.environ[key] + except KeyError: + # KEY isn't set, default to `default`. + _value = default + else: + # KEY is set, convert it to True or False. + try: + _value = str_to_bool(value) + except ValueError: + # More values are supported, but let's keep the message simple. + raise ValueError(f"If set, {key} must be yes or no.") + return _value + + +_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False) + + +def skip(test_case): + "Decorator that skips a test unconditionally" + return unittest.skip("Test was skipped")(test_case) + + +def slow(test_case): + """ + Decorator marking a test as slow. Slow tests are skipped by default. Set the RUN_SLOW environment variable to a + truthy value to run them. + """ + return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case) + + +def require_cpu(test_case): + """ + Decorator marking a test that must be only ran on the CPU. These tests are skipped when a GPU is available. + """ + return unittest.skipUnless(torch_device == "cpu", "test requires only a CPU")(test_case) + + +def require_non_cpu(test_case): + """ + Decorator marking a test that requires a hardware accelerator backend. These tests are skipped when there are no + hardware accelerator available. + """ + return unittest.skipUnless(torch_device != "cpu", "test requires a GPU")(test_case) + + +def require_cuda(test_case): + """ + Decorator marking a test that requires CUDA. These tests are skipped when there are no GPU available or when + TorchXLA is available. + """ + return unittest.skipUnless(is_cuda_available() and not is_torch_xla_available(), "test requires a GPU")(test_case) + + +def require_cuda_or_hpu(test_case): + """ + Decorator marking a test that requires CUDA or HPU. These tests are skipped when there are no GPU available or when + TorchXLA is available. + """ + return unittest.skipUnless( + (is_cuda_available() and not is_torch_xla_available()) or is_hpu_available(), "test requires a GPU or HPU" + )(test_case) + + +def require_xpu(test_case): + """ + Decorator marking a test that requires XPU. These tests are skipped when there are no XPU available. + """ + return unittest.skipUnless(is_xpu_available(), "test requires a XPU")(test_case) + + +def require_cuda_or_xpu(test_case): + """ + Decorator marking a test that requires CUDA or XPU. These tests are skipped when there are no GPU available or when + TorchXLA is available. + """ + cuda_condition = is_cuda_available() and not is_torch_xla_available() + xpu_condition = is_xpu_available() + return unittest.skipUnless(cuda_condition or xpu_condition, "test requires a CUDA GPU or XPU")(test_case) + + +def require_non_xpu(test_case): + """ + Decorator marking a test that should be skipped for XPU. + """ + return unittest.skipUnless(torch_device != "xpu", "test requires a non-XPU")(test_case) + + +def require_non_hpu(test_case): + """ + Decorator marking a test that should be skipped for HPU. + """ + return unittest.skipUnless(torch_device != "hpu", "test requires a non-HPU")(test_case) + + +def require_fp16(test_case): + """ + Decorator marking a test that requires FP16. These tests are skipped when FP16 is not supported. + """ + + return unittest.skipUnless(is_fp16_available(), "test requires FP16 support")(test_case) + + +def require_fp8(test_case): + """ + Decorator marking a test that requires FP8. These tests are skipped when FP8 is not supported. + """ + + # is_fp8_available only checks for libraries + # ideally it should check for device capability as well + fp8_is_available = is_fp8_available() + + if torch.cuda.is_available() and not check_cuda_fp8_capability(): + fp8_is_available = False + + if is_hpu_available() and is_habana_gaudi1(): + fp8_is_available = False + + return unittest.skipUnless(fp8_is_available, "test requires FP8 support")(test_case) + + +def require_fsdp2(test_case): + return unittest.skipUnless(is_torch_version(">=", "2.5.0"), "test requires FSDP2 (torch >= 2.5.0)")(test_case) + + +def require_mlu(test_case): + """ + Decorator marking a test that requires MLU. These tests are skipped when there are no MLU available. + """ + return unittest.skipUnless(is_mlu_available(), "test require a MLU")(test_case) + + +def require_sdaa(test_case): + """ + Decorator marking a test that requires SDAA. These tests are skipped when there are no SDAA available. + """ + return unittest.skipUnless(is_sdaa_available(), "test require a SDAA")(test_case) + + +def require_musa(test_case): + """ + Decorator marking a test that requires MUSA. These tests are skipped when there are no MUSA available. + """ + return unittest.skipUnless(is_musa_available(), "test require a MUSA")(test_case) + + +def require_npu(test_case): + """ + Decorator marking a test that requires NPU. These tests are skipped when there are no NPU available. + """ + return unittest.skipUnless(is_npu_available(), "test require a NPU")(test_case) + + +def require_neuron(test_case): + """ + Decorator marking a test that requires Neuron. These tests are skipped when there are no Neuron Cores available. + """ + return unittest.skipUnless(is_neuron_available(), "test require Neuron Cores")(test_case) + + +def require_mps(test_case): + """ + Decorator marking a test that requires MPS backend. These tests are skipped when torch doesn't support `mps` + backend. + """ + return unittest.skipUnless(is_mps_available(), "test requires a `mps` backend support in `torch`")(test_case) + + +def require_huggingface_suite(test_case): + """ + Decorator marking a test that requires transformers and datasets. These tests are skipped when they are not. + """ + return unittest.skipUnless( + is_transformers_available() and is_datasets_available(), + "test requires the Hugging Face suite", + )(test_case) + + +def require_transformers(test_case): + """ + Decorator marking a test that requires transformers. These tests are skipped when they are not. + """ + return unittest.skipUnless(is_transformers_available(), "test requires the transformers library")(test_case) + + +def require_timm(test_case): + """ + Decorator marking a test that requires timm. These tests are skipped when they are not. + """ + return unittest.skipUnless(is_timm_available(), "test requires the timm library")(test_case) + + +def require_torchvision(test_case): + """ + Decorator marking a test that requires torchvision. These tests are skipped when they are not. + """ + return unittest.skipUnless(is_torchvision_available(), "test requires the torchvision library")(test_case) + + +def require_triton(test_case): + """ + Decorator marking a test that requires triton. These tests are skipped when they are not. + """ + return unittest.skipUnless(is_triton_available(), "test requires the triton library")(test_case) + + +def require_schedulefree(test_case): + """ + Decorator marking a test that requires schedulefree. These tests are skipped when they are not. + """ + return unittest.skipUnless(is_schedulefree_available(), "test requires the schedulefree library")(test_case) + + +def require_bnb(test_case): + """ + Decorator marking a test that requires bitsandbytes. These tests are skipped when they are not. + """ + return unittest.skipUnless(is_bnb_available(), "test requires the bitsandbytes library")(test_case) + + +def require_tpu(test_case): + """ + Decorator marking a test that requires TPUs. These tests are skipped when there are no TPUs available. + """ + return unittest.skipUnless(is_torch_xla_available(check_is_tpu=True), "test requires TPU")(test_case) + + +def require_non_torch_xla(test_case): + """ + Decorator marking a test as requiring an environment without TorchXLA. These tests are skipped when TorchXLA is + available. + """ + return unittest.skipUnless(not is_torch_xla_available(), "test requires an env without TorchXLA")(test_case) + + +def require_single_device(test_case): + """ + Decorator marking a test that requires a single device. These tests are skipped when there is no hardware + accelerator available or number of devices is more than one. + """ + return unittest.skipUnless( + torch_device != "cpu" and device_count == 1, "test requires a single device accelerator" + )(test_case) + + +def require_single_gpu(test_case): + """ + Decorator marking a test that requires CUDA on a single GPU. These tests are skipped when there are no GPU + available or number of GPUs is more than one. + """ + return unittest.skipUnless(torch.cuda.device_count() == 1, "test requires a GPU")(test_case) + + +def require_single_xpu(test_case): + """ + Decorator marking a test that requires CUDA on a single XPU. These tests are skipped when there are no XPU + available or number of xPUs is more than one. + """ + return unittest.skipUnless(torch.xpu.device_count() == 1, "test requires a XPU")(test_case) + + +def require_multi_device(test_case): + """ + Decorator marking a test that requires a multi-device setup. These tests are skipped on a machine without multiple + devices. + """ + return unittest.skipUnless(device_count > 1, "test requires multiple hardware accelerators")(test_case) + + +def require_multi_gpu(test_case): + """ + Decorator marking a test that requires a multi-GPU setup. These tests are skipped on a machine without multiple + GPUs. + """ + return unittest.skipUnless(torch.cuda.device_count() > 1, "test requires multiple GPUs")(test_case) + + +def require_multi_xpu(test_case): + """ + Decorator marking a test that requires a multi-XPU setup. These tests are skipped on a machine without multiple + XPUs. + """ + return unittest.skipUnless(torch.xpu.device_count() > 1, "test requires multiple XPUs")(test_case) + + +def require_multi_gpu_or_xpu(test_case): + """ + Decorator marking a test that requires a multi-GPU setup. These tests are skipped on a machine without multiple + GPUs or XPUs. + """ + return unittest.skipUnless( + (is_cuda_available() or is_xpu_available()) and device_count > 1, "test requires multiple GPUs or XPUs" + )(test_case) + + +def require_deepspeed(test_case): + """ + Decorator marking a test that requires DeepSpeed installed. These tests are skipped when DeepSpeed isn't installed + """ + return unittest.skipUnless(is_deepspeed_available(), "test requires DeepSpeed")(test_case) + + +def require_tp(test_case): + """ + Decorator marking a test that requires TP installed. These tests are skipped when TP isn't installed + """ + return unittest.skipUnless( + is_torch_version(">=", "2.3.0") and compare_versions("transformers", ">=", "4.52.0"), + "test requires torch version >= 2.3.0 and transformers version >= 4.52.0", + )(test_case) + + +def require_torch_min_version(test_case=None, version=None): + """ + Decorator marking that a test requires a particular torch version to be tested. These tests are skipped when an + installed torch version is less than the required one. + """ + if test_case is None: + return partial(require_torch_min_version, version=version) + return unittest.skipUnless(is_torch_version(">=", version), f"test requires torch version >= {version}")(test_case) + + +def require_tensorboard(test_case): + """ + Decorator marking a test that requires tensorboard installed. These tests are skipped when tensorboard isn't + installed + """ + return unittest.skipUnless(is_tensorboard_available(), "test requires Tensorboard")(test_case) + + +def require_wandb(test_case): + """ + Decorator marking a test that requires wandb installed. These tests are skipped when wandb isn't installed + """ + return unittest.skipUnless(is_wandb_available(), "test requires wandb")(test_case) + + +def require_trackio(test_case): + """ + Decorator marking a test that requires trackio installed. These tests are skipped when trackio isn't installed + """ + return unittest.skipUnless(is_trackio_available(), "test requires trackio")(test_case) + + +def require_comet_ml(test_case): + """ + Decorator marking a test that requires comet_ml installed. These tests are skipped when comet_ml isn't installed + """ + return unittest.skipUnless(is_comet_ml_available(), "test requires comet_ml")(test_case) + + +def require_aim(test_case): + """ + Decorator marking a test that requires aim installed. These tests are skipped when aim isn't installed + """ + return unittest.skipUnless(is_aim_available(), "test requires aim")(test_case) + + +def require_clearml(test_case): + """ + Decorator marking a test that requires clearml installed. These tests are skipped when clearml isn't installed + """ + return unittest.skipUnless(is_clearml_available(), "test requires clearml")(test_case) + + +def require_dvclive(test_case): + """ + Decorator marking a test that requires dvclive installed. These tests are skipped when dvclive isn't installed + """ + return unittest.skipUnless(is_dvclive_available(), "test requires dvclive")(test_case) + + +def require_swanlab(test_case): + """ + Decorator marking a test that requires swanlab installed. These tests are skipped when swanlab isn't installed + """ + return unittest.skipUnless(is_swanlab_available(), "test requires swanlab")(test_case) + + +def require_pandas(test_case): + """ + Decorator marking a test that requires pandas installed. These tests are skipped when pandas isn't installed + """ + return unittest.skipUnless(is_pandas_available(), "test requires pandas")(test_case) + + +def require_mlflow(test_case): + """ + Decorator marking a test that requires mlflow installed. These tests are skipped when mlflow isn't installed + """ + return unittest.skipUnless(is_mlflow_available(), "test requires mlflow")(test_case) + + +def require_pippy(test_case): + """ + Decorator marking a test that requires pippy installed. These tests are skipped when pippy isn't installed It is + also checked if the test is running on a Gaudi1 device which doesn't support pippy. + """ + return unittest.skipUnless(is_pippy_available() and not is_habana_gaudi1(), "test requires pippy")(test_case) + + +def require_import_timer(test_case): + """ + Decorator marking a test that requires tuna interpreter installed. These tests are skipped when tuna isn't + installed + """ + return unittest.skipUnless(is_import_timer_available(), "test requires tuna interpreter")(test_case) + + +def require_transformer_engine(test_case): + """ + Decorator marking a test that requires transformers engine installed. These tests are skipped when transformers + engine isn't installed + """ + return unittest.skipUnless(is_transformer_engine_available(), "test requires transformers engine")(test_case) + + +def require_transformer_engine_mxfp8(test_case): + """ + Decorator marking a test that requires transformers engine MXFP8 block scaling available. These tests are skipped + when transformers engine MXFP8 block scaling isn't available + """ + return unittest.skipUnless( + is_transformer_engine_mxfp8_available(), "test requires transformers engine MXFP8 block scaling" + )(test_case) + + +def require_torchao(test_case): + """ + Decorator marking a test that requires torchao installed. These tests are skipped when torchao isn't installed + """ + return unittest.skipUnless(is_torchao_available(), "test requires torchao")(test_case) + + +def require_matplotlib(test_case): + """ + Decorator marking a test that requires matplotlib installed. These tests are skipped when matplotlib isn't + installed + """ + return unittest.skipUnless(is_matplotlib_available(), "test requires matplotlib")(test_case) + + +_atleast_one_tracker_available = ( + any([is_wandb_available(), is_tensorboard_available(), is_trackio_available(), is_swanlab_available()]) + and not is_comet_ml_available() +) + + +def require_trackers(test_case): + """ + Decorator marking that a test requires at least one tracking library installed. These tests are skipped when none + are installed + """ + return unittest.skipUnless( + _atleast_one_tracker_available, + "test requires at least one tracker to be available and for `comet_ml` to not be installed", + )(test_case) + + +def require_torchdata_stateful_dataloader(test_case): + """ + Decorator marking a test that requires torchdata.stateful_dataloader. + + These tests are skipped when torchdata with stateful_dataloader module isn't installed. + + """ + return unittest.skipUnless( + is_torchdata_stateful_dataloader_available(), "test requires torchdata.stateful_dataloader" + )(test_case) + + +def run_first(test_case): + """ + Decorator marking a test with order(1). When pytest-order plugin is installed, tests marked with this decorator are + guaranteed to run first. + + This is especially useful in some test settings like on a Gaudi instance where a Gaudi device can only be used by a + single process at a time. So we make sure all tests that run in a subprocess are launched first, to avoid device + allocation conflicts. + + If pytest is not installed, test will be returned as is. + """ + + if is_pytest_available(): + import pytest + + return pytest.mark.order(1)(test_case) + return test_case + + +class TempDirTestCase(unittest.TestCase): + """ + A TestCase class that keeps a single `tempfile.TemporaryDirectory` open for the duration of the class, wipes its + data at the start of a test, and then destroys it at the end of the TestCase. + + Useful for when a class or API requires a single constant folder throughout it's use, such as Weights and Biases + + The temporary directory location will be stored in `self.tmpdir` + """ + + clear_on_setup = True + + @classmethod + def setUpClass(cls): + "Creates a `tempfile.TemporaryDirectory` and stores it in `cls.tmpdir`" + cls.tmpdir = Path(tempfile.mkdtemp()) + + @classmethod + def tearDownClass(cls): + "Remove `cls.tmpdir` after test suite has finished" + if os.path.exists(cls.tmpdir): + shutil.rmtree(cls.tmpdir) + + def setUp(self): + "Destroy all contents in `self.tmpdir`, but not `self.tmpdir`" + if self.clear_on_setup: + for path in self.tmpdir.glob("**/*"): + if path.is_file(): + path.unlink() + elif path.is_dir(): + shutil.rmtree(path) + + +class AccelerateTestCase(unittest.TestCase): + """ + A TestCase class that will reset the accelerator state at the end of every test. Every test that checks or utilizes + the `AcceleratorState` class should inherit from this to avoid silent failures due to state being shared between + tests. + """ + + def tearDown(self): + super().tearDown() + # Reset the state of the AcceleratorState singleton. + AcceleratorState._reset_state(True) + + +class MockingTestCase(unittest.TestCase): + """ + A TestCase class designed to dynamically add various mockers that should be used in every test, mimicking the + behavior of a class-wide mock when defining one normally will not do. + + Useful when a mock requires specific information available only initialized after `TestCase.setUpClass`, such as + setting an environment variable with that information. + + The `add_mocks` function should be ran at the end of a `TestCase`'s `setUp` function, after a call to + `super().setUp()` such as: + ```python + def setUp(self): + super().setUp() + mocks = mock.patch.dict(os.environ, {"SOME_ENV_VAR", "SOME_VALUE"}) + self.add_mocks(mocks) + ``` + """ + + def add_mocks(self, mocks: Union[mock.Mock, list[mock.Mock]]): + """ + Add custom mocks for tests that should be repeated on each test. Should be called during + `MockingTestCase.setUp`, after `super().setUp()`. + + Args: + mocks (`mock.Mock` or list of `mock.Mock`): + Mocks that should be added to the `TestCase` after `TestCase.setUpClass` has been run + """ + self.mocks = mocks if isinstance(mocks, (tuple, list)) else [mocks] + for m in self.mocks: + m.start() + self.addCleanup(m.stop) + + +def are_the_same_tensors(tensor): + state = AcceleratorState() + tensor = tensor[None].clone().to(state.device) + tensors = gather(tensor).cpu() + tensor = tensor[0].cpu() + for i in range(tensors.shape[0]): + if not torch.equal(tensors[i], tensor): + return False + return True + + +class _RunOutput: + def __init__(self, returncode, stdout, stderr): + self.returncode = returncode + self.stdout = stdout + self.stderr = stderr + + +async def _read_stream(stream, callback): + while True: + line = await stream.readline() + if line: + callback(line) + else: + break + + +async def _stream_subprocess(cmd, env=None, stdin=None, timeout=None, quiet=False, echo=False) -> _RunOutput: + if echo: + print("\nRunning: ", " ".join(cmd)) + + p = await asyncio.create_subprocess_exec( + cmd[0], + *cmd[1:], + stdin=stdin, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + env=env, + ) + + # note: there is a warning for a possible deadlock when using `wait` with huge amounts of data in the pipe + # https://docs.python.org/3/library/asyncio-subprocess.html#asyncio.asyncio.subprocess.Process.wait + # + # If it starts hanging, will need to switch to the following code. The problem is that no data + # will be seen until it's done and if it hangs for example there will be no debug info. + # out, err = await p.communicate() + # return _RunOutput(p.returncode, out, err) + + out = [] + err = [] + + def tee(line, sink, pipe, label=""): + line = line.decode("utf-8").rstrip() + sink.append(line) + if not quiet: + print(label, line, file=pipe) + + # XXX: the timeout doesn't seem to make any difference here + await asyncio.wait( + [ + asyncio.create_task(_read_stream(p.stdout, lambda l: tee(l, out, sys.stdout, label="stdout:"))), + asyncio.create_task(_read_stream(p.stderr, lambda l: tee(l, err, sys.stderr, label="stderr:"))), + ], + timeout=timeout, + ) + return _RunOutput(await p.wait(), out, err) + + +def execute_subprocess_async(cmd: list, env=None, stdin=None, timeout=180, quiet=False, echo=True) -> _RunOutput: + # Cast every path in `cmd` to a string + for i, c in enumerate(cmd): + if isinstance(c, Path): + cmd[i] = str(c) + + result = asyncio.run(_stream_subprocess(cmd, env=env, stdin=stdin, timeout=timeout, quiet=quiet, echo=echo)) + + cmd_str = " ".join(cmd) + if result.returncode > 0: + stderr = "\n".join(result.stderr) + raise RuntimeError( + f"'{cmd_str}' failed with returncode {result.returncode}\n\n" + f"The combined stderr from workers follows:\n{stderr}" + ) + + return result + + +def pytest_xdist_worker_id(): + """ + Returns an int value of worker's numerical id under `pytest-xdist`'s concurrent workers `pytest -n N` regime, or 0 + if `-n 1` or `pytest-xdist` isn't being used. + """ + worker = os.environ.get("PYTEST_XDIST_WORKER", "gw0") + worker = re.sub(r"^gw", "", worker, 0, re.M) + return int(worker) + + +def get_torch_dist_unique_port(): + """ + Returns a port number that can be fed to `torch.distributed.launch`'s `--master_port` argument. + + Under `pytest-xdist` it adds a delta number based on a worker id so that concurrent tests don't try to use the same + port at once. + """ + port = 29500 + uniq_delta = pytest_xdist_worker_id() + return port + uniq_delta + + +class SubprocessCallException(Exception): + pass + + +def run_command(command: list[str], return_stdout=False, env=None): + """ + Runs `command` with `subprocess.check_output` and will potentially return the `stdout`. Will also properly capture + if an error occurred while running `command` + """ + # Cast every path in `command` to a string + for i, c in enumerate(command): + if isinstance(c, Path): + command[i] = str(c) + if env is None: + env = os.environ.copy() + try: + output = subprocess.check_output(command, stderr=subprocess.STDOUT, env=env) + if return_stdout: + if hasattr(output, "decode"): + output = output.decode("utf-8") + return output + except subprocess.CalledProcessError as e: + raise SubprocessCallException( + f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}" + ) from e + + +def path_in_accelerate_package(*components: str) -> Path: + """ + Get a path within the `accelerate` package's directory. + + Args: + *components: Components of the path to join after the package directory. + + Returns: + `Path`: The path to the requested file or directory. + """ + + accelerate_package_dir = Path(inspect.getfile(accelerate)).parent + return accelerate_package_dir.joinpath(*components) + + +@contextmanager +def assert_exception(exception_class: Exception, msg: Optional[str] = None) -> bool: + """ + Context manager to assert that the right `Exception` class was raised. + + If `msg` is provided, will check that the message is contained in the raised exception. + """ + was_ran = False + try: + yield + was_ran = True + except Exception as e: + assert isinstance(e, exception_class), f"Expected exception of type {exception_class} but got {type(e)}" + if msg is not None: + assert msg in str(e), f"Expected message '{msg}' to be in exception but got '{str(e)}'" + if was_ran: + raise AssertionError(f"Expected exception of type {exception_class} but ran without issue.") + + +def capture_call_output(func, *args, **kwargs): + """ + Takes in a `func` with `args` and `kwargs` and returns the captured stdout as a string + """ + captured_output = io.StringIO() + original_stdout = sys.stdout + try: + sys.stdout = captured_output + func(*args, **kwargs) + except Exception as e: + raise e + finally: + sys.stdout = original_stdout + return captured_output.getvalue() diff --git a/accelerate/test_utils/training.py b/accelerate/test_utils/training.py new file mode 100644 index 0000000000000000000000000000000000000000..609912bf542aa4d7eccf475689db066a5a8897c4 --- /dev/null +++ b/accelerate/test_utils/training.py @@ -0,0 +1,150 @@ +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch +from torch.utils.data import DataLoader + +from accelerate.utils.dataclasses import DistributedType + + +class RegressionDataset: + def __init__(self, a=2, b=3, length=64, seed=None): + rng = np.random.default_rng(seed) + self.length = length + self.x = rng.normal(size=(length,)).astype(np.float32) + self.y = a * self.x + b + rng.normal(scale=0.1, size=(length,)).astype(np.float32) + + def __len__(self): + return self.length + + def __getitem__(self, i): + return {"x": self.x[i], "y": self.y[i]} + + +class RegressionModel(torch.nn.Module): + def __init__(self, a=0, b=0, double_output=False): + super().__init__() + self.a = torch.nn.Parameter(torch.tensor(a).float()) + self.b = torch.nn.Parameter(torch.tensor(b).float()) + self.first_batch = True + + def forward(self, x=None): + if self.first_batch: + print(f"Model dtype: {self.a.dtype}, {self.b.dtype}. Input dtype: {x.dtype}") + self.first_batch = False + return x * self.a + self.b + + +def mocked_dataloaders(accelerator, batch_size: int = 16): + from datasets import load_dataset + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") + data_files = {"train": "tests/test_samples/MRPC/train.csv", "validation": "tests/test_samples/MRPC/dev.csv"} + datasets = load_dataset("csv", data_files=data_files) + label_list = datasets["train"].unique("label") + + label_to_id = {v: i for i, v in enumerate(label_list)} + + def tokenize_function(examples): + # max_length=None => use the model max length (it's actually the default) + outputs = tokenizer( + examples["sentence1"], examples["sentence2"], truncation=True, max_length=None, padding="max_length" + ) + if "label" in examples: + outputs["labels"] = [label_to_id[l] for l in examples["label"]] + return outputs + + # Apply the method we just defined to all the examples in all the splits of the dataset + tokenized_datasets = datasets.map( + tokenize_function, + batched=True, + remove_columns=["sentence1", "sentence2", "label"], + ) + + def collate_fn(examples): + # On TPU it's best to pad everything to the same length or training will be very slow. + if accelerator.distributed_type == DistributedType.XLA: + return tokenizer.pad(examples, padding="max_length", max_length=128, return_tensors="pt") + return tokenizer.pad(examples, padding="longest", return_tensors="pt") + + # Instantiate dataloaders. + train_dataloader = DataLoader(tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=2) + eval_dataloader = DataLoader(tokenized_datasets["validation"], shuffle=False, collate_fn=collate_fn, batch_size=1) + + return train_dataloader, eval_dataloader + + +def mocked_dataloaders_for_autoregressive_models(accelerator, batch_size: int = 16): + from datasets import load_dataset + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-360M") + tokenizer.pad_token = tokenizer.eos_token + + data_files = {"train": "tests/test_samples/MRPC/train.csv", "validation": "tests/test_samples/MRPC/dev.csv"} + datasets = load_dataset("csv", data_files=data_files) + + def tokenize_function(examples): + # max_length=None => use the model max length (it's actually the default) + outputs = tokenizer(examples["sentence1"], truncation=True, max_length=None, return_attention_mask=False) + return outputs + + # Apply the method we just defined to all the examples in all the splits of the dataset + # starting with the main process first: + with accelerator.main_process_first(): + tokenized_datasets = datasets.map( + tokenize_function, + batched=True, + remove_columns=["sentence1", "sentence2", "label"], + ) + + def collate_fn(examples): + # On TPU it's best to pad everything to the same length or training will be very slow. + max_length = ( + 128 + if accelerator.distributed_type == DistributedType.XLA + else max([len(e["input_ids"]) for e in examples]) + ) + # When using mixed precision we want round multiples of 8/16 + if accelerator.mixed_precision == "fp8": + pad_to_multiple_of = 16 + elif accelerator.mixed_precision != "no": + pad_to_multiple_of = 8 + else: + pad_to_multiple_of = None + + batch = tokenizer.pad( + examples, + padding="max_length", + max_length=max_length + 1, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors="pt", + ) + + batch["labels"] = batch["input_ids"][:, 1:] + batch["input_ids"] = batch["input_ids"][:, :-1] + if "attention_mask" in batch: + batch["attention_mask"] = batch["attention_mask"][:, :-1] + + batch["labels"] = torch.where(batch["labels"] == tokenizer.pad_token_id, -100, batch["labels"]) + + return batch + + # Instantiate dataloaders. + train_dataloader = DataLoader(tokenized_datasets["train"], shuffle=False, collate_fn=collate_fn, batch_size=2) + eval_dataloader = DataLoader(tokenized_datasets["validation"], shuffle=False, collate_fn=collate_fn, batch_size=1) + + return train_dataloader, eval_dataloader diff --git a/accelerate/utils/__init__.py b/accelerate/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..76b930e83014cc2e0efb801266a75c4cdc5f4bf7 --- /dev/null +++ b/accelerate/utils/__init__.py @@ -0,0 +1,304 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ..parallelism_config import ParallelismConfig +from .ao import convert_model_to_fp8_ao, filter_first_and_last_linear_layers, has_ao_layers +from .constants import ( + MITA_PROFILING_AVAILABLE_PYTORCH_VERSION, + MODEL_NAME, + OPTIMIZER_NAME, + PROFILE_PATTERN_NAME, + RNG_STATE_NAME, + SAFE_MODEL_NAME, + SAFE_WEIGHTS_INDEX_NAME, + SAFE_WEIGHTS_NAME, + SAFE_WEIGHTS_PATTERN_NAME, + SAMPLER_NAME, + SCALER_NAME, + SCHEDULER_NAME, + TORCH_DISTRIBUTED_OPERATION_TYPES, + TORCH_LAUNCH_PARAMS, + WEIGHTS_INDEX_NAME, + WEIGHTS_NAME, + WEIGHTS_PATTERN_NAME, + XPU_PROFILING_AVAILABLE_PYTORCH_VERSION, +) +from .dataclasses import ( + AORecipeKwargs, + AutocastKwargs, + BnbQuantizationConfig, + ComputeEnvironment, + CustomDtype, + DataLoaderConfiguration, + DDPCommunicationHookType, + DeepSpeedPlugin, + DeepSpeedSequenceParallelConfig, + DistributedDataParallelKwargs, + DistributedType, + DynamoBackend, + FP8RecipeKwargs, + FullyShardedDataParallelPlugin, + GradientAccumulationPlugin, + GradScalerKwargs, + InitProcessGroupKwargs, + KwargsHandler, + LoggerType, + MegatronLMPlugin, + MSAMPRecipeKwargs, + PrecisionType, + ProfileKwargs, + ProjectConfiguration, + RNGType, + SageMakerDistributedType, + TensorInformation, + TERecipeKwargs, + TorchContextParallelConfig, + TorchDynamoPlugin, + TorchTensorParallelConfig, + TorchTensorParallelPlugin, + add_model_config_to_megatron_parser, +) +from .environment import ( + are_libraries_initialized, + check_cuda_fp8_capability, + check_cuda_p2p_ib_support, + clear_environment, + convert_dict_to_env_variables, + get_cpu_distributed_information, + get_current_device_type, + get_gpu_info, + get_int_from_env, + parse_choice_from_env, + parse_flag_from_env, + patch_environment, + purge_accelerate_environment, + set_numa_affinity, + str_to_bool, +) +from .imports import ( + deepspeed_required, + is_4bit_bnb_available, + is_8bit_bnb_available, + is_aim_available, + is_bf16_available, + is_bitsandbytes_multi_backend_available, + is_bnb_available, + is_boto3_available, + is_clearml_available, + is_comet_ml_available, + is_cuda_available, + is_datasets_available, + is_deepspeed_available, + is_dvclive_available, + is_fp8_available, + is_fp16_available, + is_habana_gaudi1, + is_hpu_available, + is_import_timer_available, + is_lomo_available, + is_matplotlib_available, + is_megatron_lm_available, + is_mlflow_available, + is_mlu_available, + is_mps_available, + is_msamp_available, + is_musa_available, + is_neuron_available, + is_npu_available, + is_pandas_available, + is_peft_available, + is_pippy_available, + is_pynvml_available, + is_pytest_available, + is_rich_available, + is_sagemaker_available, + is_schedulefree_available, + is_sdaa_available, + is_swanlab_available, + is_tensorboard_available, + is_timm_available, + is_torch_xla_available, + is_torchao_available, + is_torchdata_available, + is_torchdata_stateful_dataloader_available, + is_torchvision_available, + is_trackio_available, + is_transformer_engine_available, + is_transformer_engine_mxfp8_available, + is_transformers_available, + is_triton_available, + is_wandb_available, + is_weights_only_available, + is_xccl_available, + is_xpu_available, + torchao_required, +) +from .modeling import ( + align_module_device, + calculate_maximum_sizes, + check_device_map, + check_tied_parameters_in_config, + check_tied_parameters_on_same_device, + compute_module_sizes, + convert_file_size_to_int, + dtype_byte_size, + find_tied_parameters, + get_balanced_memory, + get_grad_scaler, + get_max_layer_size, + get_max_memory, + get_mixed_precision_context_manager, + has_offloaded_params, + id_tensor_storage, + infer_auto_device_map, + is_peft_model, + load_checkpoint_in_model, + load_offloaded_weights, + load_state_dict, + named_module_tensors, + retie_parameters, + set_module_tensor_to_device, +) +from .offload import ( + OffloadedWeightsLoader, + PrefixedDataset, + extract_submodules_state_dict, + load_offloaded_weight, + offload_state_dict, + offload_weight, + save_offload_index, +) +from .operations import ( + CannotPadNestedTensorWarning, + GatheredParameters, + broadcast, + broadcast_object_list, + concatenate, + convert_outputs_to_fp32, + convert_to_fp32, + copy_tensor_to_devices, + find_batch_size, + find_device, + gather, + gather_object, + get_data_structure, + honor_type, + ignorant_find_batch_size, + initialize_tensors, + is_namedtuple, + is_tensor_information, + is_torch_tensor, + listify, + pad_across_processes, + pad_input_tensors, + recursively_apply, + reduce, + send_to_device, + slice_tensors, +) +from .versions import compare_versions, is_torch_version + + +if is_deepspeed_available(): + from .deepspeed import ( + DeepSpeedEngineWrapper, + DeepSpeedOptimizerWrapper, + DeepSpeedSchedulerWrapper, + DummyOptim, + DummyScheduler, + HfDeepSpeedConfig, + get_active_deepspeed_plugin, + map_pytorch_optim_to_deepspeed, + ) + +from .bnb import has_4bit_bnb_layers, load_and_quantize_model +from .fsdp_utils import ( + disable_fsdp_ram_efficient_loading, + enable_fsdp_ram_efficient_loading, + ensure_weights_retied, + fsdp2_apply_ac, + fsdp2_canonicalize_names, + fsdp2_load_full_state_dict, + fsdp2_prepare_model, + fsdp2_switch_optimizer_parameters, + get_fsdp2_grad_scaler, + load_fsdp_model, + load_fsdp_optimizer, + merge_fsdp_weights, + save_fsdp_model, + save_fsdp_optimizer, +) +from .launch import ( + PrepareForLaunch, + _filter_args, + prepare_deepspeed_cmd_env, + prepare_multi_gpu_env, + prepare_sagemager_args_inputs, + prepare_simple_launcher_cmd_env, + prepare_tpu, +) + +# For docs +from .megatron_lm import ( + AbstractTrainStep, + BertTrainStep, + GPTTrainStep, + MegatronLMDummyDataLoader, + MegatronLMDummyScheduler, + T5TrainStep, + avg_losses_across_data_parallel_group, +) + + +if is_megatron_lm_available(): + from .megatron_lm import ( + MegatronEngine, + MegatronLMOptimizerWrapper, + MegatronLMSchedulerWrapper, + gather_across_data_parallel_groups, + ) + from .megatron_lm import initialize as megatron_lm_initialize + from .megatron_lm import prepare_data_loader as megatron_lm_prepare_data_loader + from .megatron_lm import prepare_model_optimizer_scheduler as megatron_lm_prepare_model_optimizer_scheduler + from .megatron_lm import prepare_optimizer as megatron_lm_prepare_optimizer + from .megatron_lm import prepare_scheduler as megatron_lm_prepare_scheduler +from .memory import find_executable_batch_size, release_memory +from .other import ( + check_os_kernel, + clean_state_dict_for_safetensors, + compile_regions, + compile_regions_deepspeed, + convert_bytes, + extract_model_from_parallel, + get_module_children_bottom_up, + get_pretty_name, + has_compiled_regions, + is_compiled_module, + is_port_in_use, + load, + merge_dicts, + model_has_dtensor, + recursive_getattr, + save, + wait_for_everyone, + write_basic_config, +) +from .random import set_seed, synchronize_rng_state, synchronize_rng_states +from .torch_xla import install_xla +from .tqdm import tqdm +from .transformer_engine import ( + apply_fp8_autowrap, + contextual_fp8_autocast, + convert_model, + has_transformer_engine_layers, +) diff --git a/accelerate/utils/ao.py b/accelerate/utils/ao.py new file mode 100644 index 0000000000000000000000000000000000000000..5cc417d72829006da38549b42fd68caae27b1204 --- /dev/null +++ b/accelerate/utils/ao.py @@ -0,0 +1,143 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Needed utilities for torchao FP8 training. +""" + +from functools import partial +from typing import TYPE_CHECKING, Callable, Optional + +import torch + +from .imports import is_torchao_available, torchao_required + + +if TYPE_CHECKING: + if is_torchao_available(): + from torchao.float8.float8_linear import Float8LinearConfig + + +def find_first_last_linear_layers(model: torch.nn.Module): + """ + Finds the first and last linear layer names in a model. + + This is needed during FP8 to avoid issues with instability by keeping the first and last layers unquantized. + + Ref: https://x.com/xariusrke/status/1826669142604141052 + """ + first_linear, last_linear = None, None + for name, module in model.named_modules(): + if isinstance(module, torch.nn.Linear): + if first_linear is None: + first_linear = name + last_linear = name + return first_linear, last_linear + + +def filter_linear_layers(module, fqn: str, layers_to_filter: list[str]) -> bool: + """ + A function which will check if `module` is: + - a `torch.nn.Linear` layer + - has in_features and out_features divisible by 16 + - is not part of `layers_to_filter` + + Args: + module (`torch.nn.Module`): + The module to check. + fqn (`str`): + The fully qualified name of the layer. + layers_to_filter (`List[str]`): + The list of layers to filter. + """ + if isinstance(module, torch.nn.Linear): + if module.in_features % 16 != 0 or module.out_features % 16 != 0: + return False + if fqn in layers_to_filter: + return False + return True + + +def filter_first_and_last_linear_layers(module, fqn: str) -> bool: + """ + A filter function which will filter out all linear layers except the first and last. + + + + For stability reasons, we skip the first and last linear layers Otherwise can lead to the model not training or + converging properly + + + + Args: + module (`torch.nn.Module`): + The module to check. + fqn (`str`): + The fully qualified name of the layer. + """ + first_linear, last_linear = find_first_last_linear_layers(module) + return filter_linear_layers(module, fqn, layers_to_filter=[first_linear, last_linear]) + + +@torchao_required +def has_ao_layers(model: torch.nn.Module): + from torchao.float8.float8_linear import Float8Linear + + for name, module in model.named_modules(): + if isinstance(module, Float8Linear): + return True + return False + + +@torchao_required +def convert_model_to_fp8_ao( + model: torch.nn.Module, + config: Optional["Float8LinearConfig"] = None, + module_filter_func: Optional[Callable] = filter_first_and_last_linear_layers, +): + """ + Converts all `nn.Linear` layers in the model (except the first and last) to torchao's `Float8Linear` layer inplace. + + Args: + model (`torch.nn.Module`): + The model to convert. + config (`torchao.float8.Float8LinearConfig`, *optional*): + The configuration for the FP8 training. Recommended to utilize + `torchao.float8.recipe_name_to_linear_config` to generate this. In general, the default config should be + sufficient (what is passed when set to `None`). + module_filter_func (`Callable`, *optional*, defaults to `filter_linear_layers`): + Optional function that must take in a module and layer name, and returns a boolean indicating whether the + module should be converted to FP8. Defaults to `filter_linear_layers`. See it for an example. + + Example: + + ```python + from accelerate.utils.ao import convert_model_to_fp8_ao + from accelerate import Accelerator + + accelerator = Accelerator( + + model = MyModel() + model.to(accelerator.device) + convert_to_float8_training(model) + + model.train() + ``` + """ + from torchao.float8 import convert_to_float8_training + + first_linear, last_linear = find_first_last_linear_layers(model) + if module_filter_func is None: + module_filter_func = partial(filter_linear_layers, layers_to_filter=[first_linear, last_linear]) + convert_to_float8_training(model, module_filter_fn=module_filter_func, config=config) diff --git a/accelerate/utils/bnb.py b/accelerate/utils/bnb.py new file mode 100644 index 0000000000000000000000000000000000000000..cb698e804c492e06614230f0daf43888327c5a37 --- /dev/null +++ b/accelerate/utils/bnb.py @@ -0,0 +1,464 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +import os +from copy import deepcopy +from typing import Optional, Union + +import torch +import torch.nn as nn + +from accelerate.utils.imports import ( + is_4bit_bnb_available, + is_8bit_bnb_available, +) + +from ..big_modeling import dispatch_model, init_empty_weights +from .dataclasses import BnbQuantizationConfig +from .modeling import ( + find_tied_parameters, + get_balanced_memory, + infer_auto_device_map, + load_checkpoint_in_model, + offload_weight, + set_module_tensor_to_device, +) + + +logger = logging.getLogger(__name__) + + +def load_and_quantize_model( + model: torch.nn.Module, + bnb_quantization_config: BnbQuantizationConfig, + weights_location: Optional[Union[str, os.PathLike]] = None, + device_map: Optional[dict[str, Union[int, str, torch.device]]] = None, + no_split_module_classes: Optional[list[str]] = None, + max_memory: Optional[dict[Union[int, str], Union[int, str]]] = None, + offload_folder: Optional[Union[str, os.PathLike]] = None, + offload_state_dict: bool = False, +): + """ + This function will quantize the input model with the associated config passed in `bnb_quantization_config`. If the + model is in the meta device, we will load and dispatch the weights according to the `device_map` passed. If the + model is already loaded, we will quantize the model and put the model on the GPU, + + Args: + model (`torch.nn.Module`): + Input model. The model can be already loaded or on the meta device + bnb_quantization_config (`BnbQuantizationConfig`): + The bitsandbytes quantization parameters + weights_location (`str` or `os.PathLike`): + The folder weights_location to load. It can be: + - a path to a file containing a whole model state dict + - a path to a `.json` file containing the index to a sharded checkpoint + - a path to a folder containing a unique `.index.json` file and the shards of a checkpoint. + - a path to a folder containing a unique pytorch_model.bin file. + device_map (`Dict[str, Union[int, str, torch.device]]`, *optional*): + A map that specifies where each submodule should go. It doesn't need to be refined to each parameter/buffer + name, once a given module name is inside, every submodule of it will be sent to the same device. + no_split_module_classes (`List[str]`, *optional*): + A list of layer class names that should never be split across device (for instance any layer that has a + residual connection). + max_memory (`Dict`, *optional*): + A dictionary device identifier to maximum memory. Will default to the maximum memory available if unset. + offload_folder (`str` or `os.PathLike`, *optional*): + If the `device_map` contains any value `"disk"`, the folder where we will offload weights. + offload_state_dict (`bool`, *optional*, defaults to `False`): + If `True`, will temporarily offload the CPU state dict on the hard drive to avoid getting out of CPU RAM if + the weight of the CPU state dict + the biggest shard does not fit. + + Returns: + `torch.nn.Module`: The quantized model + """ + + load_in_4bit = bnb_quantization_config.load_in_4bit + load_in_8bit = bnb_quantization_config.load_in_8bit + + if load_in_8bit and not is_8bit_bnb_available(): + raise ImportError( + "You have a version of `bitsandbytes` that is not compatible with 8bit quantization," + " make sure you have the latest version of `bitsandbytes` installed." + ) + if load_in_4bit and not is_4bit_bnb_available(): + raise ValueError( + "You have a version of `bitsandbytes` that is not compatible with 4bit quantization," + "make sure you have the latest version of `bitsandbytes` installed." + ) + + modules_on_cpu = [] + # custom device map + if isinstance(device_map, dict) and len(device_map.keys()) > 1: + modules_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]] + + # We keep some modules such as the lm_head in their original dtype for numerical stability reasons + if bnb_quantization_config.skip_modules is None: + bnb_quantization_config.skip_modules = get_keys_to_not_convert(model) + + # add cpu modules to skip modules only for 4-bit modules + if load_in_4bit: + bnb_quantization_config.skip_modules.extend(modules_on_cpu) + modules_to_not_convert = bnb_quantization_config.skip_modules + + # We add the modules we want to keep in full precision + if bnb_quantization_config.keep_in_fp32_modules is None: + bnb_quantization_config.keep_in_fp32_modules = [] + keep_in_fp32_modules = bnb_quantization_config.keep_in_fp32_modules + modules_to_not_convert.extend(keep_in_fp32_modules) + + # compatibility with peft + model.is_loaded_in_4bit = load_in_4bit + model.is_loaded_in_8bit = load_in_8bit + + model_device = get_parameter_device(model) + if model_device.type != "meta": + # quantization of an already loaded model + logger.warning( + "It is not recommended to quantize a loaded model. " + "The model should be instantiated under the `init_empty_weights` context manager." + ) + model = replace_with_bnb_layers(model, bnb_quantization_config, modules_to_not_convert=modules_to_not_convert) + # convert param to the right dtype + dtype = bnb_quantization_config.torch_dtype + for name, param in model.named_parameters(): + if any(module_to_keep_in_fp32 in name for module_to_keep_in_fp32 in keep_in_fp32_modules): + param.data = param.data.to(torch.float32) + elif torch.is_floating_point(param): + param.data = param.data.to(dtype) + if model_device.type == "cuda": + model.cuda(torch.cuda.current_device()) + torch.cuda.empty_cache() + elif torch.cuda.is_available(): + model.to(torch.cuda.current_device()) + elif torch.xpu.is_available(): + model.to(torch.xpu.current_device()) + else: + raise RuntimeError("No GPU or Intel XPU found. A GPU or Intel XPU is needed for quantization.") + logger.info( + f"The model device type is {model_device.type}. However, gpu or intel xpu is needed for quantization." + "We move the model to it." + ) + return model + + elif weights_location is None: + raise RuntimeError( + f"`weights_location` needs to be the folder path containing the weights of the model, but we found {weights_location} " + ) + + else: + with init_empty_weights(): + model = replace_with_bnb_layers( + model, bnb_quantization_config, modules_to_not_convert=modules_to_not_convert + ) + device_map = get_quantized_model_device_map( + model, + bnb_quantization_config, + device_map, + max_memory=max_memory, + no_split_module_classes=no_split_module_classes, + ) + if offload_state_dict is None and device_map is not None and "disk" in device_map.values(): + offload_state_dict = True + + offload = any(x in list(device_map.values()) for x in ["cpu", "disk"]) + + load_checkpoint_in_model( + model, + weights_location, + device_map, + dtype=bnb_quantization_config.torch_dtype, + offload_folder=offload_folder, + offload_state_dict=offload_state_dict, + keep_in_fp32_modules=bnb_quantization_config.keep_in_fp32_modules, + offload_8bit_bnb=load_in_8bit and offload, + ) + return dispatch_model(model, device_map=device_map, offload_dir=offload_folder) + + +def get_quantized_model_device_map( + model, bnb_quantization_config, device_map=None, max_memory=None, no_split_module_classes=None +): + if device_map is None: + if torch.cuda.is_available(): + device_map = {"": torch.cuda.current_device()} + elif torch.xpu.is_available(): + device_map = {"": torch.xpu.current_device()} + else: + raise RuntimeError("No GPU found. A GPU is needed for quantization.") + logger.info("The device_map was not initialized.Setting device_map to `{'':torch.cuda.current_device()}`.") + + if isinstance(device_map, str): + if device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]: + raise ValueError( + "If passing a string for `device_map`, please choose 'auto', 'balanced', 'balanced_low_0' or " + "'sequential'." + ) + + special_dtypes = {} + special_dtypes.update( + { + name: bnb_quantization_config.torch_dtype + for name, _ in model.named_parameters() + if any(m in name for m in bnb_quantization_config.skip_modules) + } + ) + special_dtypes.update( + { + name: torch.float32 + for name, _ in model.named_parameters() + if any(m in name for m in bnb_quantization_config.keep_in_fp32_modules) + } + ) + + kwargs = {} + kwargs["special_dtypes"] = special_dtypes + kwargs["no_split_module_classes"] = no_split_module_classes + kwargs["dtype"] = bnb_quantization_config.target_dtype + + # get max_memory for each device. + if device_map != "sequential": + max_memory = get_balanced_memory( + model, + low_zero=(device_map == "balanced_low_0"), + max_memory=max_memory, + **kwargs, + ) + + kwargs["max_memory"] = max_memory + device_map = infer_auto_device_map(model, **kwargs) + + if isinstance(device_map, dict): + # check if don't have any quantized module on the cpu + modules_not_to_convert = bnb_quantization_config.skip_modules + bnb_quantization_config.keep_in_fp32_modules + + device_map_without_some_modules = { + key: device_map[key] for key in device_map.keys() if key not in modules_not_to_convert + } + for device in ["cpu", "disk"]: + if device in device_map_without_some_modules.values(): + if bnb_quantization_config.load_in_4bit: + raise ValueError( + """ + Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit + the quantized model. If you want to dispatch the model on the CPU or the disk while keeping + these modules in `torch_dtype`, you need to pass a custom `device_map` to + `load_and_quantize_model`. Check + https://huggingface.co/docs/accelerate/main/en/usage_guides/quantization#offload-modules-to-cpu-and-disk + for more details. + """ + ) + else: + logger.info( + "Some modules are are offloaded to the CPU or the disk. Note that these modules will be converted to 8-bit" + ) + del device_map_without_some_modules + return device_map + + +def replace_with_bnb_layers(model, bnb_quantization_config, modules_to_not_convert=None, current_key_name=None): + """ + A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules or by `bnb.nn.Linear4bit` + modules from the `bitsandbytes`library. The function will be run recursively and replace `torch.nn.Linear` modules. + + Parameters: + model (`torch.nn.Module`): + Input model or `torch.nn.Module` as the function is run recursively. + modules_to_not_convert (`List[str]`): + Names of the modules to not quantize convert. In practice we keep the `lm_head` in full precision for + numerical stability reasons. + current_key_name (`List[str]`, *optional*): + An array to track the current key of the recursion. This is used to check whether the current key (part of + it) is not in the list of modules to not convert. + """ + + if modules_to_not_convert is None: + modules_to_not_convert = [] + + model, has_been_replaced = _replace_with_bnb_layers( + model, bnb_quantization_config, modules_to_not_convert, current_key_name + ) + if not has_been_replaced: + logger.warning( + "You are loading your model in 8bit or 4bit but no linear modules were found in your model." + " this can happen for some architectures such as gpt2 that uses Conv1D instead of Linear layers." + " Please double check your model architecture, or submit an issue on github if you think this is" + " a bug." + ) + return model + + +def _replace_with_bnb_layers( + model, + bnb_quantization_config, + modules_to_not_convert=None, + current_key_name=None, +): + """ + Private method that wraps the recursion for module replacement. + + Returns the converted model and a boolean that indicates if the conversion has been successful or not. + """ + # bitsandbytes will initialize device(e.g. CUDA, XPU) on import, so it needs to be imported lazily + import bitsandbytes as bnb + + has_been_replaced = False + for name, module in model.named_children(): + if current_key_name is None: + current_key_name = [] + current_key_name.append(name) + if isinstance(module, nn.Linear) and name not in modules_to_not_convert: + # Check if the current key is not in the `modules_to_not_convert` + current_key_name_str = ".".join(current_key_name) + proceed = True + for key in modules_to_not_convert: + if ( + (key in current_key_name_str) and (key + "." in current_key_name_str) + ) or key == current_key_name_str: + proceed = False + break + if proceed: + # Load bnb module with empty weight and replace ``nn.Linear` module + if bnb_quantization_config.load_in_8bit: + bnb_module = bnb.nn.Linear8bitLt( + module.in_features, + module.out_features, + module.bias is not None, + has_fp16_weights=False, + threshold=bnb_quantization_config.llm_int8_threshold, + ) + elif bnb_quantization_config.load_in_4bit: + bnb_module = bnb.nn.Linear4bit( + module.in_features, + module.out_features, + module.bias is not None, + bnb_quantization_config.bnb_4bit_compute_dtype, + compress_statistics=bnb_quantization_config.bnb_4bit_use_double_quant, + quant_type=bnb_quantization_config.bnb_4bit_quant_type, + ) + else: + raise ValueError("load_in_8bit and load_in_4bit can't be both False") + bnb_module.weight.data = module.weight.data + if module.bias is not None: + bnb_module.bias.data = module.bias.data + bnb_module.requires_grad_(False) + setattr(model, name, bnb_module) + has_been_replaced = True + if len(list(module.children())) > 0: + _, _has_been_replaced = _replace_with_bnb_layers( + module, bnb_quantization_config, modules_to_not_convert, current_key_name + ) + has_been_replaced = has_been_replaced | _has_been_replaced + # Remove the last key for recursion + current_key_name.pop(-1) + return model, has_been_replaced + + +def get_keys_to_not_convert(model): + r""" + An utility function to get the key of the module to keep in full precision if any For example for CausalLM modules + we may want to keep the lm_head in full precision for numerical stability reasons. For other architectures, we want + to keep the tied weights of the model. The function will return a list of the keys of the modules to not convert in + int8. + + Parameters: + model (`torch.nn.Module`): + Input model + """ + # Create a copy of the model + with init_empty_weights(): + tied_model = deepcopy(model) # this has 0 cost since it is done inside `init_empty_weights` context manager` + + tied_params = find_tied_parameters(tied_model) + # For compatibility with Accelerate < 0.18 + if isinstance(tied_params, dict): + tied_keys = sum(list(tied_params.values()), []) + list(tied_params.keys()) + else: + tied_keys = sum(tied_params, []) + has_tied_params = len(tied_keys) > 0 + + # Check if it is a base model + is_base_model = False + if hasattr(model, "base_model_prefix"): + is_base_model = not hasattr(model, model.base_model_prefix) + + # Ignore this for base models (BertModel, GPT2Model, etc.) + if (not has_tied_params) and is_base_model: + return [] + + # otherwise they have an attached head + list_modules = list(model.named_children()) + list_last_module = [list_modules[-1][0]] + + # add last module together with tied weights + intersection = set(list_last_module) - set(tied_keys) + list_untouched = list(set(tied_keys)) + list(intersection) + + # remove ".weight" from the keys + names_to_remove = [".weight", ".bias"] + filtered_module_names = [] + for name in list_untouched: + for name_to_remove in names_to_remove: + if name_to_remove in name: + name = name.replace(name_to_remove, "") + filtered_module_names.append(name) + + return filtered_module_names + + +def has_4bit_bnb_layers(model): + """Check if we have `bnb.nn.Linear4bit` or `bnb.nn.Linear8bitLt` layers inside our model""" + # bitsandbytes will initialize device(e.g. CUDA, XPU) on import, so it needs to be imported lazily + import bitsandbytes as bnb + + for m in model.modules(): + if isinstance(m, bnb.nn.Linear4bit): + return True + return False + + +def get_parameter_device(parameter: nn.Module): + return next(parameter.parameters()).device + + +def quantize_and_offload_8bit(model, param, param_name, new_dtype, offload_folder, offload_index, fp16_statistics): + # if it is not quantized, we quantize and offload the quantized weights and the SCB stats + if fp16_statistics is None: + set_module_tensor_to_device(model, param_name, 0, dtype=new_dtype, value=param) + tensor_name = param_name + module = model + if "." in tensor_name: + splits = tensor_name.split(".") + for split in splits[:-1]: + new_module = getattr(module, split) + if new_module is None: + raise ValueError(f"{module} has no attribute {split}.") + module = new_module + tensor_name = splits[-1] + # offload weights + module._parameters[tensor_name].requires_grad = False + offload_weight(module._parameters[tensor_name], param_name, offload_folder, index=offload_index) + if hasattr(module._parameters[tensor_name], "SCB"): + offload_weight( + module._parameters[tensor_name].SCB, + param_name.replace("weight", "SCB"), + offload_folder, + index=offload_index, + ) + else: + offload_weight(param, param_name, offload_folder, index=offload_index) + offload_weight(fp16_statistics, param_name.replace("weight", "SCB"), offload_folder, index=offload_index) + + set_module_tensor_to_device(model, param_name, "meta", dtype=new_dtype, value=torch.empty(*param.size())) diff --git a/accelerate/utils/constants.py b/accelerate/utils/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..0eba7bc3d61cb012366f43a31873871f6d29a83d --- /dev/null +++ b/accelerate/utils/constants.py @@ -0,0 +1,108 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import operator as op + +import torch + + +SCALER_NAME = "scaler.pt" +MODEL_NAME = "pytorch_model" +SAFE_MODEL_NAME = "model" +RNG_STATE_NAME = "random_states" +OPTIMIZER_NAME = "optimizer" +SCHEDULER_NAME = "scheduler" +SAMPLER_NAME = "sampler" +PROFILE_PATTERN_NAME = "profile_{suffix}.json" +WEIGHTS_NAME = f"{MODEL_NAME}.bin" +WEIGHTS_PATTERN_NAME = "pytorch_model{suffix}.bin" +WEIGHTS_INDEX_NAME = f"{WEIGHTS_NAME}.index.json" +SAFE_WEIGHTS_NAME = f"{SAFE_MODEL_NAME}.safetensors" +SAFE_WEIGHTS_PATTERN_NAME = "model{suffix}.safetensors" +SAFE_WEIGHTS_INDEX_NAME = f"{SAFE_WEIGHTS_NAME}.index.json" +SAGEMAKER_PYTORCH_VERSION = "1.10.2" +SAGEMAKER_PYTHON_VERSION = "py38" +SAGEMAKER_TRANSFORMERS_VERSION = "4.17.0" +SAGEMAKER_PARALLEL_EC2_INSTANCES = ["ml.p3.16xlarge", "ml.p3dn.24xlarge", "ml.p4dn.24xlarge"] +FSDP_SHARDING_STRATEGY = ["FULL_SHARD", "SHARD_GRAD_OP", "NO_SHARD", "HYBRID_SHARD", "HYBRID_SHARD_ZERO2"] +FSDP_AUTO_WRAP_POLICY = ["TRANSFORMER_BASED_WRAP", "SIZE_BASED_WRAP", "NO_WRAP"] +FSDP_BACKWARD_PREFETCH = ["BACKWARD_PRE", "BACKWARD_POST", "NO_PREFETCH"] +FSDP_STATE_DICT_TYPE = ["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"] +FSDP2_STATE_DICT_TYPE = ["SHARDED_STATE_DICT", "FULL_STATE_DICT"] +FSDP_PYTORCH_VERSION = ( + "2.1.0.a0+32f93b1" # Technically should be 2.1.0, but MS-AMP uses this specific prerelease in their Docker image. +) +FSDP2_PYTORCH_VERSION = "2.6.0" +DTENSOR_PYTORCH_VERSION = "2.5.0" +FSDP_MODEL_NAME = "pytorch_model_fsdp" +DEEPSPEED_MULTINODE_LAUNCHERS = ["pdsh", "standard", "openmpi", "mvapich", "mpich", "nossh", "slurm"] +TORCH_DYNAMO_MODES = ["default", "reduce-overhead", "max-autotune"] +ELASTIC_LOG_LINE_PREFIX_TEMPLATE_PYTORCH_VERSION = "2.2.0" +XPU_PROFILING_AVAILABLE_PYTORCH_VERSION = "2.4.0" +MITA_PROFILING_AVAILABLE_PYTORCH_VERSION = "2.1.0" +BETA_TP_AVAILABLE_PYTORCH_VERSION = "2.3.0" + +BETA_TP_AVAILABLE_TRANSFORMERS_VERSION = "4.52.0" +BETA_CP_AVAILABLE_PYTORCH_VERSION = "2.6.0" +BETA_SP_AVAILABLE_DEEPSPEED_VERSION = "0.18.2" + +STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt} + +# These are the args for `torch.distributed.launch` for pytorch < 1.9 +TORCH_LAUNCH_PARAMS = [ + "nnodes", + "nproc_per_node", + "rdzv_backend", + "rdzv_endpoint", + "rdzv_id", + "rdzv_conf", + "standalone", + "max_restarts", + "monitor_interval", + "start_method", + "role", + "module", + "m", + "no_python", + "run_path", + "log_dir", + "r", + "redirects", + "t", + "tee", + "node_rank", + "master_addr", + "master_port", +] + +CUDA_DISTRIBUTED_TYPES = ["DEEPSPEED", "MULTI_GPU", "FSDP", "MEGATRON_LM", "TP"] +TORCH_DISTRIBUTED_OPERATION_TYPES = CUDA_DISTRIBUTED_TYPES + [ + "MULTI_NPU", + "MULTI_MLU", + "MULTI_SDAA", + "MULTI_MUSA", + "MULTI_XPU", + "MULTI_CPU", + "MULTI_HPU", + "MULTI_NEURON", +] +SUPPORTED_PYTORCH_LAYERS_FOR_UPCASTING = ( + torch.nn.Conv1d, + torch.nn.Conv2d, + torch.nn.Conv3d, + torch.nn.ConvTranspose1d, + torch.nn.ConvTranspose2d, + torch.nn.ConvTranspose3d, + torch.nn.Linear, +) diff --git a/accelerate/utils/dataclasses.py b/accelerate/utils/dataclasses.py new file mode 100644 index 0000000000000000000000000000000000000000..43e1d904886523744c75245d539fe3a74c73b3f9 --- /dev/null +++ b/accelerate/utils/dataclasses.py @@ -0,0 +1,3199 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +General namespace and dataclass related classes +""" + +import argparse +import copy +import enum +import functools +import logging +import os +import warnings +from collections.abc import Iterable +from contextlib import contextmanager +from dataclasses import dataclass, field +from datetime import timedelta +from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union, get_args + +import torch + +from .constants import ( + BETA_CP_AVAILABLE_PYTORCH_VERSION, + BETA_TP_AVAILABLE_PYTORCH_VERSION, + BETA_TP_AVAILABLE_TRANSFORMERS_VERSION, + FSDP2_PYTORCH_VERSION, + FSDP_AUTO_WRAP_POLICY, + FSDP_BACKWARD_PREFETCH, + FSDP_SHARDING_STRATEGY, + MITA_PROFILING_AVAILABLE_PYTORCH_VERSION, + XPU_PROFILING_AVAILABLE_PYTORCH_VERSION, +) +from .environment import parse_flag_from_env, str_to_bool +from .imports import ( + is_cuda_available, + is_hpu_available, + is_mlu_available, + is_msamp_available, + is_musa_available, + is_npu_available, + is_torchao_available, + is_transformer_engine_available, + is_xpu_available, +) +from .versions import compare_versions, is_torch_version + + +if TYPE_CHECKING: + # Mock imports for type checking + from torchao.float8 import Float8LinearConfig + + +logger = logging.getLogger(__name__) + + +class KwargsHandler: + """ + Internal mixin that implements a `to_kwargs()` method for a dataclass. + """ + + def to_dict(self): + return copy.deepcopy(self.__dict__) + + def to_kwargs(self): + """ + Returns a dictionary containing the attributes with values different from the default of this class. + """ + # import clear_environment here to avoid circular import problem + from .environment import clear_environment + + with clear_environment(): + default_dict = self.__class__().to_dict() + this_dict = self.to_dict() + return {k: v for k, v in this_dict.items() if default_dict[k] != v} + + +class EnumWithContains(enum.EnumMeta): + "A metaclass that adds the ability to check if `self` contains an item with the `in` operator" + + def __contains__(cls, item): + try: + cls(item) + except ValueError: + return False + return True + + +class BaseEnum(enum.Enum, metaclass=EnumWithContains): + "An enum class that can get the value of an item with `str(Enum.key)`" + + def __str__(self): + return self.value + + @classmethod + def list(cls): + "Method to list all the possible items in `cls`" + return list(map(str, cls)) + + +@dataclass +class AutocastKwargs(KwargsHandler): + """ + Use this object in your [`Accelerator`] to customize how `torch.autocast` behaves. Please refer to the + documentation of this [context manager](https://pytorch.org/docs/stable/amp.html#torch.autocast) for more + information on each argument. + + Example: + + ```python + from accelerate import Accelerator + from accelerate.utils import AutocastKwargs + + kwargs = AutocastKwargs(cache_enabled=True) + accelerator = Accelerator(kwargs_handlers=[kwargs]) + ``` + """ + + enabled: bool = True + cache_enabled: Optional[bool] = None + + +class DDPCommunicationHookType(BaseEnum): + """ + Represents a type of communication hook used in DDP. + + Values: + + - **NO** -- no communication hook + - **FP16** -- DDP communication hook to compress the gradients in FP16 + - **BF16** -- DDP communication hook to compress the gradients in BF16 + - **POWER_SGD** -- DDP communication hook to use PowerSGD + - **BATCHED_POWER_SGD** -- DDP communication hook to use batched PowerSGD + """ + + NO = "no" + FP16 = "fp16" + BF16 = "bf16" + POWER_SGD = "power_sgd" + BATCHED_POWER_SGD = "batched_power_sgd" + + +@dataclass +class DistributedDataParallelKwargs(KwargsHandler): + """ + Use this object in your [`Accelerator`] to customize how your model is wrapped in a + `torch.nn.parallel.DistributedDataParallel`. Please refer to the documentation of this + [wrapper](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html) for more + information on each argument. + + + + `gradient_as_bucket_view` is only available in PyTorch 1.7.0 and later versions. + + `static_graph` is only available in PyTorch 1.11.0 and later versions. + + + + Example: + + ```python + from accelerate import Accelerator + from accelerate.utils import DistributedDataParallelKwargs + + kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + accelerator = Accelerator(kwargs_handlers=[kwargs]) + ``` + """ + + dim: int = 0 + broadcast_buffers: bool = True + bucket_cap_mb: int = 25 + find_unused_parameters: bool = False + check_reduction: bool = False + gradient_as_bucket_view: bool = False + static_graph: bool = False + + comm_hook: DDPCommunicationHookType = DDPCommunicationHookType.NO + comm_wrapper: Literal[ + DDPCommunicationHookType.NO, + DDPCommunicationHookType.FP16, + DDPCommunicationHookType.BF16, + ] = DDPCommunicationHookType.NO + comm_state_option: dict = field(default_factory=dict) + + def to_dict(self, ignore_keys=("comm_hook", "comm_wrapper", "comm_state_option")): + return {k: v for k, v in super().to_dict().items() if k not in ignore_keys} + + def register_comm_hook(self, model): + from torch.distributed.algorithms.ddp_comm_hooks import ( + default_hooks, + powerSGD_hook, + ) + + hook_map: dict[DDPCommunicationHookType, Callable] = { + DDPCommunicationHookType.FP16: default_hooks.fp16_compress_hook, + DDPCommunicationHookType.BF16: default_hooks.bf16_compress_hook, + DDPCommunicationHookType.POWER_SGD: powerSGD_hook.powerSGD_hook, + DDPCommunicationHookType.BATCHED_POWER_SGD: powerSGD_hook.batched_powerSGD_hook, + } + + wrapper_map: dict[DDPCommunicationHookType, Callable] = { + DDPCommunicationHookType.FP16: default_hooks.fp16_compress_wrapper, + DDPCommunicationHookType.BF16: default_hooks.bf16_compress_wrapper, + } + + hook: Optional[Callable] = hook_map.get(self.comm_hook) + wrapper: Optional[Callable] = wrapper_map.get(self.comm_wrapper) + + if hook and wrapper: + hook = wrapper(hook) + + if hook: + state = ( + powerSGD_hook.PowerSGDState(None, **self.comm_state_option) + if self.comm_hook + in ( + DDPCommunicationHookType.POWER_SGD, + DDPCommunicationHookType.BATCHED_POWER_SGD, + ) + else None + ) + model.register_comm_hook( + state=state, + hook=hook, + ) + + +@dataclass +class GradScalerKwargs(KwargsHandler): + """ + Use this object in your [`Accelerator`] to customize the behavior of mixed precision, specifically how the + `torch.amp.GradScaler` or `torch.cuda.amp.GradScaler` used is created. Please refer to the documentation of this + [scaler](https://pytorch.org/docs/stable/amp.html?highlight=gradscaler) for more information on each argument. + + + + `torch.cuda.amp.GradScaler` is only available in PyTorch 1.5.0 and later versions, and `torch.amp.GradScaler` is + only available in PyTorch 2.4.0 and later versions. + + + + Example: + + ```python + from accelerate import Accelerator + from accelerate.utils import GradScalerKwargs + + kwargs = GradScalerKwargs(backoff_factor=0.25) + accelerator = Accelerator(kwargs_handlers=[kwargs]) + ``` + """ + + init_scale: float = 65536.0 + growth_factor: float = 2.0 + backoff_factor: float = 0.5 + growth_interval: int = 2000 + enabled: bool = True + + +@dataclass +class InitProcessGroupKwargs(KwargsHandler): + """ + Use this object in your [`Accelerator`] to customize the initialization of the distributed processes. Please refer + to the documentation of this + [method](https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group) for more + information on each argument. + + Note: If `timeout` is set to `None`, the default will be based upon how `backend` is set. + + ```python + from datetime import timedelta + from accelerate import Accelerator + from accelerate.utils import InitProcessGroupKwargs + + kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=800)) + accelerator = Accelerator(kwargs_handlers=[kwargs]) + ``` + """ + + backend: Optional[str] = "nccl" + init_method: Optional[str] = None + timeout: Optional[timedelta] = None + + def __post_init__(self): + if self.timeout is None: + seconds = 1800 if self.backend != "nccl" else 600 + self.timeout = timedelta(seconds=seconds) + + +# Literals +Backend = Literal["MSAMP", "TE"] +OptLevel = Literal["O1", "O2"] +FP8Format = Literal["HYBRID", "E4M3", "E5M2"] +AmaxComputeAlgorithm = Literal["max", "most_recent"] + + +# FP8 training recipe kwargs +@dataclass +class AORecipeKwargs(KwargsHandler): + """ + Use this object in your [`Accelerator`] to customize the initialization of the recipe for FP8 mixed precision + training with `torchao` FP8. + + Args: + config (`torchao.float8.Float8LinearConfig`, *optional*, default to `None`): + The configuration for the FP8 training. If `None`, a default config will be created with sensible + defaults for most use cases: + - `pad_inner_dim=True`: Pads matrix dimensions to be divisible by 16, required for `torch._scaled_mm` + operations to prevent runtime errors. + - `enable_fsdp_float8_all_gather=True`: Enables FP8 all-gather for FSDP2. This provides memory bandwidth + savings by casting parameters before the all-gather operation, saving 50% bandwidth compared to BF16. + + You can override these defaults by providing your own `Float8LinearConfig` instance. + module_filter_func (`Callable`, *optional*, default to `None`): + Optional function that must take in a module and layer name, and returns a boolean indicating whether the + module should be converted to FP8. Defaults to `accelerate.utils.ao.filter_linear_layers`. See it for an + example. + """ + + config: Optional["Float8LinearConfig"] = None + module_filter_func: Optional[Callable] = None + pad_inner_dim: Optional[bool] = None + enable_fsdp_float8_all_gather: Optional[bool] = None + + def __post_init__(self): + env_prefix = "ACCELERATE_FP8_" + if not is_torchao_available(): + raise ImportError("TorchAO is not available. Please install it or use a different backend.") + + if self.config is None: + from torchao.float8 import Float8LinearConfig + + # Check environment variables for overrides + if self.pad_inner_dim is None: + self.pad_inner_dim = parse_flag_from_env(env_prefix + "PAD_INNER_DIM", default=True) + if self.enable_fsdp_float8_all_gather is None: + self.enable_fsdp_float8_all_gather = parse_flag_from_env( + env_prefix + "ENABLE_FSDP_FLOAT8_ALL_GATHER", default=True + ) + self.config = Float8LinearConfig( + pad_inner_dim=self.pad_inner_dim, + enable_fsdp_float8_all_gather=self.enable_fsdp_float8_all_gather, + ) + + +@dataclass +class TERecipeKwargs(KwargsHandler): + """ + Use this object in your [`Accelerator`] to customize the initialization of the recipe for FP8 mixed precision + training with `transformer-engine`. + + + + For more information on the args, please refer to the API + [documentation](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/common.html). + + + + ```python + from accelerate import Accelerator + from accelerate.utils import TERecipeKwargs + + kwargs = TERecipeKwargs(fp8_format="HYBRID") + accelerator = Accelerator(mixed_precision="fp8", kwargs_handlers=[kwargs]) + ``` + + Args: + use_autocast_during_eval (`bool`, *optional*, default to `False`): + Whether to use FP8 autocast during eval mode. Generally better metrics are found when this is `False`. + margin (`int`, *optional*, default to 0): + The margin to use for the gradient scaling. + interval (`int`, *optional*, default to 1): + The interval to use for how often the scaling factor is recomputed. + fp8_format (`str`, *optional*, default to "HYBRID"): + The format to use for the FP8 recipe. Must be one of `HYBRID`, `E4M3` or `E5M2`. (Generally `HYBRID` for + training, `E4M3` or `E5M2` for evaluation) + amax_history_len (`int`, *optional*, default to 1024): + The length of the history to use for the scaling factor computation + amax_compute_algo (`str`, *optional*, default to "most_recent"): + The algorithm to use for the scaling factor computation. Must be one of `max` or `most_recent`. + override_linear_precision (`tuple` of three `bool`, *optional*, default to `(False, False, False)`): + Whether or not to execute `fprop`, `dgrad`, and `wgrad` GEMMS in higher precision. + """ + + use_autocast_during_eval: Optional[bool] = None + margin: Optional[int] = None + interval: Optional[int] = None + fp8_format: FP8Format = None + amax_history_len: Optional[int] = None + amax_compute_algo: AmaxComputeAlgorithm = None + override_linear_precision: tuple[bool, bool, bool] = None + use_mxfp8_block_scaling: Optional[bool] = None + + def __post_init__(self): + env_prefix = "ACCELERATE_FP8_" + if not is_transformer_engine_available(): + raise ImportError("TransformerEngine is not available. Please install it or use a different backend.") + if self.use_autocast_during_eval is None: + self.use_autocast_during_eval = parse_flag_from_env(env_prefix + "USE_AUTOCAST_DURING_EVAL") + if self.margin is None: + self.margin = int(os.environ.get(env_prefix + "MARGIN", 0)) + if self.interval is None: + self.interval = int(os.environ.get(env_prefix + "INTERVAL", 1)) + if self.fp8_format is None: + self.fp8_format = os.environ.get(env_prefix + "FORMAT", "HYBRID") + self.fp8_format = self.fp8_format.upper() + if self.fp8_format not in get_args(FP8Format): + raise ValueError(f"`fp8_format` must be one of {' or '.join(get_args(FP8Format))}.") + if self.amax_compute_algo is None: + self.amax_compute_algo = os.environ.get(env_prefix + "AMAX_COMPUTE_ALGO", "most_recent") + self.amax_compute_algo = self.amax_compute_algo.lower() + if self.amax_compute_algo not in get_args(AmaxComputeAlgorithm): + raise ValueError(f"`amax_compute_algo` must be one of {' or '.join(get_args(AmaxComputeAlgorithm))}") + if self.amax_history_len is None: + self.amax_history_len = int(os.environ.get(env_prefix + "AMAX_HISTORY_LEN", 1024)) + if self.override_linear_precision is None: + fprop = parse_flag_from_env(env_prefix + "OVERRIDE_FPROP") + dgrad = parse_flag_from_env(env_prefix + "OVERRIDE_DGRAD") + wgrad = parse_flag_from_env(env_prefix + "OVERRIDE_WGRAD") + self.override_linear_precision = (fprop, dgrad, wgrad) + if self.use_mxfp8_block_scaling is None: + self.use_mxfp8_block_scaling = parse_flag_from_env(env_prefix + "USE_MXFP8_BLOCK_SCALING") + + +@dataclass +class MSAMPRecipeKwargs(KwargsHandler): + """ + Use this object in your [`Accelerator`] to customize the initialization of the recipe for FP8 mixed precision + training with `ms-amp`. + """ + + opt_level: OptLevel = None + + def __post_init__(self): + env_prefix = "ACCELERATE_FP8_" + if self.opt_level is None: + self.opt_level = os.environ.get(env_prefix + "OPT_LEVEL", "O2") + if self.opt_level not in get_args(OptLevel): + raise ValueError(f"`opt_level` must be one of {' or '.join(get_args(OptLevel))}") + + +@dataclass +class FP8RecipeKwargs(TERecipeKwargs, MSAMPRecipeKwargs): + """ + Deprecated. Please use one of the proper FP8 recipe kwargs classes such as `TERecipeKwargs` or `MSAMPRecipeKwargs` + instead. + """ + + backend: Backend = None + + def __post_init__(self): + env_prefix = "ACCELERATE_FP8_" + warnings.warn( + "FP8RecipeKwargs is deprecated and will be removed in Accelerate v2.0.0. " + "Please use one of the proper FP8 recipe kwargs classes such as TERecipeKwargs or MSAMPRecipeKwargs instead.", + FutureWarning, + ) + default_backend = "msamp" if is_msamp_available() else "te" + if self.backend is None: + self.backend = os.environ.get(env_prefix + "BACKEND", default_backend) + self.backend = self.backend.upper() + if self.backend not in get_args(Backend): + raise ValueError("`backend` must be 'MSAMP' or 'TE' (TransformerEngine) to use `FP8RecipeKwargs`.") + super().__post_init__() + + +# Literal +ProfilerActivity = Literal["cpu", "xpu", "mtia", "cuda", "hpu"] + + +@dataclass +class ProfileKwargs(KwargsHandler): + """ + Use this object in your [`Accelerator`] to customize the initialization of the profiler. Please refer to the + documentation of this [context manager](https://pytorch.org/docs/stable/profiler.html#torch.profiler.profile) for + more information on each argument. + + + + `torch.profiler` is only available in PyTorch 1.8.1 and later versions. + + + + Example: + + ```python + from accelerate import Accelerator + from accelerate.utils import ProfileKwargs + + kwargs = ProfileKwargs(activities=["cpu", "cuda"]) + accelerator = Accelerator(kwargs_handlers=[kwargs]) + ``` + + Args: + activities (`List[str]`, *optional*, default to `None`): + The list of activity groups to use in profiling. Must be one of `"cpu"`, `"xpu"`, `"mtia"`, "hpu" or + `"cuda"`. + schedule_option (`Dict[str, int]`, *optional*, default to `None`): + The schedule option to use for the profiler. Available keys are `wait`, `warmup`, `active`, `repeat` and + `skip_first`. The profiler will skip the first `skip_first` steps, then wait for `wait` steps, then do the + warmup for the next `warmup` steps, then do the active recording for the next `active` steps and then + repeat the cycle starting with `wait` steps. The optional number of cycles is specified with the `repeat` + parameter, the zero value means that the cycles will continue until the profiling is finished. + on_trace_ready (`Callable`, *optional*, default to `None`): + Callable that is called at each step when schedule returns `ProfilerAction.RECORD_AND_SAVE` during the + profiling. + record_shapes (`bool`, *optional*, default to `False`): + Save information about operator’s input shapes. + profile_memory (`bool`, *optional*, default to `False`): + Track tensor memory allocation/deallocation + with_stack (`bool`, *optional*, default to `False`): + Record source information (file and line number) for the ops. + with_flops (`bool`, *optional*, default to `False`): + Use formula to estimate the FLOPS of specific operators + with_modules (`bool`, *optional*, default to `False`): + Record module hierarchy (including function names) corresponding to the callstack of the op. + output_trace_dir (`str`, *optional*, default to `None`): + Exports the collected trace in Chrome JSON format. Chrome use 'chrome://tracing' view json file. Defaults + to None, which means profiling does not store json files. + """ + + activities: Optional[list[ProfilerActivity]] = None + schedule_option: Optional[dict[str, int]] = None + on_trace_ready: Optional[Callable] = None + record_shapes: bool = False + profile_memory: bool = False + with_stack: bool = False + with_flops: bool = False + with_modules: bool = False + output_trace_dir: Optional[str] = None + + def _get_profiler_activity(self, activity: ProfilerActivity) -> torch.profiler.ProfilerActivity: + """Get the profiler activity from the string. + + Args: + activity (str): The profiler activity name. + + Returns: + torch.profiler.ProfilerActivity: The profiler activity. + """ + + profiler_activity_map: dict[str, torch.profiler.ProfilerActivity] = { + "cpu": torch.profiler.ProfilerActivity.CPU, + "cuda": torch.profiler.ProfilerActivity.CUDA, + } + + if is_hpu_available(): + profiler_activity_map["hpu"] = torch.profiler.ProfilerActivity.HPU + + if is_torch_version(">=", XPU_PROFILING_AVAILABLE_PYTORCH_VERSION): + if torch.xpu.is_available(): + profiler_activity_map["xpu"] = torch.profiler.ProfilerActivity.XPU + + if is_torch_version(">=", MITA_PROFILING_AVAILABLE_PYTORCH_VERSION): + if torch.mtia.is_available(): + profiler_activity_map["mtia"] = torch.profiler.ProfilerActivity.MTIA + + if activity not in profiler_activity_map: + raise ValueError(f"Invalid profiler activity: {activity}. Must be one of {list(profiler_activity_map)}.") + return profiler_activity_map[activity] + + def build(self) -> torch.profiler.profile: + """ + Build a profiler object with the current configuration. + + Returns: + torch.profiler.profile: The profiler object. + """ + activities: Optional[list[ProfilerActivity]] = None + if self.activities is not None: + activities = [self._get_profiler_activity(activity) for activity in self.activities] + schedule: Optional[torch.profiler.schedule] = None + if self.schedule_option is not None: + schedule = torch.profiler.schedule(**self.schedule_option) + + return torch.profiler.profile( + activities=activities, + schedule=schedule, + on_trace_ready=self.on_trace_ready, + record_shapes=self.record_shapes, + profile_memory=self.profile_memory, + with_stack=self.with_stack, + with_flops=self.with_flops, + with_modules=self.with_modules, + ) + + +class DistributedType(str, enum.Enum): + """ + Represents a type of distributed environment. + + Values: + + - **NO** -- Not a distributed environment, just a single process. + - **MULTI_CPU** -- Distributed on multiple CPU nodes. + - **MULTI_GPU** -- Distributed on multiple GPUs. + - **MULTI_MLU** -- Distributed on multiple MLUs. + - **MULTI_SDAA** -- Distributed on multiple SDAAs. + - **MULTI_MUSA** -- Distributed on multiple MUSAs. + - **MULTI_NPU** -- Distributed on multiple NPUs. + - **MULTI_XPU** -- Distributed on multiple XPUs. + - **MULTI_HPU** -- Distributed on multiple HPUs. + - **MULTI_NEURON** -- Distributed on multiple Neuron cores. + - **DEEPSPEED** -- Using DeepSpeed. + - **XLA** -- Using TorchXLA. + """ + + # Subclassing str as well as Enum allows the `DistributedType` to be JSON-serializable out of the box. + NO = "NO" + MULTI_CPU = "MULTI_CPU" + MULTI_GPU = "MULTI_GPU" + MULTI_NPU = "MULTI_NPU" + MULTI_MLU = "MULTI_MLU" + MULTI_SDAA = "MULTI_SDAA" + MULTI_MUSA = "MULTI_MUSA" + MULTI_XPU = "MULTI_XPU" + DEEPSPEED = "DEEPSPEED" + FSDP = "FSDP" + XLA = "XLA" + MEGATRON_LM = "MEGATRON_LM" + MULTI_HPU = "MULTI_HPU" + MULTI_NEURON = "MULTI_NEURON" + + +class SageMakerDistributedType(str, enum.Enum): + """ + Represents a type of distributed environment. + + Values: + + - **NO** -- Not a distributed environment, just a single process. + - **DATA_PARALLEL** -- using sagemaker distributed data parallelism. + - **MODEL_PARALLEL** -- using sagemaker distributed model parallelism. + """ + + # Subclassing str as well as Enum allows the `SageMakerDistributedType` to be JSON-serializable out of the box. + NO = "NO" + DATA_PARALLEL = "DATA_PARALLEL" + MODEL_PARALLEL = "MODEL_PARALLEL" + + +class FP8BackendType(str, enum.Enum): + """ + Represents the backend used for FP8. + + Values: + + - **TE** -- using TransformerEngine. + - **MSAMP** -- using msamp. + """ + + # Subclassing str as well as Enum allows the `FP8BackendType` to be JSON-serializable out of the box. + NO = "NO" + TE = "TE" + MSAMP = "MSAMP" + AO = "AO" + + +class ComputeEnvironment(str, enum.Enum): + """ + Represents a type of the compute environment. + + Values: + + - **LOCAL_MACHINE** -- private/custom cluster hardware. + - **AMAZON_SAGEMAKER** -- Amazon SageMaker as compute environment. + """ + + # Subclassing str as well as Enum allows the `ComputeEnvironment` to be JSON-serializable out of the box. + LOCAL_MACHINE = "LOCAL_MACHINE" + AMAZON_SAGEMAKER = "AMAZON_SAGEMAKER" + + +class DynamoBackend(str, BaseEnum): + """ + Represents a dynamo backend (see https://pytorch.org/docs/stable/torch.compiler.html). + + Values: + + - **NO** -- Do not use torch dynamo. + - **EAGER** -- Uses PyTorch to run the extracted GraphModule. This is quite useful in debugging TorchDynamo + issues. + - **AOT_EAGER** -- Uses AotAutograd with no compiler, i.e, just using PyTorch eager for the AotAutograd's + extracted forward and backward graphs. This is useful for debugging, and unlikely to give speedups. + - **INDUCTOR** -- Uses TorchInductor backend with AotAutograd and cudagraphs by leveraging codegened Triton + kernels. [Read + more](https://dev-discuss.pytorch.org/t/torchinductor-a-pytorch-native-compiler-with-define-by-run-ir-and-symbolic-shapes/747) + - **AOT_TS_NVFUSER** -- nvFuser with AotAutograd/TorchScript. [Read + more](https://dev-discuss.pytorch.org/t/tracing-with-primitives-update-1-nvfuser-and-its-primitives/593) + - **NVPRIMS_NVFUSER** -- nvFuser with PrimTorch. [Read + more](https://dev-discuss.pytorch.org/t/tracing-with-primitives-update-1-nvfuser-and-its-primitives/593) + - **CUDAGRAPHS** -- cudagraphs with AotAutograd. [Read more](https://github.com/pytorch/torchdynamo/pull/757) + - **OFI** -- Uses Torchscript optimize_for_inference. Inference only. [Read + more](https://pytorch.org/docs/stable/generated/torch.jit.optimize_for_inference.html) + - **FX2TRT** -- Uses Nvidia TensorRT for inference optimizations. Inference only. [Read + more](https://github.com/pytorch/TensorRT/blob/master/docsrc/tutorials/getting_started_with_fx_path.rst) + - **ONNXRT** -- Uses ONNXRT for inference on CPU/GPU. Inference only. [Read more](https://onnxruntime.ai/) + - **TENSORRT** -- Uses ONNXRT to run TensorRT for inference optimizations. [Read + more](https://github.com/onnx/onnx-tensorrt) + - **AOT_TORCHXLA_TRACE_ONCE** -- Uses Pytorch/XLA with TorchDynamo optimization, for training. [Read + more](https://github.com/pytorch/xla/blob/r2.0/docs/dynamo.md) + - **TORCHXLA_TRACE_ONCE** -- Uses Pytorch/XLA with TorchDynamo optimization, for inference. [Read + more](https://github.com/pytorch/xla/blob/r2.0/docs/dynamo.md) + - **TVM** -- Uses Apache TVM for inference optimizations. [Read more](https://tvm.apache.org/) + - **HPU_BACKEND** -- Uses HPU backend for inference optimizations. + + """ + + # Subclassing str as well as Enum allows the `SageMakerDistributedType` to be JSON-serializable out of the box. + NO = "NO" + EAGER = "EAGER" + AOT_EAGER = "AOT_EAGER" + INDUCTOR = "INDUCTOR" + AOT_TS_NVFUSER = "AOT_TS_NVFUSER" + NVPRIMS_NVFUSER = "NVPRIMS_NVFUSER" + CUDAGRAPHS = "CUDAGRAPHS" + OFI = "OFI" + FX2TRT = "FX2TRT" + ONNXRT = "ONNXRT" + TENSORRT = "TENSORRT" + AOT_TORCHXLA_TRACE_ONCE = "AOT_TORCHXLA_TRACE_ONCE" + TORCHXLA_TRACE_ONCE = "TORCHXLA_TRACE_ONCE" + TVM = "TVM" + HPU_BACKEND = "HPU_BACKEND" + + +class LoggerType(BaseEnum): + """Represents a type of supported experiment tracker + + Values: + + - **ALL** -- all available trackers in the environment that are supported + - **TENSORBOARD** -- TensorBoard as an experiment tracker + - **WANDB** -- wandb as an experiment tracker + - **TRACKIO** -- trackio as an experiment tracker + - **COMETML** -- comet_ml as an experiment tracker + - **MLFLOW** -- mlflow as an experiment tracker + - **CLEARML** -- clearml as an experiment tracker + - **DVCLIVE** -- dvclive as an experiment tracker + - **SWANLAB** -- swanlab as an experiment tracker + """ + + ALL = "all" + AIM = "aim" + TENSORBOARD = "tensorboard" + WANDB = "wandb" + TRACKIO = "trackio" + COMETML = "comet_ml" + MLFLOW = "mlflow" + CLEARML = "clearml" + DVCLIVE = "dvclive" + SWANLAB = "swanlab" + + +class PrecisionType(str, BaseEnum): + """Represents a type of precision used on floating point values + + Values: + + - **NO** -- using full precision (FP32) + - **FP16** -- using half precision + - **BF16** -- using brain floating point precision + """ + + NO = "no" + FP8 = "fp8" + FP16 = "fp16" + BF16 = "bf16" + + +class RNGType(BaseEnum): + TORCH = "torch" + CUDA = "cuda" + MLU = "mlu" + SDAA = "sdaa" + MUSA = "musa" + NPU = "npu" + XLA = "xla" + XPU = "xpu" + HPU = "hpu" + NEURON = "neuron" + GENERATOR = "generator" + + +class CustomDtype(enum.Enum): + r""" + An enum that contains multiple custom dtypes that can be used for `infer_auto_device_map`. + """ + + FP8 = "fp8" + INT4 = "int4" + INT2 = "int2" + + +# data classes + + +@dataclass +class TensorInformation: + shape: torch.Size + dtype: torch.dtype + + +@dataclass +class DataLoaderConfiguration: + """ + Configuration for dataloader-related items when calling `accelerator.prepare`. + + Args: + split_batches (`bool`, defaults to `False`): + Whether or not the accelerator should split the batches yielded by the dataloaders across the devices. If + `True`, the actual batch size used will be the same on any kind of distributed processes, but it must be a + round multiple of `num_processes` you are using. If `False`, actual batch size used will be the one set in + your script multiplied by the number of processes. + dispatch_batches (`bool`, defaults to `None`): + If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process + and then the batches are split and broadcast to each process. Will default to `True` for `DataLoader` whose + underlying dataset is an `IterableDataset`, `False` otherwise. + even_batches (`bool`, defaults to `True`): + If set to `True`, in cases where the total batch size across all processes does not exactly divide the + dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among + all workers. + use_seedable_sampler (`bool`, defaults to `False`): + Whether or not use a fully seedable random sampler ([`data_loader.SeedableRandomSampler`]). Ensures + training results are fully reproducible using a different sampling technique. While seed-to-seed results + may differ, on average the differences are negligible when using multiple different seeds to compare. + Should also be ran with [`~utils.set_seed`] for the best results. + data_seed (`int`, defaults to `None`): + The seed to use for the underlying generator when using `use_seedable_sampler`. If `None`, the generator + will use the current default seed from torch. + non_blocking (`bool`, defaults to `False`): + If set to `True`, the dataloader prepared by the Accelerator will utilize non-blocking host-to-device + transfers, allowing for better overlap between dataloader communication and computation. Recommended that + the prepared dataloader has `pin_memory` set to `True` to work properly. + use_stateful_dataloader (`bool`, defaults to `False`): + If set to `True`, the dataloader prepared by the Accelerator will be backed by + [torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader). + This requires `torchdata` version 0.8.0 or higher that supports StatefulDataLoader to be installed. + """ + + split_batches: bool = field( + default=False, + metadata={ + "help": "Whether or not the accelerator should split the batches yielded by the dataloaders across the devices. If" + " `True` the actual batch size used will be the same on any kind of distributed processes, but it must be a" + " round multiple of the `num_processes` you are using. If `False`, actual batch size used will be the one set" + " in your script multiplied by the number of processes." + }, + ) + dispatch_batches: bool = field( + default=None, + metadata={ + "help": "If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process" + " and then the batches are split and broadcast to each process. Will default to `True` for `DataLoader` whose" + " underlying dataset is an `IterableDataset`, `False` otherwise." + }, + ) + even_batches: bool = field( + default=True, + metadata={ + "help": "If set to `True`, in cases where the total batch size across all processes does not exactly divide the" + " dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among" + " all workers." + }, + ) + use_seedable_sampler: bool = field( + default=False, + metadata={ + "help": "Whether or not use a fully seedable random sampler ([`data_loader.SeedableRandomSampler`])." + "Ensures training results are fully reproducible using a different sampling technique. " + "While seed-to-seed results may differ, on average the differences are negligible when using" + "multiple different seeds to compare. Should also be ran with [`~utils.set_seed`] for the best results." + }, + ) + data_seed: int = field( + default=None, + metadata={ + "help": "The seed to use for the underlying generator when using `use_seedable_sampler`. If `None`, the generator" + " will use the current default seed from torch." + }, + ) + non_blocking: bool = field( + default=False, + metadata={ + "help": "If set to `True`, the dataloader prepared by the Accelerator will utilize non-blocking host-to-device" + " transfers, allowing for better overlap between dataloader communication and computation. Recommended that the" + " prepared dataloader has `pin_memory` set to `True` to work properly." + }, + ) + use_stateful_dataloader: bool = field( + default=False, + metadata={ + "help": "If set to `True`, the dataloader prepared by the Accelerator will be backed by " + "[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader). This requires `torchdata` version 0.8.0 or higher that supports StatefulDataLoader to be installed." + }, + ) + + +@dataclass +class ProjectConfiguration: + """ + Configuration for the Accelerator object based on inner-project needs. + + Args: + project_dir (`str`, defaults to `None`): + A path to a directory for storing data. + logging_dir (`str`, defaults to `None`): + A path to a directory for storing logs of locally-compatible loggers. If None, defaults to `project_dir`. + automatic_checkpoint_naming (`bool`, defaults to `False`): + Whether saved states should be automatically iteratively named. + total_limit (`int`, defaults to `None`): + The maximum number of total saved states to keep. + iteration (`int`, defaults to `0`): + The current save iteration. + save_on_each_node (`bool`, defaults to `False`): + When doing multi-node distributed training, whether to save models and checkpoints on each node, or only on + the main one. + """ + + project_dir: str = field(default=None, metadata={"help": "A path to a directory for storing data."}) + logging_dir: str = field( + default=None, + metadata={ + "help": "A path to a directory for storing logs of locally-compatible loggers. If None, defaults to `project_dir`." + }, + ) + automatic_checkpoint_naming: bool = field( + default=False, + metadata={"help": "Whether saved states should be automatically iteratively named."}, + ) + + total_limit: int = field( + default=None, + metadata={"help": "The maximum number of total saved states to keep."}, + ) + + iteration: int = field( + default=0, + metadata={"help": "The current save iteration."}, + ) + + save_on_each_node: bool = field( + default=False, + metadata={ + "help": ( + "When doing multi-node distributed training, whether to save models and checkpoints on each node, or" + " only on the main one" + ) + }, + ) + + def set_directories(self, project_dir: Optional[str] = None): + "Sets `self.project_dir` and `self.logging_dir` to the appropriate values." + self.project_dir = project_dir + if self.logging_dir is None: + self.logging_dir = project_dir + + def __post_init__(self): + self.set_directories(self.project_dir) + + +@dataclass +class GradientAccumulationPlugin(KwargsHandler): + """ + A plugin to configure gradient accumulation behavior. You can only pass one of `gradient_accumulation_plugin` or + `gradient_accumulation_steps` to [`Accelerator`]. Passing both raises an error. + + Parameters: + num_steps (`int`): + The number of steps to accumulate gradients for. + adjust_scheduler (`bool`, *optional*, defaults to `True`): + Whether to adjust the scheduler steps to account for the number of steps being accumulated. Should be + `True` if the used scheduler was not adjusted for gradient accumulation. + sync_with_dataloader (`bool`, *optional*, defaults to `True`): + Whether to synchronize setting the gradients when at the end of the dataloader. + sync_each_batch (`bool`, *optional*): + Whether to synchronize setting the gradients at each data batch. Setting to `True` may reduce memory + requirements when using gradient accumulation with distributed training, at expense of speed. + + Example: + + ```python + from accelerate.utils import GradientAccumulationPlugin + + gradient_accumulation_plugin = GradientAccumulationPlugin(num_steps=2) + accelerator = Accelerator(gradient_accumulation_plugin=gradient_accumulation_plugin) + ``` + """ + + num_steps: int = field( + default=None, + metadata={"help": "The number of steps to accumulate gradients for."}, + ) + adjust_scheduler: bool = field( + default=True, + metadata={ + "help": "Whether to adjust the scheduler steps to account for the number of steps being accumulated. Should be `True` if the used scheduler was not adjusted for gradient accumulation." + }, + ) + sync_with_dataloader: bool = field( + default=True, + metadata={ + "help": "Whether to synchronize setting the gradients when at the end of the dataloader. Should only be set to `False` if you know what you're doing." + }, + ) + sync_each_batch: bool = field( + default=False, + metadata={ + "help": "Whether to synchronize setting the gradients at each data batch. Setting to `True` may reduce memory requirements when using gradient accumulation with distributed training, at expense of speed." + }, + ) + + +@dataclass +class TorchDynamoPlugin(KwargsHandler): + """ + This plugin is used to compile a model with PyTorch 2.0 + + Args: + backend (`DynamoBackend`, defaults to `None`): + A valid Dynamo backend. See https://pytorch.org/docs/stable/torch.compiler.html for more details. + mode (`str`, defaults to `None`): + Possible options are 'default', 'reduce-overhead' or 'max-autotune'. + fullgraph (`bool`, defaults to `None`): + Whether it is ok to break model into several subgraphs. + dynamic (`bool`, defaults to `None`): + Whether to use dynamic shape for tracing. + options (`Any`, defaults to `None`): + A dictionary of options to pass to the backend. + disable (`bool`, defaults to `False`): + Turn torch.compile() into a no-op for testing + use_regional_compilation (`bool`, defaults to `None`): + Use it to reduce the cold start compilation time of torch.compile() by targeting repeated blocks of the + same class and compiling them sequentially to hit the compiler's cache. For example, in `GPT2LMHeadModel`, + the repeated block/class is `GPT2Block`, and can be accessed as `model.transformer.h[0]`. The rest of the + model (e.g model.lm_head) is compiled separately. + """ + + backend: DynamoBackend = field( + default=None, + metadata={"help": f"Possible options are {[b.value.lower() for b in DynamoBackend]}"}, + ) + mode: str = field( + default=None, + metadata={"help": "Possible options are 'default', 'reduce-overhead' or 'max-autotune'"}, + ) + fullgraph: bool = field( + default=None, + metadata={"help": "Whether it is ok to break model into several subgraphs"}, + ) + dynamic: bool = field(default=None, metadata={"help": "Whether to use dynamic shape for tracing"}) + options: Any = field( + default=None, + metadata={"help": "A dictionary of options to pass to the backend."}, + ) + disable: bool = field( + default=False, + metadata={"help": "Turn torch.compile() into a no-op for testing"}, + ) + + use_regional_compilation: bool = field( + default=None, + metadata={ + "help": ( + # https://pytorch.org/tutorials/recipes/regional_compilation.html + "Use it to reduce the cold start compilation time of torch.compile() by targeting repeated " + "blocks of the same class and compiling them sequentially to hit the compiler's cache. For " + "example, in `GPT2LMHeadModel`, the repeated block/class is `GPT2Block`, and can be accessed " + "as `model.transformer.h[0]`. The rest of the model (e.g model.lm_head) is compiled separately." + ) + }, + ) + + def __post_init__(self): + prefix = "ACCELERATE_DYNAMO_" + if self.backend is None: + self.backend = os.environ.get(prefix + "BACKEND", "no") + self.backend = DynamoBackend(self.backend.upper()) + + if self.mode is None: + self.mode = os.environ.get(prefix + "MODE", "default") + if self.fullgraph is None: + self.fullgraph = str_to_bool(os.environ.get(prefix + "USE_FULLGRAPH", "False")) == 1 + if self.use_regional_compilation is None: + self.use_regional_compilation = ( + str_to_bool(os.environ.get(prefix + "USE_REGIONAL_COMPILATION", "False")) == 1 + ) + + if self.dynamic is None and os.environ.get(prefix + "USE_DYNAMIC", None) is not None: + self.dynamic = str_to_bool(os.environ.get(prefix + "USE_DYNAMIC", "False")) == 1 + + def to_dict(self): + dynamo_config = copy.deepcopy(self.__dict__) + dynamo_config["backend"] = dynamo_config["backend"].value.lower() + return dynamo_config + + def to_kwargs(self): + kwargs = super().to_kwargs() + kwargs.pop("use_regional_compilation", None) + return kwargs + + +@dataclass +class DeepSpeedPlugin: + """ + This plugin is used to integrate DeepSpeed. + + Args: + hf_ds_config (`Any`, defaults to `None`): + Path to DeepSpeed config file or dict or an object of class `accelerate.utils.deepspeed.HfDeepSpeedConfig`. + gradient_accumulation_steps (`int`, defaults to `None`): + Number of steps to accumulate gradients before updating optimizer states. If not set, will use the value + from the `Accelerator` directly. + gradient_clipping (`float`, defaults to `None`): + Enable gradient clipping with value. + zero_stage (`int`, defaults to `None`): + Possible options are 0, 1, 2, 3. Default will be taken from environment variable. + is_train_batch_min (`bool`, defaults to `True`): + If both train & eval dataloaders are specified, this will decide the `train_batch_size`. + offload_optimizer_device (`str`, defaults to `None`): + Possible options are none|cpu|nvme. Only applicable with ZeRO Stages 2 and 3. + offload_param_device (`str`, defaults to `None`): + Possible options are none|cpu|nvme. Only applicable with ZeRO Stage 3. + offload_optimizer_nvme_path (`str`, defaults to `None`): + Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3. + offload_param_nvme_path (`str`, defaults to `None`): + Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3. + zero3_init_flag (`bool`, defaults to `None`): + Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3. + zero3_save_16bit_model (`bool`, defaults to `None`): + Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3. + transformer_moe_cls_names (`str`, defaults to `None`): + Comma-separated list of Transformers MoE layer class names (case-sensitive). For example, + `MixtralSparseMoeBlock`, `Qwen2MoeSparseMoeBlock`, `JetMoEAttention`, `JetMoEBlock`, etc. + enable_msamp (`bool`, defaults to `None`): + Flag to indicate whether to enable MS-AMP backend for FP8 training. + msasmp_opt_level (`Optional[Literal["O1", "O2"]]`, defaults to `None`): + Optimization level for MS-AMP (defaults to 'O1'). Only applicable if `enable_msamp` is True. Should be one + of ['O1' or 'O2']. + """ + + hf_ds_config: Any = field( + default=None, + metadata={ + "help": "path to DeepSpeed config file or dict or an object of class `accelerate.utils.deepspeed.HfDeepSpeedConfig`." + }, + ) + gradient_accumulation_steps: int = field( + default=None, + metadata={ + "help": "Number of steps to accumulate gradients before updating optimizer states. If not set, will use the value from the `Accelerator` directly." + }, + ) + gradient_clipping: float = field(default=None, metadata={"help": "Enable gradient clipping with value"}) + zero_stage: int = field( + default=None, + metadata={"help": "Possible options are 0,1,2,3; Default will be taken from environment variable"}, + ) + is_train_batch_min: bool = field( + default=True, + metadata={"help": "If both train & eval dataloaders are specified, this will decide the train_batch_size"}, + ) + offload_optimizer_device: str = field( + default=None, + metadata={"help": "Possible options are none|cpu|nvme. Only applicable with ZeRO Stages 2 and 3."}, + ) + offload_param_device: str = field( + default=None, + metadata={"help": "Possible options are none|cpu|nvme. Only applicable with ZeRO Stage 3."}, + ) + offload_optimizer_nvme_path: str = field( + default=None, + metadata={"help": "Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3."}, + ) + offload_param_nvme_path: str = field( + default=None, + metadata={"help": "Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3."}, + ) + zero3_init_flag: bool = field( + default=None, + metadata={ + "help": "Flag to indicate whether to enable `deepspeed.zero.Init` for constructing massive models." + "Only applicable with ZeRO Stage-3." + }, + ) + zero3_save_16bit_model: bool = field( + default=None, + metadata={"help": "Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3."}, + ) + transformer_moe_cls_names: str = field( + default=None, + metadata={ + "help": "comma-separated list of transformers MoE layer class names (case-sensitive), e.g : " + " `MixtralSparseMoeBlock`, `Qwen2MoeSparseMoeBlock`, `JetMoEAttention,JetMoEBlock` ..." + }, + ) + enable_msamp: bool = field( + default=None, + metadata={"help": "Flag to indicate whether to enable MS-AMP backend for FP8 training."}, + ) + msamp_opt_level: Optional[Literal["O1", "O2"]] = field( + default=None, + metadata={ + "help": "Optimization level for MS-AMP (defaults to 'O1'). Only applicable if `enable_msamp` is True. Should be one of ['O1' or 'O2']." + }, + ) + + def __post_init__(self): + from .deepspeed import HfDeepSpeedConfig + + if self.gradient_accumulation_steps is None: + gas = os.environ.get("ACCELERATE_GRADIENT_ACCUMULATION_STEPS", "auto") + self.gradient_accumulation_steps = int(gas) if gas.isdigit() else gas + + if self.gradient_clipping is None: + gradient_clipping = os.environ.get("ACCELERATE_GRADIENT_CLIPPING", "auto") + self.gradient_clipping = gradient_clipping if gradient_clipping == "auto" else float(gradient_clipping) + + if self.zero_stage is None: + self.zero_stage = int(os.environ.get("ACCELERATE_DEEPSPEED_ZERO_STAGE", 2)) + + if self.offload_optimizer_device is None: + self.offload_optimizer_device = os.environ.get("ACCELERATE_DEEPSPEED_OFFLOAD_OPTIMIZER_DEVICE", "none") + + if self.offload_param_device is None: + self.offload_param_device = os.environ.get("ACCELERATE_DEEPSPEED_OFFLOAD_PARAM_DEVICE", "none") + + if self.offload_optimizer_nvme_path is None: + self.offload_optimizer_nvme_path = os.environ.get( + "ACCELERATE_DEEPSPEED_OFFLOAD_OPTIMIZER_NVME_PATH", "none" + ) + + if self.offload_param_nvme_path is None: + self.offload_param_nvme_path = os.environ.get("ACCELERATE_DEEPSPEED_OFFLOAD_PARAM_NVME_PATH", "none") + + if self.zero3_save_16bit_model is None: + self.zero3_save_16bit_model = ( + os.environ.get("ACCELERATE_DEEPSPEED_ZERO3_SAVE_16BIT_MODEL", "false").lower() == "true" + ) + if self.enable_msamp is None: + self.enable_msamp = os.environ.get("ACCELERATE_FP8_BACKEND", None) == "MSAMP" + + if self.msamp_opt_level is None: + self.msamp_opt_level = os.environ.get("ACCELERATE_FP8_OPT_LEVEL", "O1") + + if self.hf_ds_config is None: + self.hf_ds_config = os.environ.get("ACCELERATE_DEEPSPEED_CONFIG_FILE", "none") + + if ( + isinstance(self.hf_ds_config, dict) + or (isinstance(self.hf_ds_config, str) and self.hf_ds_config != "none") + or isinstance(self.hf_ds_config, HfDeepSpeedConfig) + ): + if not isinstance(self.hf_ds_config, HfDeepSpeedConfig): + self.hf_ds_config = HfDeepSpeedConfig(self.hf_ds_config) + if "gradient_accumulation_steps" not in self.hf_ds_config.config: + self.hf_ds_config.config["gradient_accumulation_steps"] = 1 + if "zero_optimization" not in self.hf_ds_config.config: + raise ValueError("Please specify the ZeRO optimization config in the DeepSpeed config.") + + self._deepspeed_config_checks() + plugin_to_config_mapping = { + "gradient_accumulation_steps": "gradient_accumulation_steps", + "gradient_clipping": "gradient_clipping", + "zero_stage": "zero_optimization.stage", + "offload_optimizer_device": "zero_optimization.offload_optimizer.device", + "offload_param_device": "zero_optimization.offload_param.device", + "offload_param_nvme_path": "zero_optimization.offload_param.nvme_path", + "offload_optimizer_nvme_path": "zero_optimization.offload_optimizer.nvme_path", + "zero3_save_16bit_model": "zero_optimization.stage3_gather_16bit_weights_on_model_save", + } + kwargs = {v: getattr(self, k) for k, v in plugin_to_config_mapping.items() if getattr(self, k) is not None} + for key in kwargs.keys(): + self.fill_match(key, **kwargs, must_match=False) + self.hf_ds_config.set_stage_and_offload() + + # filling the missing values in the class attributes from the DeepSpeed config + # when using the DeepSpeed config file. + for key, value in plugin_to_config_mapping.items(): + config_value = self.hf_ds_config.get_value(value) + if config_value is not None and config_value != "auto": + setattr(self, key, config_value) + else: + config = { + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "gradient_accumulation_steps": self.gradient_accumulation_steps, + "zero_optimization": { + "stage": self.zero_stage, + "offload_optimizer": { + "device": self.offload_optimizer_device, + "nvme_path": ( + self.offload_optimizer_nvme_path if self.offload_optimizer_device == "nvme" else None + ), + }, + "offload_param": { + "device": self.offload_param_device, + "nvme_path": (self.offload_param_nvme_path if self.offload_param_device == "nvme" else None), + }, + "stage3_gather_16bit_weights_on_model_save": self.zero3_save_16bit_model, + }, + } + if self.gradient_clipping: + config["gradient_clipping"] = self.gradient_clipping + self.hf_ds_config = HfDeepSpeedConfig(config) + + self.deepspeed_config = self.hf_ds_config.config + self.deepspeed_config["steps_per_print"] = float("inf") # this will stop deepspeed from logging @ stdout + if self.zero3_init_flag is None: + self.zero3_init_flag = ( + str_to_bool( + os.environ.get( + "ACCELERATE_DEEPSPEED_ZERO3_INIT", + str(self.hf_ds_config.is_zero3()), + ) + ) + == 1 + ) + if self.zero3_init_flag and not self.hf_ds_config.is_zero3(): + warnings.warn("DeepSpeed Zero3 Init flag is only applicable for ZeRO Stage 3. Setting it to False.") + self.zero3_init_flag = False + # NOTE: Set to False by default, will be set to `True` automatically if it's the first plugin passed + # to the `Accelerator`'s `deepspeed_plugin` param, *or* `AcceleratorState().enable_deepspeed_plugin(plugin_key)` is manually called + self._set_selected(False) + + # Ignore if it's already set + if self.enable_msamp and "msamp" not in self.deepspeed_config: + if self.zero_stage == 3: + raise NotImplementedError( + "MS-AMP is not supported for ZeRO Stage 3. Please use ZeRO Stage 0, 1, or 2 instead." + ) + if self.msamp_opt_level not in ["O1", "O2"]: + raise ValueError("Invalid optimization level for MS-AMP. Please use one of ['O1' or'O2'].") + self.deepspeed_config["msamp"] = { + "enabled": True, + "opt_level": self.msamp_opt_level, + } + + def fill_match(self, ds_key_long, mismatches=None, must_match=True, **kwargs): + mismatches = [] if mismatches is None else mismatches + config, ds_key = self.hf_ds_config.find_config_node(ds_key_long) + if config is None: + return + + if config.get(ds_key) == "auto": + if ds_key_long in kwargs: + config[ds_key] = kwargs[ds_key_long] + return + else: + raise ValueError( + f"`{ds_key_long}` not found in kwargs. " + f"Please specify `{ds_key_long}` without `auto` (set to correct value) in the DeepSpeed config file or " + "pass it in kwargs." + ) + + if not must_match: + return + + ds_val = config.get(ds_key) + if ds_val is not None and ds_key_long in kwargs: + if ds_val != kwargs[ds_key_long]: + mismatches.append(f"- ds {ds_key_long}={ds_val} vs arg {ds_key_long}={kwargs[ds_key_long]}") + + def is_auto(self, ds_key_long): + val = self.hf_ds_config.get_value(ds_key_long) + if val is None: + return False + else: + return val == "auto" + + def get_value(self, ds_key_long, default=None): + return self.hf_ds_config.get_value(ds_key_long, default) + + def deepspeed_config_process(self, prefix="", mismatches=None, config=None, must_match=True, **kwargs): + """Process the DeepSpeed config with the values from the kwargs.""" + mismatches = [] if mismatches is None else mismatches + if config is None: + config = self.deepspeed_config + for key, value in config.items(): + if isinstance(value, dict): + self.deepspeed_config_process( + prefix=prefix + key + ".", + mismatches=mismatches, + config=value, + must_match=must_match, + **kwargs, + ) + else: + self.fill_match(prefix + key, mismatches, must_match=must_match, **kwargs) + if len(mismatches) > 0 and prefix == "": + mismatches_msg = "\n".join(mismatches) + raise ValueError( + "Please correct the following DeepSpeed config values that mismatch kwargs " + f" values:\n{mismatches_msg}\nThe easiest method is to set these DeepSpeed config values to 'auto'." + ) + + def set_mixed_precision(self, mixed_precision): + ds_config = self.deepspeed_config + kwargs = { + "fp16.enabled": mixed_precision == "fp16", + # When training in fp8, we still rely on bf16 autocast for the core mixed precision + "bf16.enabled": mixed_precision in ("bf16", "fp8"), + } + if mixed_precision == "fp16": + if "fp16" not in ds_config: + ds_config["fp16"] = {"enabled": True, "auto_cast": True} + elif mixed_precision in ("bf16", "fp8"): + if "bf16" not in ds_config: + ds_config["bf16"] = {"enabled": True} + + if mixed_precision == "fp8" and self.enable_msamp: + if "msamp" not in ds_config: + ds_config["msamp"] = { + "enabled": True, + "opt_level": self.msamp_opt_level, + } + + if mixed_precision != "no": + diff_dtype = "bf16" if mixed_precision == "fp16" else "fp16" + if str(ds_config.get(diff_dtype, {}).get("enabled", "False")).lower() == "true": + raise ValueError( + f"`--mixed_precision` arg cannot be set to `{mixed_precision}` when `{diff_dtype}` is set in the DeepSpeed config file." + ) + for dtype in ["fp16", "bf16"]: + if dtype not in ds_config: + ds_config[dtype] = {"enabled": False} + self.fill_match("fp16.enabled", must_match=False, **kwargs) + self.fill_match("bf16.enabled", must_match=False, **kwargs) + + def set_deepspeed_weakref(self): + from .imports import is_transformers_available + + ds_config = copy.deepcopy(self.deepspeed_config) + if self.zero3_init_flag: + if not is_transformers_available(): + raise Exception( + "When `zero3_init_flag` is set, it requires Transformers to be installed. " + "Please run `pip install transformers`." + ) + if "gradient_accumulation_steps" not in ds_config or ds_config["gradient_accumulation_steps"] == "auto": + ds_config["gradient_accumulation_steps"] = 1 + if "train_micro_batch_size_per_gpu" not in ds_config or ds_config["train_micro_batch_size_per_gpu"] == "auto": + ds_config["train_micro_batch_size_per_gpu"] = 1 + if ds_config.get("train_batch_size", None) == "auto": + del ds_config["train_batch_size"] + + if compare_versions("transformers", "<", "4.46"): + from transformers.deepspeed import ( + HfDeepSpeedConfig, + unset_hf_deepspeed_config, + ) + else: + from transformers.integrations import ( + HfDeepSpeedConfig, + unset_hf_deepspeed_config, + ) + + unset_hf_deepspeed_config() + self.dschf = HfDeepSpeedConfig(ds_config) # keep this object alive # noqa + + def is_zero3_init_enabled(self): + return self.zero3_init_flag + + @contextmanager + def zero3_init_context_manager(self, enable=False): + old = self.zero3_init_flag + if old == enable: + yield + else: + self.zero3_init_flag = enable + self.dschf = None + self.set_deepspeed_weakref() + yield + self.zero3_init_flag = old + self.dschf = None + self.set_deepspeed_weakref() + + def _deepspeed_config_checks(self): + env_variable_names_to_ignore = [ + "ACCELERATE_GRADIENT_ACCUMULATION_STEPS", + "ACCELERATE_GRADIENT_CLIPPING", + "ACCELERATE_DEEPSPEED_ZERO_STAGE", + "ACCELERATE_DEEPSPEED_OFFLOAD_OPTIMIZER_DEVICE", + "ACCELERATE_DEEPSPEED_OFFLOAD_PARAM_DEVICE", + "ACCELERATE_DEEPSPEED_OFFLOAD_PARAM_NVME_PATH", + "ACCELERATE_DEEPSPEED_OFFLOAD_OPTIMIZER_NVME_PATH", + "ACCELERATE_DEEPSPEED_ZERO3_SAVE_16BIT_MODEL", + "ACCELERATE_MIXED_PRECISION", + ] + env_variable_names_to_ignore = [ + name.replace("ACCELERATE_", "").replace("DEEPSPEED_", "").lower() for name in env_variable_names_to_ignore + ] + + deepspeed_fields_from_accelerate_config = os.environ.get("ACCELERATE_CONFIG_DS_FIELDS", "").split(",") + + if any(name in env_variable_names_to_ignore for name in deepspeed_fields_from_accelerate_config): + raise ValueError( + f"When using `deepspeed_config_file`, the following accelerate config variables will be ignored: {env_variable_names_to_ignore}.\n" + "Please specify them appropriately in the DeepSpeed config file.\n" + "If you are using an accelerate config file, remove others config variables mentioned in the above specified list.\n" + "The easiest method is to create a new config following the questionnaire via `accelerate config`.\n" + "It will only ask for the necessary config variables when using `deepspeed_config_file`." + ) + + def set_moe_leaf_modules(self, model): + if self.transformer_moe_cls_names is None: + self.transformer_moe_cls_names = os.environ.get("ACCELERATE_DEEPSPEED_MOE_LAYER_CLS_NAMES", None) + if self.transformer_moe_cls_names is not None: + if compare_versions("deepspeed", "<", "0.14.0"): + raise ImportError("DeepSpeed version must be >= 0.14.0 to use MOE support. Please update DeepSpeed.") + from deepspeed.utils import set_z3_leaf_modules + + class_names = self.transformer_moe_cls_names.split(",") + transformer_moe_cls = [] + for layer_class in class_names: + transformer_cls = get_module_class_from_name(model, layer_class) + if transformer_cls is None: + raise Exception( + f"Could not find a transformer layer class called '{layer_class}' to wrap in the model." + ) + else: + transformer_moe_cls.append(transformer_cls) + set_z3_leaf_modules(model, transformer_moe_cls) # z3_leaf + + def select(self, _from_accelerator_state: bool = False): + """ + Sets the HfDeepSpeedWeakref to use the current deepspeed plugin configuration + """ + if not _from_accelerator_state: + raise ValueError( + "A `DeepSpeedPlugin` object must be enabled manually by calling `AcceleratorState().enable_deepspeed_plugin(plugin_key)`." + ) + self.set_deepspeed_weakref() + self._set_selected(True) + + def _unselect(self): + self._set_selected(False) + + def _set_selected(self, value: bool): + """ + Private setter for the 'enabled' attribute. + """ + self._selected = value + + @property + def selected(self): + return self._selected + + @selected.setter + def selected(self, value): + raise NotImplementedError( + "'enabled' can only be set through calling 'AcceleratorState().enable_deepspeed_plugin(key)'." + ) + + +@dataclass +class FullyShardedDataParallelPlugin: + """ + This plugin is used to enable fully sharded data parallelism. + + Args: + fsdp_version (`int`, defaults to `1`): + The version of FSDP to use. Defaults to 1. If set to 2, launcher expects the config to be converted to + FSDP2 format. + sharding_strategy (`Union[str, torch.distributed.fsdp.ShardingStrategy]`, defaults to `'FULL_SHARD'`): + Sharding strategy to use. Should be either a `str` or an instance of + `torch.distributed.fsdp.fully_sharded_data_parallel.ShardingStrategy`. Is deprecated in favor of + `reshard_after_forward`. + reshard_after_forward (`Union[str, torch.distributed.fsdp.ShardingStrategy, bool]`, defaults to `'FULL_SHARD'` for `fsdp_version=1` and `True` for `fsdp_version=2`): + Sharding strategy to use. Should be a bool if `fsdp_version` is set to 2 else a `str` or an instance of + `torch.distributed.fsdp.fully_sharded_data_parallel.ShardingStrategy`. + backward_prefetch (`Union[str, torch.distributed.fsdp.BackwardPrefetch]`, defaults to `'NO_PREFETCH'`): + Backward prefetch strategy to use. Should be either a `str` or an instance of + `torch.distributed.fsdp.fully_sharded_data_parallel.BackwardPrefetch`. + mixed_precision_policy (`Optional[Union[dict, str, torch.distributed.fsdp.MixedPrecision, torch.distributed.fsdp.MixedPrecisionPolicy]]`, defaults to `None`): + A config to enable mixed precision training with FullyShardedDataParallel. If passing in a `dict`, it + should have the following keys: `param_dtype`, `reduce_dtype`, and `buffer_dtype`, can be an instance of + `torch.distributed.fsdp.MixedPrecisionPolicy` if `fsdp_version` is set to 2. If passing in a `str`, it + should be one of the following values: fp8, fp16, bf16, fp32, and used to set `param_dtype`, + `reduce_dtype`, and `buffer_dtype`. + auto_wrap_policy (`Optional(Union[Callable, Literal["transformer_based_wrap", "size_based_wrap", "no_wrap"]]), defaults to `NO_WRAP`): + A callable or string specifying a policy to recursively wrap layers with FSDP. If a string, it must be one + of `transformer_based_wrap`, `size_based_wrap`, or `no_wrap`. See + `torch.distributed.fsdp.wrap.size_based_wrap_policy` for a direction on what it should look like. + cpu_offload (`Union[bool, torch.distributed.fsdp.CPUOffload, torch.distributed.fsdp.CPUOffloadPolicy]`, defaults to `False`): + Whether to offload parameters to CPU. Should be either a `bool` or an instance of + `torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffload` or + `torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffloadPolicy` if `fsdp_version` is set to 2. + ignored_modules (`Optional[Union[Iterable[torch.nn.Module], str]]`, defaults to `None`): + A list of modules to ignore when wrapping with FSDP. When passing a string, will match the modules by name + using regex fullmatch. If `fsdp_version` is set to 2, the modules are converted to parameters and used. + state_dict_type (`Union[str, torch.distributed.fsdp.StateDictType]`, defaults to `'FULL_STATE_DICT'`): + State dict type to use. If a string, it must be one of `full_state_dict`, `local_state_dict`, or + `sharded_state_dict`. + state_dict_config (`Optional[Union[torch.distributed.fsdp.FullStateDictConfig, torch.distributed.fsdp.ShardedStateDictConfig]`, defaults to `None`): + State dict config to use. Is determined based on the `state_dict_type` if not passed in. + optim_state_dict_config (`Optional[Union[torch.distributed.fsdp.FullOptimStateDictConfig, torch.distributed.fsdp.ShardedOptimStateDictConfig]`, defaults to `None`): + Optim state dict config to use. Is determined based on the `state_dict_type` if not passed in. + limit_all_gathers (`bool`, defaults to `True`): + Whether to have FSDP explicitly synchronizes the CPU thread to prevent too many in-flight all-gathers. This + bool only affects the sharded strategies that schedule all-gathers. Enabling this can help lower the number + of CUDA malloc retries. + use_orig_params (`bool`, defaults to `False`): + Whether to use the original parameters for the optimizer. + param_init_fn (`Optional[Callable[[torch.nn.Module], None]`, defaults to `None`): + A `Callable[torch.nn.Module] -> None` that specifies how modules that are currently on the meta device + should be initialized onto an actual device. Only applicable when `sync_module_states` is `True`. By + default is a `lambda` which calls `to_empty` on the module. + sync_module_states (`bool`, defaults to `False`): + Whether each individually wrapped FSDP unit should broadcast module parameters from rank 0 to ensure they + are the same across all ranks after initialization. Defaults to `False` unless `cpu_ram_efficient_loading` + is `True`, then will be forcibly enabled. + forward_prefetch (`bool`, defaults to `False`): + Whether to have FSDP explicitly prefetches the next upcoming all-gather while executing in the forward + pass. only use with Static graphs. + activation_checkpointing (`bool`, defaults to `False`): + A technique to reduce memory usage by clearing activations of certain layers and recomputing them during a + backward pass. Effectively, this trades extra computation time for reduced memory usage. + cpu_ram_efficient_loading (`bool`, defaults to `None`): + If True, only the first process loads the pretrained model checkoint while all other processes have empty + weights. Only applicable for Transformers. When using this, `sync_module_states` needs to be `True`. + transformer_cls_names_to_wrap (`Optional[List[str]]`, defaults to `None`): + A list of transformer layer class names to wrap. Only applicable when `auto_wrap_policy` is + `transformer_based_wrap`. + min_num_params (`Optional[int]`, defaults to `None`): + The minimum number of parameters a module must have to be wrapped. Only applicable when `auto_wrap_policy` + is `size_based_wrap`. + """ + + fsdp_version: int = field( + default=None, + metadata={ + "help": "The version of FSDP to use. Defaults to 1. If set to 2, launcher expects the config to be converted to FSDP2 format." + }, + ) + + sharding_strategy: Union[str, "torch.distributed.fsdp.ShardingStrategy"] = field( + default=None, + metadata={ + "help": "Sharding strategy to use. Should be either a `str` or an instance of `torch.distributed.fsdp.fully_sharded_data_parallel.ShardingStrategy`. Defaults to 'FULL_SHARD'. Is deprecated in favor of `reshard_after_forward` " + }, + ) + + reshard_after_forward: Union[str, "torch.distributed.fsdp.ShardingStrategy", bool] = field( + default=None, + metadata={ + "help": "Sharding strategy to use. Should be a bool if `fsdp_version` is set to 2 else a `str` or an instance of `torch.distributed.fsdp.fully_sharded_data_parallel.ShardingStrategy`. Defaults to 'FULL_SHARD'" + }, + ) + backward_prefetch: Optional[Union[str, "torch.distributed.fsdp.BackwardPrefetch"]] = field( + default=None, + metadata={ + "help": "Backward prefetch strategy to use. Should be either a `str` or an instance of `torch.distributed.fsdp.fully_sharded_data_parallel.BackwardPrefetch`. Defaults to 'NO_PREFETCH'. This becomes obsolete in FSDP2." + }, + ) + mixed_precision_policy: Optional[ + Union[ + dict, + str, + "torch.distributed.fsdp.MixedPrecision", + "torch.distributed.fsdp.MixedPrecisionPolicy", + ] + ] = field( + default=None, + metadata={ + "help": "A config to enable mixed precision training with FullyShardedDataParallel. " + "If passing in a `dict`, it should have the following keys: `param_dtype`, `reduce_dtype`, and `buffer_dtype`." + "Can also be an instance of `torch.distributed.fsdp.MixedPrecisionPolicy` if `fsdp_version` is set to 2." + }, + ) + auto_wrap_policy: Optional[Union[Callable, Literal["transformer_based_wrap", "size_based_wrap", "no_wrap"]]] = ( + field( + default=None, + metadata={ + "help": "A callable or string specifying a policy to recursively wrap layers with FSDP. If a string, it must be one of `transformer_based_wrap`, `size_based_wrap`, or `no_wrap`. " + "Defaults to `NO_WRAP`. See `torch.distributed.fsdp.wrap.size_based_wrap_policy` for a direction on what it should look like" + }, + ) + ) + cpu_offload: Union[ + bool, + "torch.distributed.fsdp.CPUOffload", + "torch.distributed.fsdp.CPUOffloadPolicy", + ] = field( + default=None, + metadata={ + "help": "Whether to offload parameters to CPU. Should be either a `bool` or an instance of `torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffload` or `torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffloadPolicy` if `fsdp_version` is set to 2. Defaults to `False`" + }, + ) + ignored_modules: Optional[Union[Iterable[torch.nn.Module], str]] = field( + default=None, + metadata={"help": "A list of modules to ignore when wrapping with FSDP."}, + ) + + state_dict_type: Union[str, "torch.distributed.fsdp.StateDictType"] = field( + default=None, + metadata={ + "help": "State dict type to use. If a string, it must be one of `full_state_dict`, `local_state_dict`, or `sharded_state_dict`. Defaults to `FULL_STATE_DICT`" + }, + ) + state_dict_config: Optional[ + Union[ + "torch.distributed.fsdp.FullStateDictConfig", + "torch.distributed.fsdp.ShardedStateDictConfig", + ] + ] = field( + default=None, + metadata={"help": "State dict config to use. Is determined based on the `state_dict_type` if not passed in."}, + ) + optim_state_dict_config: Optional[ + Union[ + "torch.distributed.fsdp.FullOptimStateDictConfig", + "torch.distributed.fsdp.ShardedOptimStateDictConfig", + ] + ] = field( + default=None, + metadata={ + "help": "Optim state dict config to use. Is determined based on the `state_dict_type` if not passed in." + }, + ) + limit_all_gathers: bool = field( + default=True, + metadata={ + "help": "Whether to have FSDP explicitly synchronizes the CPU thread to prevent " + "too many in-flight all-gathers. This bool only affects the sharded strategies that schedule all-gathers. " + "Enabling this can help lower the number of CUDA malloc retries." + }, + ) + use_orig_params: Optional[bool] = field( + default=None, + metadata={ + "help": "Whether to use the original parameters for the optimizer. Defaults to `False`. This becomes obsolete in FSDP2." + }, + ) + param_init_fn: Optional[Callable[[torch.nn.Module], None]] = field( + default=None, + metadata={ + "help": "A Callable[torch.nn.Module] -> None that specifies how modules " + "that are currently on the meta device should be initialized onto an actual device. " + "Only applicable when `sync_module_states` is `True`. By default is a `lambda` which calls `to_empty` on the module." + }, + ) + sync_module_states: Optional[bool] = field( + default=None, + metadata={ + "help": "Whether each individually wrapped FSDP unit should broadcast module parameters from rank 0 " + "to ensure they are the same across all ranks after initialization. Defaults to `False` unless " + "`cpu_ram_efficient_loading` is `True`, then will be forcibly enabled. This becomes obsolete in FSDP2." + }, + ) + forward_prefetch: bool = field( + default=None, + metadata={ + "help": "Whether to have FSDP explicitly prefetches the next upcoming " + "all-gather while executing in the forward pass. only use with Static graphs. Defaults to `False`" + }, + ) + activation_checkpointing: bool = field( + default=None, + metadata={ + "help": "A technique to reduce memory usage by clearing activations of " + "certain layers and recomputing them during a backward pass. Effectively, this trades extra computation time " + "for reduced memory usage. Defaults to `False`" + }, + ) + cpu_ram_efficient_loading: bool = field( + default=None, + metadata={ + "help": "If True, only the first process loads the pretrained model checkoint while all other processes have empty weights. " + "Only applicable for 🤗 Transformers. When using this, `sync_module_states` needs to be `True`. Defaults to `False`." + }, + ) + transformer_cls_names_to_wrap: Optional[list[str]] = field( + default=None, + metadata={ + "help": "A list of transformer layer class names to wrap. Only applicable when `auto_wrap_policy` is `transformer_based_wrap`." + }, + ) + min_num_params: Optional[int] = field( + default=None, + metadata={ + "help": "The minimum number of parameters a module must have to be wrapped. Only applicable when `auto_wrap_policy` is `size_based_wrap`." + }, + ) + + def __post_init__(self): + from torch.distributed.fsdp import BackwardPrefetch, ShardingStrategy + + _fsdp2_warnings = set() + + env_prefix = "FSDP_" + # Strategy: By default we should always assume that values are passed in, else we check the environment variables + if self.fsdp_version is None: + self.fsdp_version = int(os.environ.get(env_prefix + "VERSION", "1")) + + if self.fsdp_version == 2: + if not is_torch_version(">=", FSDP2_PYTORCH_VERSION): + raise ImportError(f"FSDP2 requires PyTorch >= {FSDP2_PYTORCH_VERSION}") + + if self.sharding_strategy is not None: + # We cannot properly detect all of the cases, as by default `args.fsdp_sharding_strategy` is set to `fully_shard` + # Therefore we issue a warning only if the user has explicitly set it inside their plugin + _fsdp2_warnings.add( + "sharding_strategy is deprecated in favor of reshard_after_forward. " + "This will be removed in a future version of Accelerate." + ) + if self.fsdp_version == 1: + if self.sharding_strategy is None: + self.sharding_strategy = os.environ.get(env_prefix + "SHARDING_STRATEGY", "FULL_SHARD") + if isinstance(self.sharding_strategy, str): + if self.sharding_strategy.upper() in FSDP_SHARDING_STRATEGY: + self.sharding_strategy = FSDP_SHARDING_STRATEGY.index(self.sharding_strategy.upper()) + 1 + if isinstance(self.sharding_strategy, int) or self.sharding_strategy.isdigit(): + self.sharding_strategy = ShardingStrategy(int(self.sharding_strategy)) + else: + self.sharding_strategy = ShardingStrategy[self.sharding_strategy.upper()] + + # Fallback to `reshard_after_forward` in FSDP1 if `sharding_strategy` is not set + if self.reshard_after_forward is None and self.sharding_strategy is None: + reshard_after_forward = os.environ.get( + env_prefix + "RESHARD_AFTER_FORWARD", + "true" if self.fsdp_version == 2 else "FULL_SHARD", + ) + if self.fsdp_version == 2: + self.reshard_after_forward = str_to_bool(reshard_after_forward.lower(), to_bool=True) + else: + self.reshard_after_forward = reshard_after_forward + if isinstance(self.reshard_after_forward, str): + if self.fsdp_version == 2: + self.reshard_after_forward = str_to_bool(self.reshard_after_forward.lower(), to_bool=True) + else: + # We need to remap based on custom enum values for user readability + if self.reshard_after_forward.upper() in FSDP_SHARDING_STRATEGY: + self.reshard_after_forward = FSDP_SHARDING_STRATEGY.index(self.reshard_after_forward.upper()) + 1 + if isinstance(self.reshard_after_forward, int) or self.reshard_after_forward.isdigit(): + self.reshard_after_forward = ShardingStrategy(int(self.reshard_after_forward)) + else: + self.reshard_after_forward = ShardingStrategy[self.reshard_after_forward.upper()] + + if self.fsdp_version == 2 and not isinstance(self.reshard_after_forward, bool): + raise ValueError( + f"reshard_after_forward set to {self.reshard_after_forward}. This is not supported with FSDP2, please set to a `bool`" + ) + if self.fsdp_version == 1 and isinstance(self.reshard_after_forward, bool): + raise ValueError( + f"reshard_after_forward set to {self.reshard_after_forward}. This is not supported with FSDP1, please set to a `str` or an instance of `torch.distributed.fsdp.fully_sharded_data_parallel.ShardingStrategy`" + ) + + if self.cpu_offload is None: + self.cpu_offload = str_to_bool(os.environ.get(env_prefix + "OFFLOAD_PARAMS", "False")) == 1 + + self.set_cpu_offload() # abstracted away to hide imports due to version checks + self.validate_cpu_offload() + + if self.backward_prefetch is None: + self.backward_prefetch = os.environ.get(env_prefix + "BACKWARD_PREFETCH", None) + if isinstance(self.backward_prefetch, str) and self.backward_prefetch.upper() == "NO_PREFETCH": + self.backward_prefetch = None + if self.backward_prefetch is not None and not isinstance(self.backward_prefetch, BackwardPrefetch): + if isinstance(self.backward_prefetch, str) and self.backward_prefetch.upper() in FSDP_BACKWARD_PREFETCH: + self.backward_prefetch = FSDP_BACKWARD_PREFETCH.index(self.backward_prefetch.upper()) + 1 + if isinstance(self.backward_prefetch, int) or self.backward_prefetch.isdigit(): + self.backward_prefetch = BackwardPrefetch(int(self.backward_prefetch)) + else: + self.backward_prefetch = BackwardPrefetch[self.backward_prefetch.upper()] + if self.fsdp_version == 2 and self.backward_prefetch is not None: + _fsdp2_warnings.add("backward_prefetch is not supported in FSDP2. Setting backward prefetch to None.") + self.backward_prefetch = None + + self.set_state_dict_type() + + if self.auto_wrap_policy is None: + self.auto_wrap_policy = os.environ.get(env_prefix + "AUTO_WRAP_POLICY", "NO_WRAP") + if isinstance(self.auto_wrap_policy, str): + if self.auto_wrap_policy.upper() not in FSDP_AUTO_WRAP_POLICY: + raise ValueError( + f"Invalid auto wrap policy: {self.auto_wrap_policy}. Must be one of {FSDP_AUTO_WRAP_POLICY}" + ) + from torch.distributed.fsdp.wrap import ( + size_based_auto_wrap_policy, + transformer_auto_wrap_policy, + ) + + if self.auto_wrap_policy.upper() == "TRANSFORMER_BASED_WRAP": + self.auto_wrap_policy = transformer_auto_wrap_policy + if self.transformer_cls_names_to_wrap is None: + self.transformer_cls_names_to_wrap = os.environ.get(env_prefix + "TRANSFORMER_CLS_TO_WRAP", None) + if isinstance(self.transformer_cls_names_to_wrap, str): + self.transformer_cls_names_to_wrap = self.transformer_cls_names_to_wrap.split(",") + elif self.auto_wrap_policy.upper() == "SIZE_BASED_WRAP": + self.auto_wrap_policy = size_based_auto_wrap_policy + if self.min_num_params is None: + self.min_num_params = int(os.environ.get(env_prefix + "MIN_NUM_PARAMS", 0)) + elif not isinstance(self.min_num_params, int): + raise ValueError( + f"`min_num_params` must be an integer. Got {self.min_num_params} of type {type(self.min_num_params)}" + ) + elif self.auto_wrap_policy.upper() == "NO_WRAP": + self.auto_wrap_policy = None + + if self.use_orig_params is None and self.fsdp_version == 1: + self.use_orig_params = str_to_bool(os.environ.get(env_prefix + "USE_ORIG_PARAMS", "False")) == 1 + if self.fsdp_version == 2 and self.use_orig_params is not None: + _fsdp2_warnings.add("use_orig_params is obsolete in FSDP2, as FSDP2 always uses the original parameters.") + self.use_orig_params = None + + if self.sync_module_states is None and self.fsdp_version == 1: + self.sync_module_states = str_to_bool(os.environ.get(env_prefix + "SYNC_MODULE_STATES", "False")) == 1 + if self.fsdp_version == 2 and self.sync_module_states is not None: + _fsdp2_warnings.add( + "sync_module_states is obsolete in FSDP2, as it is not needed anymore." + "Setting sync_module_states to None." + ) + self.sync_module_states = None + + if self.forward_prefetch is None and self.fsdp_version == 1: + self.forward_prefetch = str_to_bool(os.environ.get(env_prefix + "FORWARD_PREFETCH", "False")) == 1 + if self.fsdp_version == 2 and self.forward_prefetch is not None: + raise ValueError("forward_prefetch is not yet implemented in FSDP2, set to None or use `fsdp_version=1`") + + if self.activation_checkpointing is None: + self.activation_checkpointing = ( + str_to_bool(os.environ.get(env_prefix + "ACTIVATION_CHECKPOINTING", "False")) == 1 + ) + + if self.ignored_modules is None: + self.ignored_modules = os.environ.get(env_prefix + "IGNORED_MODULES", None) + + if self.cpu_ram_efficient_loading is None: + self.cpu_ram_efficient_loading = ( + str_to_bool(os.environ.get(env_prefix + "CPU_RAM_EFFICIENT_LOADING", "False")) == 1 + ) + else: + # We still need to set it for transformers + os.environ[env_prefix + "CPU_RAM_EFFICIENT_LOADING"] = str(self.cpu_ram_efficient_loading) + # There's no need to specify sync_module_states in FSDP2 + if self.fsdp_version == 1 and self.cpu_ram_efficient_loading and not self.sync_module_states: + warnings.warn( + "sync_module_states cannot be False since efficient cpu ram loading enabled. " + "Setting sync_module_states to True." + ) + self.sync_module_states = True + if isinstance(self.mixed_precision_policy, str): + # override is True since self.mixed_precision_policy is not None + # has to be overwritten with the correct mixed precision object + self.set_mixed_precision(self.mixed_precision_policy, override=True) + elif isinstance(self.mixed_precision_policy, dict): + self.set_mixed_precision(self.mixed_precision_policy) + if self.mixed_precision_policy is not None: + self.validate_mixed_precision_policy() + + if self.sync_module_states: + if is_npu_available(): + device = torch.npu.current_device() + elif is_mlu_available(): + device = torch.mlu.current_device() + elif is_musa_available(): + device = torch.musa.current_device() + elif is_cuda_available(): + device = torch.cuda.current_device() + elif is_xpu_available(): + device = torch.xpu.current_device() + elif is_hpu_available(): + device = torch.hpu.current_device() + else: + raise RuntimeError( + "There are currently no available devices found, must be one of 'XPU', 'CUDA', 'MLU', 'NPU', 'MUSA', or 'HPU'." + ) + # Create a function that will be used to initialize the parameters of the model + # when using `sync_module_states` + self.param_init_fn = lambda x: x.to_empty(device=device, recurse=False) + if is_torch_version("<", "2.7.0") and self.fsdp_version == 2 and self.ignored_modules is not None: + _fsdp2_warnings.add( + "FSDP2 ignored_params/ignored_modules is not available for torch version < 2.7.0" + "Setting ignored_modules to None." + ) + self.ignored_modules = None + # Single warning for all deprecation warnings due to FSDP2 conversion + if _fsdp2_warnings: + logger.warning("Multiple deprecation warnings due to FSDP2 conversion:\n".join(_fsdp2_warnings)) + + def set_state_dict_type(self, state_dict_type=None): + """ + Set the state dict config based on the `StateDictType`. + """ + from torch.distributed.fsdp.fully_sharded_data_parallel import ( + FullOptimStateDictConfig, + FullStateDictConfig, + ShardedOptimStateDictConfig, + ShardedStateDictConfig, + StateDictType, + ) + + # Override the state_dict_type if provided, typical use case: + # user trains with sharded, but final save is with full + if state_dict_type is not None: + self.state_dict_type = state_dict_type + + if self.state_dict_type is None: + self.state_dict_type = os.environ.get( + "FSDP_STATE_DICT_TYPE", + "FULL_STATE_DICT" if self.fsdp_version == 1 else "SHARDED_STATE_DICT", + ) + if isinstance(self.state_dict_type, str): + if self.state_dict_type.isdigit(): + self.state_dict_type = StateDictType(int(self.state_dict_type)) + else: + self.state_dict_type = StateDictType[self.state_dict_type.upper()] + + if self.state_dict_type == StateDictType.FULL_STATE_DICT: + if self.state_dict_config is None: + self.state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + if self.optim_state_dict_config is None: + self.optim_state_dict_config = FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True) + elif self.state_dict_type == StateDictType.SHARDED_STATE_DICT: + if self.state_dict_config is None: + self.state_dict_config = ShardedStateDictConfig(offload_to_cpu=True) + if self.optim_state_dict_config is None: + self.optim_state_dict_config = ShardedOptimStateDictConfig(offload_to_cpu=True) + + if self.fsdp_version == 2 and self.state_dict_type == StateDictType.LOCAL_STATE_DICT: + raise ValueError( + "FSDP2 does not support LOCAL_STATE_DICT. " + "Please set `fsdp_state_dict_type` to `SHARDED_STATE_DICT` or `FULL_STATE_DICT`." + ) + + def set_auto_wrap_policy(self, model): + """ + Given `model`, creates an `auto_wrap_policy` based on the passed in policy and if we can use the + `transformer_cls_to_wrap` + """ + from torch.distributed.fsdp.wrap import ( + size_based_auto_wrap_policy, + transformer_auto_wrap_policy, + ) + + # First base off of `_no_split_modules` + no_split_modules = getattr(model, "_no_split_modules", None) + default_transformer_cls_names_to_wrap = list(no_split_modules) if no_split_modules is not None else [] + if self.auto_wrap_policy == transformer_auto_wrap_policy: + if self.transformer_cls_names_to_wrap is None: + self.transformer_cls_names_to_wrap = default_transformer_cls_names_to_wrap + transformer_cls_to_wrap = set() + for layer_class in self.transformer_cls_names_to_wrap: + transformer_cls = get_module_class_from_name(model, layer_class) + if transformer_cls is None: + raise ValueError(f"Could not find the transformer layer class {layer_class} in the model.") + transformer_cls_to_wrap.add(transformer_cls) + # Finally we set the auto_wrap_policy to a callable + self.auto_wrap_policy = functools.partial( + self.auto_wrap_policy, transformer_layer_cls=transformer_cls_to_wrap + ) + + elif self.auto_wrap_policy == size_based_auto_wrap_policy: + # If zero, we silently ignore it. + if self.min_num_params > 0: + self.auto_wrap_policy = functools.partial(self.auto_wrap_policy, min_num_params=self.min_num_params) + else: + self.auto_wrap_policy = None + + def set_mixed_precision(self, mixed_precision, buffer_autocast=False, override=False): + "Sets the mixed precision policy for FSDP" + mixed_precision_mapping = { + "fp8": torch.bfloat16, + "fp16": torch.float16, + "bf16": torch.bfloat16, + "fp32": torch.float32, + } + dtype = mixed_precision + if isinstance(mixed_precision, str): + dtype = mixed_precision_mapping.get(mixed_precision, None) + if dtype is None: + raise ValueError( + f"Invalid mixed precision: {mixed_precision}. Must be one of {list(mixed_precision_mapping.keys())}" + ) + elif isinstance(mixed_precision, torch.dtype) and mixed_precision not in mixed_precision_mapping.values(): + raise ValueError( + f"Invalid mixed precision: {mixed_precision}. Must be one of {list(mixed_precision_mapping.values())}" + ) + + buffer_type = torch.float32 if buffer_autocast else dtype + + if self.fsdp_version == 1: + from torch.distributed.fsdp import MixedPrecision + elif self.fsdp_version == 2: + from torch.distributed.fsdp import MixedPrecisionPolicy as MixedPrecision + + if override or self.mixed_precision_policy is None: + dtype_args = {"param_dtype": dtype, "reduce_dtype": dtype} + if self.fsdp_version == 1: + dtype_args["buffer_dtype"] = buffer_type + else: + dtype_args["output_dtype"] = dtype + # TODO(s1ro1): `cast_forward_inputs` for FSDP2? + self.mixed_precision_policy = MixedPrecision(**dtype_args) + elif isinstance(self.mixed_precision_policy, dict): + # Check for incompatible types + valid_keys = ["param_dtype", "reduce_dtype"] + ( + ["buffer_dtype"] if self.fsdp_version == 1 else ["output_dtype"] + ) + missing_keys = [k for k in valid_keys if k not in self.mixed_precision_policy] + invalid_values = [ + k for k, v in self.mixed_precision_policy.items() if v not in mixed_precision_mapping.values() + ] + if missing_keys or invalid_values: + raise ValueError( + f"Invalid mixed precision policy: {self.mixed_precision_policy}. " + f"Must be a `dict` with keys {valid_keys}." + f"Values must be one of {list(mixed_precision_mapping.values())}" + ) + self.mixed_precision_policy = MixedPrecision(**self.mixed_precision_policy) + + def validate_mixed_precision_policy(self): + """ + Validates the mixed precision policy, abstracted away to not bring in the imports if not needed. + """ + if self.fsdp_version == 2: + from torch.distributed.fsdp import MixedPrecisionPolicy as MixedPrecision + else: + from torch.distributed.fsdp import MixedPrecision + + if not isinstance(self.mixed_precision_policy, MixedPrecision): + required_type = ( + "`torch.distributed.fsdp.MixedPrecisionPolicy`" + if self.fsdp_version == 2 + else "`torch.distributed.fsdp.MixedPrecision`" + ) + raise ValueError(f"mixed_precision_policy must be an instance of {required_type}.") + + def set_cpu_offload(self): + if self.fsdp_version == 2: + from torch.distributed.fsdp import CPUOffloadPolicy, OffloadPolicy + else: + from torch.distributed.fsdp import CPUOffload + + if isinstance(self.cpu_offload, bool): + if self.fsdp_version == 2: + if not self.cpu_offload: + self.cpu_offload = OffloadPolicy() + else: + self.cpu_offload = CPUOffloadPolicy() + else: + self.cpu_offload = CPUOffload(offload_params=self.cpu_offload) + + def validate_cpu_offload(self): + if self.fsdp_version == 2: + from torch.distributed.fsdp import OffloadPolicy + else: + from torch.distributed.fsdp import CPUOffload + + if self.fsdp_version == 2 and not isinstance(self.cpu_offload, OffloadPolicy): + raise ValueError( + f"`cpu_offload` must be an instance of `torch.distributed.fsdp.OffloadPolicy` in FSDP2, got {self.cpu_offload}" + ) + if self.fsdp_version == 1 and not isinstance(self.cpu_offload, CPUOffload): + raise ValueError( + f"`cpu_offload` must be an instance of `torch.distributed.fsdp.CPUOffload` in FSDP1, got {self.cpu_offload}" + ) + + +@dataclass +class TorchTensorParallelPlugin: + """ + This plugin is used to enable tensor parallelism using PyTorch >= 2.0. + """ + + tp_size: int = field( + default=1, + metadata={"help": "tensor parallel size will be used in the device mesh preparation"}, + ) + + # torch_device_mesh is of type "torch.distributed.DeviceMesh" + torch_device_mesh: Optional["torch.distributed.DeviceMesh"] = field(default=None) + + +@dataclass +class TorchContextParallelConfig: + """ + This class holds the configuration for context parallelism in PyTorch. + """ + + cp_comm_strategy: Optional[str] = field( + default=None, + metadata={ + "help": "Communication strategy for context parallelism. Can be one of 'allgather' or 'alltoall'. Defaults to 'allgather'." + }, + ) + + def __post_init__(self): + if not is_torch_version(">=", BETA_CP_AVAILABLE_PYTORCH_VERSION): + raise ValueError( + f"FSDP2-based Context parallelism is only available in PyTorch {BETA_CP_AVAILABLE_PYTORCH_VERSION} and later versions. " + "Please upgrade your PyTorch version." + ) + + if self.cp_comm_strategy is None: + self.cp_comm_strategy = os.environ.get("PARALLELISM_CONFIG_CP_COMM_STRATEGY", "allgather") + if self.cp_comm_strategy not in ["allgather", "alltoall"]: + raise ValueError( + f"Invalid cp_comm_strategy: {self.cp_comm_strategy}. Must be one of 'allgather' or 'alltoall'." + ) + + +@dataclass +class DeepSpeedSequenceParallelConfig: + sp_seq_length: Optional[int] = field( + default=None, + metadata={ + "help": "Sequence length for when batches are all of the same length. For variable sequence lengths across batches set `sp_seq_length_is_variable=True` and leave this field unset" + }, + ) + sp_seq_length_is_variable: Optional[bool] = field( + default=None, + metadata={ + "help": "If `True` will work with a sequence length that may change between batches, in which case `sp_seq_length` value can be set to anything divisible by cp size or remain unset. If `False` then `sp_seq_length` needs to match the batch's sequence length dimension. The default is `True`." + }, + ) + sp_attn_implementation: Optional[str] = field( + default=None, + metadata={ + "help": "Attention implementation to use. Can be one of 'flash_attention_2', 'flash_attention_3' or 'sdpa'. Defaults to `sdpa`." + }, + ) + + def __post_init__(self): + # sp_seq_length_is_variable and sp_seq_length are interconnected + if self.sp_seq_length_is_variable is None: + self.sp_seq_length_is_variable = ( + os.environ.get("PARALLELISM_CONFIG_SP_SEQ_LENGTH_IS_VARIABLE", "true").lower() == "true" + ) + + if not self.sp_seq_length_is_variable and self.sp_seq_length is None: + if "PARALLELISM_CONFIG_SP_SEQ_LENGTH" not in os.environ: + raise ValueError( + "when `sp_seq_length_is_variable` is `False` `sp_seq_length` must be provided either through the constructor or the environment variable PARALLELISM_CONFIG_SP_SEQ_LENGTH" + ) + else: + self.sp_seq_length = os.environ.get("PARALLELISM_CONFIG_SP_SEQ_LENGTH") + self.sp_seq_length = None if self.sp_seq_length == "None" else int(self.sp_seq_length) + + if self.sp_attn_implementation is None: + self.sp_attn_implementation = os.environ.get("PARALLELISM_CONFIG_SP_ATTN_IMPLEMENTATION", None) + + if self.sp_attn_implementation is not None and self.sp_attn_implementation not in [ + "flash_attention_2", + "flash_attention_3", + "sdpa", + ]: + raise ValueError( + f"Invalid sp_attn_implementation: {self.sp_attn_implementation}. Must be one of 'flash_attention_2', 'flash_attention_3' or 'sdpa'." + ) + + +@dataclass +class TorchTensorParallelConfig: + """ + Use this object in your [`Accelerator`] to customize your torch tensor parallelism. + """ + + enable_async_tp: bool = False + + def __post_init__(self): + if not is_torch_version(">=", BETA_TP_AVAILABLE_PYTORCH_VERSION): + raise ValueError( + f"Torch tensor parallelism is only available in PyTorch {BETA_TP_AVAILABLE_PYTORCH_VERSION} and later versions. " + "Please upgrade your PyTorch version." + ) + + if not compare_versions("transformers", ">=", BETA_TP_AVAILABLE_TRANSFORMERS_VERSION): + raise ValueError(f"TP requires transformers >= {BETA_TP_AVAILABLE_TRANSFORMERS_VERSION}") + + if self.enable_async_tp: + warnings.warn("Async tensor parallelism is currently not supported, ignoring this option.") + + +@dataclass +class MegatronLMPlugin: + """ + Plugin for Megatron-LM to enable tensor, pipeline, sequence and data parallelism. Also to enable selective + activation recomputation and optimized fused kernels. + + Args: + tp_degree (`int`, defaults to `None`): + Tensor parallelism degree. + pp_degree (`int`, defaults to `None`): + Pipeline parallelism degree. + num_micro_batches (`int`, defaults to `None`): + Number of micro-batches. + gradient_clipping (`float`, defaults to `None`): + Gradient clipping value based on global L2 Norm (0 to disable). + sequence_parallelism (`bool`, defaults to `None`): + Enable sequence parallelism. + recompute_activations (`bool`, defaults to `None`): + Enable selective activation recomputation. + use_distributed_optimizr (`bool`, defaults to `None`): + Enable distributed optimizer. + pipeline_model_parallel_split_rank (`int`, defaults to `None`): + Rank where encoder and decoder should be split. + num_layers_per_virtual_pipeline_stage (`int`, defaults to `None`): + Number of layers per virtual pipeline stage. + is_train_batch_min (`str`, defaults to `True`): + If both tran & eval dataloaders are specified, this will decide the `micro_batch_size`. + train_iters (`int`, defaults to `None`): + Total number of samples to train over all training runs. Note that either train-iters or train-samples + should be provided when using `MegatronLMDummyScheduler`. + train_samples (`int`, defaults to `None`): + Total number of samples to train over all training runs. Note that either train-iters or train-samples + should be provided when using `MegatronLMDummyScheduler`. + weight_decay_incr_style (`str`, defaults to `'constant'`): + Weight decay increment function. choices=["constant", "linear", "cosine"]. + start_weight_decay (`float`, defaults to `None`): + Initial weight decay coefficient for L2 regularization. + end_weight_decay (`float`, defaults to `None`): + End of run weight decay coefficient for L2 regularization. + lr_decay_style (`str`, defaults to `'linear'`): + Learning rate decay function. choices=['constant', 'linear', 'cosine']. + lr_decay_iters (`int`, defaults to `None`): + Number of iterations for learning rate decay. If None defaults to `train_iters`. + lr_decay_samples (`int`, defaults to `None`): + Number of samples for learning rate decay. If None defaults to `train_samples`. + lr_warmup_iters (`int`, defaults to `None`): + Number of iterations to linearly warmup learning rate over. + lr_warmup_samples (`int`, defaults to `None`): + Number of samples to linearly warmup learning rate over. + lr_warmup_fraction (`float`, defaults to `None`): + Fraction of lr-warmup-(iters/samples) to linearly warmup learning rate over. + min_lr (`float`, defaults to `0`): + Minimum value for learning rate. The scheduler clip values below this threshold. + consumed_samples (`List`, defaults to `None`): + Number of samples consumed in the same order as the dataloaders to `accelerator.prepare` call. + no_wd_decay_cond (`Optional`, defaults to `None`): + Condition to disable weight decay. + scale_lr_cond (`Optional`, defaults to `None`): + Condition to scale learning rate. + lr_mult (`float`, defaults to `1.0`): + Learning rate multiplier. + megatron_dataset_flag (`bool`, defaults to `False`): + Whether the format of dataset follows Megatron-LM Indexed/Cached/MemoryMapped format. + seq_length (`int`, defaults to `None`): + Maximum sequence length to process. + encoder_seq_length (`int`, defaults to `None`): + Maximum sequence length to process for the encoder. + decoder_seq_length (`int`, defaults to `None`): + Maximum sequence length to process for the decoder. + tensorboard_dir (`str`, defaults to `None`): + Path to save tensorboard logs. + set_all_logging_options (`bool`, defaults to `False`): + Whether to set all logging options. + eval_iters (`int`, defaults to `100`): + Number of iterations to run for evaluation validation/test for. + eval_interval (`int`, defaults to `1000`): + Interval between running evaluation on validation set. + return_logits (`bool`, defaults to `False`): + Whether to return logits from the model. + custom_train_step_class (`Optional`, defaults to `None`): + Custom train step class. + custom_train_step_kwargs (`Optional`, defaults to `None`): + Custom train step kwargs. + custom_model_provider_function (`Optional`, defaults to `None`): + Custom model provider function. + custom_prepare_model_function (`Optional`, defaults to `None`): + Custom prepare model function. + custom_megatron_datasets_provider_function (`Optional`, defaults to `None`): + Custom megatron train_valid_test datasets provider function. + custom_get_batch_function (`Optional`, defaults to `None`): + Custom get batch function. + custom_loss_function (`Optional`, defaults to `None`): + Custom loss function. + other_megatron_args (`Optional`, defaults to `None`): + Other Megatron-LM arguments. Please refer Megatron-LM. + """ + + tp_degree: int = field(default=None, metadata={"help": "tensor parallelism degree."}) + pp_degree: int = field(default=None, metadata={"help": "pipeline parallelism degree."}) + use_custom_fsdp: bool = field(default=None, metadata={"help": "use custom fsdp."}) + overlap_cpu_optimizer_d2h_h2d: bool = field( + default=None, metadata={"help": "overlap CPU optimizer step, gradients D2H and updated parameters H2D."} + ) + no_load_optim: bool = field(default=None, metadata={"help": "do not load optimizer."}) + eod_mask_loss: bool = field(default=None, metadata={"help": "use eod mask loss."}) + no_save_optim: bool = field(default=None, metadata={"help": "do not save optimizer."}) + optimizer_cpu_offload: bool = field(default=None, metadata={"help": "use CPU offload for optimizer."}) + use_precision_aware_optimizer: bool = field(default=None, metadata={"help": "use precision aware optimizer."}) + decoder_last_pipeline_num_layers: int = field( + default=None, + metadata={ + "help": "decoder last pipeline number of layers, default None is even split of transformer layers across all pipeline stages." + }, + ) + recompute_granularity: str = field(default=None, metadata={"help": "recompute granularity (full, selective)."}) + recompute_method: str = field(default=None, metadata={"help": "recompute method (uniform, block)."}) + recompute_num_layers: int = field(default=None, metadata={"help": "number of layers to recompute."}) + attention_backend: bool = field(default=None, metadata={"help": "enable attention backend."}) + expert_model_parallel_size: int = field(default=None, metadata={"help": "expert model parallel size."}) + context_parallel_size: int = field(default=None, metadata={"help": "context parallel size."}) + attention_dropout: float = field(default=None, metadata={"help": "attention dropout rate."}) + hidden_dropout: float = field(default=None, metadata={"help": "hidden dropout rate."}) + attention_softmax_in_fp32: bool = field(default=None, metadata={"help": "use fp32 for attention softmax."}) + expert_tensor_parallel_size: int = field(default=None, metadata={"help": "expert tensor parallel size."}) + calculate_per_token_loss: bool = field(default=None, metadata={"help": "calculate per token loss."}) + use_rotary_position_embeddings: bool = field(default=None, metadata={"help": "use rotary position embeddings."}) + num_micro_batches: int = field(default=None, metadata={"help": "number of micro-batches."}) + gradient_clipping: float = field( + default=None, + metadata={"help": "gradient clipping value based on global L2 Norm (0 to disable)"}, + ) + sequence_parallelism: bool = field( + default=None, + metadata={"help": "enable sequence parallelism"}, + ) + recompute_activations: bool = field( + default=None, + metadata={"help": "enable selective activation recomputation"}, + ) + use_distributed_optimizer: bool = field( + default=None, + metadata={"help": "enable distributed optimizer"}, + ) + pipeline_model_parallel_split_rank: int = field( + default=None, + metadata={"help": "Rank where encoder and decoder should be split."}, + ) + num_layers_per_virtual_pipeline_stage: int = field( + default=None, metadata={"help": "Number of layers per virtual pipeline stage."} + ) + is_train_batch_min: str = field( + default=True, + metadata={"help": "If both train & eval dataloaders are specified, this will decide the micro_batch_size"}, + ) + train_iters: int = field( + default=None, + metadata={ + "help": "Total number of iterations to train over all training runs. " + "Note that either train-iters or train-samples should be provided when using `MegatronLMDummyScheduler`" + }, + ) + train_samples: int = field( + default=None, + metadata={ + "help": "Total number of samples to train over all training runs. " + "Note that either train-iters or train-samples should be provided when using `MegatronLMDummyScheduler`" + }, + ) + weight_decay_incr_style: str = field( + default="constant", + metadata={"help": 'Weight decay increment function. choices=["constant", "linear", "cosine"]. '}, + ) + start_weight_decay: float = field( + default=None, + metadata={"help": "Initial weight decay coefficient for L2 regularization."}, + ) + end_weight_decay: float = field( + default=None, + metadata={"help": "End of run weight decay coefficient for L2 regularization."}, + ) + lr_decay_style: str = field( + default="linear", + metadata={"help": "Learning rate decay function. choices=['constant', 'linear', 'cosine']."}, + ) + lr_decay_iters: int = field( + default=None, + metadata={"help": "Number of iterations for learning rate decay. If None defaults to `train_iters`."}, + ) + lr_decay_samples: int = field( + default=None, + metadata={"help": "Number of samples for learning rate decay. If None defaults to `train_samples`."}, + ) + lr_warmup_iters: int = field( + default=None, + metadata={"help": "number of iterations to linearly warmup learning rate over."}, + ) + lr_warmup_samples: int = field( + default=None, + metadata={"help": "number of samples to linearly warmup learning rate over."}, + ) + lr_warmup_fraction: float = field( + default=None, + metadata={"help": "fraction of lr-warmup-(iters/samples) to linearly warmup learning rate over."}, + ) + min_lr: float = field( + default=0, + metadata={"help": "Minimum value for learning rate. The scheduler clip values below this threshold."}, + ) + consumed_samples: list[int] = field( + default=None, + metadata={ + "help": "Number of samples consumed in the same order as the dataloaders to `accelerator.prepare` call." + }, + ) + no_wd_decay_cond: Optional[Callable] = field(default=None, metadata={"help": "Condition to disable weight decay."}) + scale_lr_cond: Optional[Callable] = field(default=None, metadata={"help": "Condition to scale learning rate."}) + lr_mult: float = field(default=1.0, metadata={"help": "Learning rate multiplier."}) + megatron_dataset_flag: bool = field( + default=False, + metadata={"help": "Whether the format of dataset follows Megatron-LM Indexed/Cached/MemoryMapped format."}, + ) + seq_length: int = field( + default=None, + metadata={"help": "Maximum sequence length to process."}, + ) + encoder_seq_length: int = field( + default=None, + metadata={"help": "Maximum sequence length to process for the encoder."}, + ) + decoder_seq_length: int = field( + default=None, + metadata={"help": "Maximum sequence length to process for the decoder."}, + ) + tensorboard_dir: str = field( + default=None, + metadata={"help": "Path to save tensorboard logs."}, + ) + set_all_logging_options: bool = field( + default=False, + metadata={"help": "Whether to set all logging options."}, + ) + eval_iters: int = field( + default=100, + metadata={"help": "Number of iterations to run for evaluation validation/test for."}, + ) + eval_interval: int = field( + default=1000, + metadata={"help": "Interval between running evaluation on validation set."}, + ) + return_logits: bool = field( + default=False, + metadata={"help": "Whether to return logits from the model."}, + ) + + # custom train step args + custom_train_step_class: Optional[Any] = field( + default=None, + metadata={"help": "Custom train step class."}, + ) + custom_train_step_kwargs: Optional[dict[str, Any]] = field( + default=None, + metadata={"help": "Custom train step kwargs."}, + ) + + # custom model args + custom_model_provider_function: Optional[Callable] = field( + default=None, + metadata={"help": "Custom model provider function."}, + ) + custom_prepare_model_function: Optional[Callable] = field( + default=None, + metadata={"help": "Custom prepare model function."}, + ) + custom_megatron_datasets_provider_function: Optional[Callable] = field( + default=None, + metadata={"help": "Custom megatron train_valid_test datasets provider function."}, + ) + custom_get_batch_function: Optional[Callable] = field( + default=None, + metadata={"help": "Custom get batch function."}, + ) + custom_loss_function: Optional[Callable] = field( + default=None, + metadata={"help": "Custom loss function."}, + ) + + # remaining args such as enabling Alibi/ROPE positional embeddings, + # wandb logging, Multi-Query Attention, etc. + other_megatron_args: Optional[dict[str, Any]] = field( + default=None, + metadata={"help": "Other Megatron-LM arguments. Please refer Megatron-LM"}, + ) + + def __post_init__(self): + prefix = "MEGATRON_LM_" + if self.tp_degree is None: + self.tp_degree = int(os.environ.get(prefix + "TP_DEGREE", 1)) + if self.pp_degree is None: + self.pp_degree = int(os.environ.get(prefix + "PP_DEGREE", 1)) + if self.use_custom_fsdp is None: + self.use_custom_fsdp = str_to_bool(os.environ.get(prefix + "USE_CUSTOM_FSDP", "False")) == 1 + if self.no_load_optim is None: + self.no_load_optim = str_to_bool(os.environ.get(prefix + "NO_LOAD_OPTIM", "False")) == 1 + if self.eod_mask_loss is None: + self.eod_mask_loss = str_to_bool(os.environ.get(prefix + "EOD_MASK_LOSS", "False")) == 1 + if self.no_save_optim is None: + self.no_save_optim = str_to_bool(os.environ.get(prefix + "NO_SAVE_OPTIM", "False")) == 1 + if self.optimizer_cpu_offload is None: + self.optimizer_cpu_offload = str_to_bool(os.environ.get(prefix + "OPTIMIZER_CPU_OFFLOAD", "False")) == 1 + if self.overlap_cpu_optimizer_d2h_h2d is None: + self.overlap_cpu_optimizer_d2h_h2d = ( + str_to_bool(os.environ.get(prefix + "OVERLAP_CPU_OPTIMIZER_D2H_H2D", "False")) == 1 + ) + if self.use_precision_aware_optimizer is None: + self.use_precision_aware_optimizer = ( + str_to_bool(os.environ.get(prefix + "USE_PRECISION_AWARE_OPTIMIZER", "False")) == 1 + ) + if self.decoder_last_pipeline_num_layers is None: + if os.environ.get(prefix + "DECODER_LAST_PIPELINE_NUM_LAYERS") is not None: + self.decoder_last_pipeline_num_layers = int( + os.environ.get(prefix + "DECODER_LAST_PIPELINE_NUM_LAYERS", 0) + ) + else: + self.decoder_last_pipeline_num_layers = None + if self.num_micro_batches is None: + self.num_micro_batches = int(os.environ.get(prefix + "NUM_MICRO_BATCHES", 1)) + if self.gradient_clipping is None: + self.gradient_clipping = float(os.environ.get(prefix + "GRADIENT_CLIPPING", 1.0)) + if self.recompute_activations is None: + self.recompute_activations = str_to_bool(os.environ.get(prefix + "RECOMPUTE_ACTIVATIONS", "False")) == 1 + if self.use_distributed_optimizer is None: + self.use_distributed_optimizer = ( + str_to_bool(os.environ.get(prefix + "USE_DISTRIBUTED_OPTIMIZER", "False")) == 1 + ) + if self.sequence_parallelism is None: + self.sequence_parallelism = str_to_bool(os.environ.get(prefix + "SEQUENCE_PARALLELISM", "False")) == 1 + if self.recompute_granularity is None: + self.recompute_granularity = os.environ.get(prefix + "RECOMPUTE_GRANULARITY", "full") + if self.recompute_method is None: + self.recompute_method = os.environ.get(prefix + "RECOMPUTE_METHOD", "uniform") + if self.recompute_num_layers is None: + self.recompute_num_layers = int(os.environ.get(prefix + "RECOMPUTE_NUM_LAYERS", 1)) + if self.attention_backend is None: + self.attention_backend = str_to_bool(os.environ.get(prefix + "ATTENTION_BACKEND", "True")) == 1 + if self.expert_model_parallel_size is None: + self.expert_model_parallel_size = int(os.environ.get(prefix + "EXPERT_MODEL_PARALLEL_SIZE", 1)) + if self.context_parallel_size is None: + self.context_parallel_size = int(os.environ.get(prefix + "CONTEXT_PARALLEL_SIZE", 2)) + if self.attention_dropout is None: + self.attention_dropout = float(os.environ.get(prefix + "ATTENTION_DROPOUT", "0.0")) + if self.hidden_dropout is None: + self.hidden_dropout = float(os.environ.get(prefix + "HIDDEN_DROPOUT", "0.0")) + if self.attention_softmax_in_fp32 is None: + self.attention_softmax_in_fp32 = ( + str_to_bool(os.environ.get(prefix + "ATTENTION_SOFTMAX_IN_FP32", "True")) == 1 + ) + if self.expert_tensor_parallel_size is None: + self.expert_tensor_parallel_size = int(os.environ.get(prefix + "EXPERT_TENSOR_PARALLEL_SIZE", 1)) + if self.calculate_per_token_loss is None: + self.calculate_per_token_loss = ( + str_to_bool(os.environ.get(prefix + "CALCULATE_PER_TOKEN_LOSS", "True")) == 1 + ) + if self.use_rotary_position_embeddings is None: + self.use_rotary_position_embeddings = ( + str_to_bool(os.environ.get(prefix + "USE_ROTARY_POSITION_EMBEDDINGS", "True")) == 1 + ) + + if self.pp_degree > 1 or self.use_distributed_optimizer: + self.DDP_impl = "local" + else: + self.DDP_impl = "torch" + + if self.consumed_samples is not None: + if len(self.consumed_samples) == 1: + self.consumed_samples.extend([0, 0]) + elif len(self.consumed_samples) == 2: + self.consumed_samples.append(0) + + self.megatron_lm_default_args = { + "tensor_model_parallel_size": self.tp_degree, + "pipeline_model_parallel_size": self.pp_degree, + "pipeline_model_parallel_split_rank": self.pipeline_model_parallel_split_rank, + "num_layers_per_virtual_pipeline_stage": self.num_layers_per_virtual_pipeline_stage, + "DDP_impl": self.DDP_impl, + "use_distributed_optimizer": self.use_distributed_optimizer, + "sequence_parallel": self.sequence_parallelism, + "clip_grad": self.gradient_clipping, + "num_micro_batches": self.num_micro_batches, + "consumed_samples": self.consumed_samples, + "no_wd_decay_cond": self.no_wd_decay_cond, + "scale_lr_cond": self.scale_lr_cond, + "lr_mult": self.lr_mult, + "megatron_dataset_flag": self.megatron_dataset_flag, + "eval_iters": self.eval_iters, + "eval_interval": self.eval_interval, + "use_custom_fsdp": self.use_custom_fsdp, + "no_load_optim": self.no_load_optim, + "eod_mask_loss": self.eod_mask_loss, + "no_save_optim": self.no_save_optim, + "optimizer_cpu_offload": self.optimizer_cpu_offload, + "overlap_cpu_optimizer_d2h_h2d": self.overlap_cpu_optimizer_d2h_h2d, + "use_precision_aware_optimizer": self.use_precision_aware_optimizer, + "decoder_last_pipeline_num_layers": self.decoder_last_pipeline_num_layers, + "recompute_granularity": self.recompute_granularity, + "recompute_method": self.recompute_method, + "recompute_num_layers": self.recompute_num_layers, + "attention_backend": self.attention_backend, + "expert_model_parallel_size": self.expert_model_parallel_size, + "context_parallel_size": self.context_parallel_size, + "attention_dropout": self.attention_dropout, + "hidden_dropout": self.hidden_dropout, + "attention_softmax_in_fp32": self.attention_softmax_in_fp32, + "expert_tensor_parallel_size": self.expert_tensor_parallel_size, + "calculate_per_token_loss": self.calculate_per_token_loss, + "use_rotary_position_embeddings": self.use_rotary_position_embeddings, + } + if self.tensorboard_dir is not None: + self.megatron_lm_default_args["tensorboard_dir"] = self.tensorboard_dir + if self.set_all_logging_options: + self.set_tensorboard_logging_options() + if self.other_megatron_args is not None: + self.megatron_lm_default_args.update(self.other_megatron_args) + + def set_network_size_args(self, model, batch_data=None): + model_config_type = model.config.model_type.lower() + for model_type in MODEL_CONFIGS_TO_MEGATRON_PARSERS.keys(): + if model_type in model_config_type: + MODEL_CONFIGS_TO_MEGATRON_PARSERS[model_type](self, model, batch_data) + return + raise ValueError( + f"Accelerate Megatron-LM integration not supports {model_config_type} model. " + "You can add your own model config parser." + ) + + def set_mixed_precision(self, mixed_precision): + if mixed_precision == "fp16": + self.megatron_lm_default_args["fp16"] = True + elif mixed_precision == "bf16": + self.megatron_lm_default_args["bf16"] = True + self.DDP_impl = "local" + self.megatron_lm_default_args["DDP_impl"] = self.DDP_impl + + def set_training_args(self, micro_batch_size, dp_degree): + self.data_parallel_size = dp_degree + self.micro_batch_size = micro_batch_size + self.global_batch_size = dp_degree * micro_batch_size * self.num_micro_batches + self.megatron_lm_default_args["data_parallel_size"] = self.data_parallel_size + self.megatron_lm_default_args["micro_batch_size"] = self.micro_batch_size + self.megatron_lm_default_args["global_batch_size"] = self.global_batch_size + + def set_optimizer_type(self, optimizer): + optimizer_name = optimizer.__class__.__name__.lower() + if "adam" in optimizer_name: + self.megatron_lm_default_args["optimizer"] = "adam" + self.megatron_lm_default_args["adam_beta1"] = optimizer.defaults["betas"][0] + self.megatron_lm_default_args["adam_beta2"] = optimizer.defaults["betas"][1] + self.megatron_lm_default_args["adam_eps"] = optimizer.defaults["eps"] + elif "sgd" in optimizer_name: + self.megatron_lm_default_args["optimizer"] = "sgd" + self.megatron_lm_default_args["sgd_momentum"] = optimizer.defaults["momentum"] + else: + raise ValueError(f"Optimizer {optimizer_name} is not supported by Megatron-LM") + + self.megatron_lm_default_args["lr"] = optimizer.defaults["lr"] + self.megatron_lm_default_args["weight_decay"] = optimizer.defaults["weight_decay"] + + def set_scheduler_args(self, scheduler): + if self.train_iters is None: + self.train_iters = scheduler.total_num_steps // self.megatron_lm_default_args["data_parallel_size"] + if self.train_samples is not None: + self.train_samples = None + warnings.warn( + "Ignoring `train_samples` as `train_iters` based on scheduler is being used for training." + ) + if self.lr_warmup_iters is None: + self.lr_warmup_iters = scheduler.warmup_num_steps // self.megatron_lm_default_args["data_parallel_size"] + if self.lr_warmup_samples is not None: + warnings.warn( + "Ignoring `lr_warmup_samples` as `lr_warmup_iters` based on scheduler is being used for training." + ) + self.lr_warmup_samples = 0 + + self.megatron_lm_default_args["train_iters"] = self.train_iters + self.megatron_lm_default_args["lr_warmup_iters"] = self.lr_warmup_iters + self.megatron_lm_default_args["train_samples"] = self.train_samples + self.megatron_lm_default_args["lr_warmup_samples"] = self.lr_warmup_samples + self.megatron_lm_default_args["lr_decay_iters"] = self.lr_decay_iters + self.megatron_lm_default_args["lr_decay_samples"] = self.lr_decay_samples + self.megatron_lm_default_args["lr_warmup_fraction"] = self.lr_warmup_fraction + self.megatron_lm_default_args["lr_decay_style"] = self.lr_decay_style + self.megatron_lm_default_args["weight_decay_incr_style"] = self.weight_decay_incr_style + self.megatron_lm_default_args["start_weight_decay"] = self.start_weight_decay + self.megatron_lm_default_args["end_weight_decay"] = self.end_weight_decay + self.megatron_lm_default_args["min_lr"] = self.min_lr + + def set_tensorboard_logging_options(self): + from megatron.training.arguments import _add_logging_args + + parser = argparse.ArgumentParser() + parser = _add_logging_args(parser) + logging_args = parser.parse_known_args() + self.dataset_args = vars(logging_args[0]) + for key, value in self.dataset_args.items(): + if key.startswith("log_"): + self.megatron_lm_default_args[key] = True + elif key.startswith("no_log_"): + self.megatron_lm_default_args[key.replace("no_", "")] = True + + +MODEL_CONFIGS_TO_MEGATRON_PARSERS = {} + + +def add_model_config_to_megatron_parser(model_type: str): + def add_model_config_parser_helper(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + MODEL_CONFIGS_TO_MEGATRON_PARSERS[model_type] = func + return wrapper + + return add_model_config_parser_helper + + +@add_model_config_to_megatron_parser("megatron-bert") +def parse_bert_config(megatron_lm_plugin, model, batch_data): + model_type_name = "bert" + num_layers = model.config.num_hidden_layers + hidden_size = model.config.hidden_size + num_attention_heads = model.config.num_attention_heads + max_position_embeddings = model.config.max_position_embeddings + num_labels = model.config.num_labels + orig_vocab_size = model.config.vocab_size + pretraining_flag = False + if "maskedlm" in model.__class__.__name__.lower(): + pretraining_flag = True + if megatron_lm_plugin.seq_length is not None: + if megatron_lm_plugin.encoder_seq_length is not None: + warnings.warn("Both `seq_length` and `encoder_seq_length` are set. Using `encoder_seq_length`.") + megatron_lm_plugin.seq_length = megatron_lm_plugin.encoder_seq_length + elif megatron_lm_plugin.encoder_seq_length is not None: + megatron_lm_plugin.seq_length = megatron_lm_plugin.encoder_seq_length + elif batch_data is not None: + megatron_lm_plugin.seq_length = batch_data["input_ids"].shape[1] + else: + megatron_lm_plugin.seq_length = max_position_embeddings + megatron_lm_plugin.megatron_lm_default_args["seq_length"] = megatron_lm_plugin.seq_length + megatron_lm_plugin.megatron_lm_default_args["model_type_name"] = model_type_name + megatron_lm_plugin.megatron_lm_default_args["num_layers"] = num_layers + megatron_lm_plugin.megatron_lm_default_args["hidden_size"] = hidden_size + megatron_lm_plugin.megatron_lm_default_args["num_attention_heads"] = num_attention_heads + megatron_lm_plugin.megatron_lm_default_args["max_position_embeddings"] = max_position_embeddings + megatron_lm_plugin.megatron_lm_default_args["pretraining_flag"] = pretraining_flag + megatron_lm_plugin.megatron_lm_default_args["orig_vocab_size"] = orig_vocab_size + megatron_lm_plugin.megatron_lm_default_args["model_return_dict"] = model.config.return_dict + megatron_lm_plugin.megatron_lm_default_args["num_labels"] = num_labels + + +@add_model_config_to_megatron_parser("gpt2") +def parse_gpt2_config(megatron_lm_plugin, model, batch_data): + model_type_name = "gpt" + num_layers = model.config.n_layer + hidden_size = model.config.n_embd + num_attention_heads = model.config.n_head + max_position_embeddings = model.config.n_positions + orig_vocab_size = model.config.vocab_size + pretraining_flag = True + if megatron_lm_plugin.seq_length is not None: + if megatron_lm_plugin.decoder_seq_length is not None: + warnings.warn("Both `seq_length` and `decoder_seq_length` are set. Using `decoder_seq_length`.") + megatron_lm_plugin.seq_length = megatron_lm_plugin.decoder_seq_length + elif megatron_lm_plugin.decoder_seq_length is not None: + megatron_lm_plugin.seq_length = megatron_lm_plugin.decoder_seq_length + elif batch_data is not None: + megatron_lm_plugin.seq_length = batch_data["input_ids"].shape[1] + else: + megatron_lm_plugin.seq_length = max_position_embeddings + megatron_lm_plugin.megatron_lm_default_args["seq_length"] = megatron_lm_plugin.seq_length + megatron_lm_plugin.megatron_lm_default_args["return_logits"] = megatron_lm_plugin.return_logits + megatron_lm_plugin.megatron_lm_default_args["tokenizer_type"] = "GPT2BPETokenizer" + megatron_lm_plugin.megatron_lm_default_args["model_type_name"] = model_type_name + megatron_lm_plugin.megatron_lm_default_args["num_layers"] = num_layers + megatron_lm_plugin.megatron_lm_default_args["hidden_size"] = hidden_size + megatron_lm_plugin.megatron_lm_default_args["num_attention_heads"] = num_attention_heads + megatron_lm_plugin.megatron_lm_default_args["max_position_embeddings"] = max_position_embeddings + megatron_lm_plugin.megatron_lm_default_args["pretraining_flag"] = pretraining_flag + megatron_lm_plugin.megatron_lm_default_args["orig_vocab_size"] = orig_vocab_size + megatron_lm_plugin.megatron_lm_default_args["model_return_dict"] = model.config.return_dict + + +@add_model_config_to_megatron_parser("t5") +def parse_t5_config(megatron_lm_plugin, model, batch_data): + model_type_name = "t5" + num_layers = model.config.num_layers + hidden_size = model.config.d_model + num_attention_heads = model.config.num_heads + max_position_embeddings = model.config.n_positions if hasattr(model.config, "n_positions") else 1024 + orig_vocab_size = model.config.vocab_size + pretraining_flag = True + if megatron_lm_plugin.encoder_seq_length is None: + if batch_data is not None: + megatron_lm_plugin.encoder_seq_length = batch_data["input_ids"].shape[1] + else: + megatron_lm_plugin.encoder_seq_length = max_position_embeddings + if megatron_lm_plugin.decoder_seq_length is None: + if batch_data is not None: + megatron_lm_plugin.decoder_seq_length = batch_data["labels"].shape[1] + else: + megatron_lm_plugin.decoder_seq_length = max_position_embeddings + megatron_lm_plugin.megatron_lm_default_args["encoder_seq_length"] = megatron_lm_plugin.encoder_seq_length + megatron_lm_plugin.megatron_lm_default_args["decoder_seq_length"] = megatron_lm_plugin.decoder_seq_length + megatron_lm_plugin.megatron_lm_default_args["model_type_name"] = model_type_name + megatron_lm_plugin.megatron_lm_default_args["num_layers"] = num_layers + megatron_lm_plugin.megatron_lm_default_args["hidden_size"] = hidden_size + megatron_lm_plugin.megatron_lm_default_args["num_attention_heads"] = num_attention_heads + megatron_lm_plugin.megatron_lm_default_args["max_position_embeddings"] = max_position_embeddings + megatron_lm_plugin.megatron_lm_default_args["pretraining_flag"] = pretraining_flag + megatron_lm_plugin.megatron_lm_default_args["orig_vocab_size"] = orig_vocab_size + megatron_lm_plugin.megatron_lm_default_args["model_return_dict"] = model.config.return_dict + + +@add_model_config_to_megatron_parser("llama") +def parse_llama_config(megatron_lm_plugin, model, batch_data): + model_type_name = "gpt" + num_layers = model.config.num_hidden_layers + pretraining_flag = True + hidden_size = model.config.hidden_size + num_attention_heads = model.config.num_attention_heads + orig_vocab_size = model.config.vocab_size + + max_position_embeddings = model.config.max_position_embeddings + seq_length = getattr(model.config, "max_sequence_length", None) + if megatron_lm_plugin.seq_length is None: + if seq_length is not None: + megatron_lm_plugin.seq_length = seq_length + elif megatron_lm_plugin.decoder_seq_length is not None: + megatron_lm_plugin.seq_length = megatron_lm_plugin.decoder_seq_length + elif batch_data is not None: + megatron_lm_plugin.seq_length = batch_data["input_ids"].shape[1] + else: + megatron_lm_plugin.seq_length = max_position_embeddings + + megatron_lm_plugin.megatron_lm_default_args["return_logits"] = megatron_lm_plugin.return_logits + megatron_lm_plugin.megatron_lm_default_args["tokenizer_type"] = "Llama2Tokenizer" + megatron_lm_plugin.megatron_lm_default_args["model_type_name"] = model_type_name + megatron_lm_plugin.megatron_lm_default_args["num_layers"] = num_layers + megatron_lm_plugin.megatron_lm_default_args["pretraining_flag"] = pretraining_flag + megatron_lm_plugin.megatron_lm_default_args["hidden_size"] = hidden_size + megatron_lm_plugin.megatron_lm_default_args["num_attention_heads"] = num_attention_heads + megatron_lm_plugin.megatron_lm_default_args["orig_vocab_size"] = orig_vocab_size + megatron_lm_plugin.megatron_lm_default_args["max_position_embeddings"] = max_position_embeddings + megatron_lm_plugin.megatron_lm_default_args["seq_length"] = megatron_lm_plugin.seq_length + megatron_lm_plugin.megatron_lm_default_args["model_return_dict"] = model.config.return_dict + + +@add_model_config_to_megatron_parser("glm4_moe") +def parse_glm4_moe_config(megatron_lm_plugin, model, batch_data): + model_type_name = "gpt" + num_layers = model.config.num_hidden_layers + pretraining_flag = False + hidden_size = model.config.hidden_size + num_attention_heads = model.config.num_attention_heads + orig_vocab_size = model.config.vocab_size + + max_position_embeddings = model.config.max_position_embeddings + seq_length = getattr(model.config, "max_sequence_length", None) + if megatron_lm_plugin.seq_length is None: + if seq_length is not None: + megatron_lm_plugin.seq_length = seq_length + elif megatron_lm_plugin.decoder_seq_length is not None: + megatron_lm_plugin.seq_length = megatron_lm_plugin.decoder_seq_length + elif batch_data is not None: + megatron_lm_plugin.seq_length = batch_data["input_ids"].shape[1] + else: + megatron_lm_plugin.seq_length = max_position_embeddings + + megatron_lm_plugin.megatron_lm_default_args["return_logits"] = megatron_lm_plugin.return_logits + megatron_lm_plugin.megatron_lm_default_args["tokenizer_type"] = "HuggingFaceTokenizer" + megatron_lm_plugin.megatron_lm_default_args["model_type_name"] = model_type_name + megatron_lm_plugin.megatron_lm_default_args["num_layers"] = num_layers + megatron_lm_plugin.megatron_lm_default_args["pretraining_flag"] = pretraining_flag + megatron_lm_plugin.megatron_lm_default_args["hidden_size"] = hidden_size + megatron_lm_plugin.megatron_lm_default_args["num_attention_heads"] = num_attention_heads + megatron_lm_plugin.megatron_lm_default_args["kv_channels"] = model.config.head_dim + megatron_lm_plugin.megatron_lm_default_args["orig_vocab_size"] = orig_vocab_size + megatron_lm_plugin.megatron_lm_default_args["max_position_embeddings"] = max_position_embeddings + megatron_lm_plugin.megatron_lm_default_args["seq_length"] = megatron_lm_plugin.seq_length + megatron_lm_plugin.megatron_lm_default_args["model_return_dict"] = model.config.return_dict + megatron_lm_plugin.megatron_lm_default_args["position_embedding_type"] = "rope" + megatron_lm_plugin.megatron_lm_default_args["original_model_type"] = model.config.model_type + megatron_lm_plugin.megatron_lm_default_args["qk_layernorm"] = ( + model.config.use_qk_norm + ) # this is true for glm4.5 but False for glm4.5-air. + megatron_lm_plugin.megatron_lm_default_args["add_bias_linear"] = False + megatron_lm_plugin.megatron_lm_default_args["group_query_attention"] = True + megatron_lm_plugin.megatron_lm_default_args["num_query_groups"] = model.config.num_key_value_heads + megatron_lm_plugin.megatron_lm_default_args["ffn_hidden_size"] = model.config.intermediate_size + megatron_lm_plugin.megatron_lm_default_args["add_qkv_bias"] = True + megatron_lm_plugin.megatron_lm_default_args["normalization"] = "RMSNorm" + megatron_lm_plugin.megatron_lm_default_args["rotary-percent"] = 0.5 + megatron_lm_plugin.megatron_lm_default_args["swiglu"] = True + megatron_lm_plugin.megatron_lm_default_args["moe_ffn_hidden_size"] = model.config.moe_intermediate_size + megatron_lm_plugin.megatron_lm_default_args["moe_shared_expert_intermediate_size"] = ( + model.config.moe_intermediate_size + ) + megatron_lm_plugin.megatron_lm_default_args["moe_router_pre_softmax"] = True + megatron_lm_plugin.megatron_lm_default_args["moe_router_score_function"] = "sigmoid" + megatron_lm_plugin.megatron_lm_default_args["moe_router_enable_expert_bias"] = True + megatron_lm_plugin.megatron_lm_default_args["moe_router_bias_update_rate"] = 0 + megatron_lm_plugin.megatron_lm_default_args["moe_router_load_balancing_type"] = "seq_aux_loss" + megatron_lm_plugin.megatron_lm_default_args["moe_token_dispatcher_type"] = "alltoall" + megatron_lm_plugin.megatron_lm_default_args["moe_router_topk"] = model.config.num_experts_per_tok + megatron_lm_plugin.megatron_lm_default_args["moe_router_topk_scaling_factor"] = model.config.routed_scaling_factor + megatron_lm_plugin.megatron_lm_default_args["moe_layer_freq"] = [0] * model.config.first_k_dense_replace + [1] * ( + model.config.num_hidden_layers - model.config.first_k_dense_replace + ) + megatron_lm_plugin.megatron_lm_default_args["num_experts"] = model.config.n_routed_experts + megatron_lm_plugin.megatron_lm_default_args["moe_grouped_gemm"] = True + megatron_lm_plugin.megatron_lm_default_args["moe_router_dtype"] = "fp32" + megatron_lm_plugin.megatron_lm_default_args["moe_permute_fusion"] = True + megatron_lm_plugin.megatron_lm_default_args["moe_aux_loss_coeff"] = 0 + megatron_lm_plugin.megatron_lm_default_args["rotary_base"] = model.config.rope_theta + megatron_lm_plugin.megatron_lm_default_args["rope_type"] = "rope" + megatron_lm_plugin.megatron_lm_default_args["rotary_percent"] = model.config.partial_rotary_factor + megatron_lm_plugin.megatron_lm_default_args["norm_epsilon"] = 1e-3 + megatron_lm_plugin.megatron_lm_default_args["use_flash_attn"] = True + megatron_lm_plugin.megatron_lm_default_args["eos_token_id"] = model.config.eos_token_id + if getattr(model.config, "fp8_param", False): + megatron_lm_plugin.megatron_lm_default_args["fp8"] = model.config.fp8 + megatron_lm_plugin.megatron_lm_default_args["fp8_param"] = model.config.fp8_param + megatron_lm_plugin.megatron_lm_default_args["fp8_param_gather"] = model.config.fp8_param_gather + megatron_lm_plugin.megatron_lm_default_args["fp8_recipe"] = model.config.fp8_recipe + megatron_lm_plugin.megatron_lm_default_args["bf16"] = model.config.bf16 + megatron_lm_plugin.megatron_lm_default_args[ + "untie_embeddings_and_output_weights" + ] = not model.config.tie_word_embeddings + logger.info(f"Parsed GLM4 MoE config: {megatron_lm_plugin.megatron_lm_default_args}") + + +@dataclass +class BnbQuantizationConfig: + """ + A plugin to enable BitsAndBytes 4bit and 8bit quantization + + Args: + load_in_8bit (`bool`, defaults to `False`): + Enable 8bit quantization. + llm_int8_threshold (`float`, defaults to `6.0`): + Value of the outliner threshold. Only relevant when `load_in_8bit=True`. + load_in_4bit (`bool`, defaults to `False`): + Enable 4bit quantization. + bnb_4bit_quant_type (`str`, defaults to `fp4`): + Set the quantization data type in the `bnb.nn.Linear4Bit` layers. Options are {'fp4','np4'}. + bnb_4bit_use_double_quant (`bool`, defaults to `False`): + Enable nested quantization where the quantization constants from the first quantization are quantized + again. + bnb_4bit_compute_dtype (`bool`, defaults to `fp16`): + This sets the computational type which might be different than the input time. For example, inputs might be + fp32, but computation can be set to bf16 for speedups. Options are {'fp32','fp16','bf16'}. + torch_dtype (`torch.dtype`, defaults to `None`): + This sets the dtype of the remaining non quantized layers. `bitsandbytes` library suggests to set the value + to `torch.float16` for 8 bit model and use the same dtype as the compute dtype for 4 bit model. + skip_modules (`List[str]`, defaults to `None`): + An explicit list of the modules that we don't quantize. The dtype of these modules will be `torch_dtype`. + keep_in_fp32_modules (`List`, defaults to `None`): + An explicit list of the modules that we don't quantize. We keep them in `torch.float32`. + """ + + load_in_8bit: bool = field(default=False, metadata={"help": "enable 8bit quantization."}) + + llm_int8_threshold: float = field( + default=6.0, + metadata={"help": "value of the outliner threshold. only relevant when load_in_8bit=True"}, + ) + + load_in_4bit: bool = field(default=False, metadata={"help": "enable 4bit quantization."}) + + bnb_4bit_quant_type: str = field( + default="fp4", + metadata={ + "help": "set the quantization data type in the `bnb.nn.Linear4Bit` layers. Options are {'fp4','nf4'}." + }, + ) + + bnb_4bit_use_double_quant: bool = field( + default=False, + metadata={ + "help": "enable nested quantization where the quantization constants from the first quantization are quantized again." + }, + ) + + bnb_4bit_compute_dtype: str = field( + default="fp16", + metadata={ + "help": "This sets the computational type which might be different than the input time. For example, inputs might be " + "fp32, but computation can be set to bf16 for speedups. Options are {'fp32','fp16','bf16'}." + }, + ) + + torch_dtype: torch.dtype = field( + default=None, + metadata={ + "help": "this sets the dtype of the remaining non quantized layers. `bitsandbytes` library suggests to set the value" + "to `torch.float16` for 8 bit model and use the same dtype as the compute dtype for 4 bit model " + }, + ) + + skip_modules: list[str] = field( + default=None, + metadata={ + "help": "an explicit list of the modules that we don't quantize. The dtype of these modules will be `torch_dtype`." + }, + ) + + keep_in_fp32_modules: list[str] = field( + default=None, + metadata={"help": "an explicit list of the modules that we don't quantize. We keep them in `torch.float32`."}, + ) + + def __post_init__(self): + """ + Safety checker that arguments are correct - also replaces some NoneType arguments with their default values. + """ + if not isinstance(self.load_in_8bit, bool): + raise ValueError("load_in_8bit must be a boolean") + + if not isinstance(self.load_in_4bit, bool): + raise ValueError("load_in_4bit must be a boolean") + + if self.load_in_4bit and self.load_in_8bit: + raise ValueError("load_in_4bit and load_in_8bit can't be both True") + + if not self.load_in_4bit and not self.load_in_8bit: + raise ValueError("load_in_4bit and load_in_8bit can't be both False") + + if not isinstance(self.llm_int8_threshold, (int, float)): + raise ValueError("llm_int8_threshold must be a float or an int") + + if not isinstance(self.bnb_4bit_quant_type, str): + raise ValueError("bnb_4bit_quant_type must be a string") + elif self.bnb_4bit_quant_type not in ["fp4", "nf4"]: + raise ValueError(f"bnb_4bit_quant_type must be in ['fp4','nf4'] but found {self.bnb_4bit_quant_type}") + + if not isinstance(self.bnb_4bit_use_double_quant, bool): + raise ValueError("bnb_4bit_use_double_quant must be a boolean") + + if isinstance(self.bnb_4bit_compute_dtype, str): + if self.bnb_4bit_compute_dtype == "fp32": + self.bnb_4bit_compute_dtype = torch.float32 + elif self.bnb_4bit_compute_dtype == "fp16": + self.bnb_4bit_compute_dtype = torch.float16 + elif self.bnb_4bit_compute_dtype == "bf16": + self.bnb_4bit_compute_dtype = torch.bfloat16 + else: + raise ValueError( + f"bnb_4bit_compute_dtype must be in ['fp32','fp16','bf16'] but found {self.bnb_4bit_compute_dtype}" + ) + elif not isinstance(self.bnb_4bit_compute_dtype, torch.dtype): + raise ValueError("bnb_4bit_compute_dtype must be a string or a torch.dtype") + + if self.skip_modules is not None and not isinstance(self.skip_modules, list): + raise ValueError("skip_modules must be a list of strings") + + if self.keep_in_fp32_modules is not None and not isinstance(self.keep_in_fp32_modules, list): + raise ValueError("keep_in_fp_32_modules must be a list of strings") + + if self.load_in_4bit: + self.target_dtype = CustomDtype.INT4 + + if self.load_in_8bit: + self.target_dtype = torch.int8 + + if self.load_in_4bit and self.llm_int8_threshold != 6.0: + warnings.warn("llm_int8_threshold can only be used for model loaded in 8bit") + + if isinstance(self.torch_dtype, str): + if self.torch_dtype == "fp32": + self.torch_dtype = torch.float32 + elif self.torch_dtype == "fp16": + self.torch_dtype = torch.float16 + elif self.torch_dtype == "bf16": + self.torch_dtype = torch.bfloat16 + else: + raise ValueError(f"torch_dtype must be in ['fp32','fp16','bf16'] but found {self.torch_dtype}") + if self.load_in_8bit and self.torch_dtype is None: + self.torch_dtype = torch.float16 + + if self.load_in_4bit and self.torch_dtype is None: + self.torch_dtype = self.bnb_4bit_compute_dtype + + if not isinstance(self.torch_dtype, torch.dtype): + raise ValueError("torch_dtype must be a torch.dtype") + + +def get_module_class_from_name(module, name): + """ + Gets a class from a module by its name. + + Args: + module (`torch.nn.Module`): The module to get the class from. + name (`str`): The name of the class. + """ + modules_children = list(module.children()) + if module.__class__.__name__ == name: + return module.__class__ + elif len(modules_children) == 0: + return + else: + for child_module in modules_children: + module_class = get_module_class_from_name(child_module, name) + if module_class is not None: + return module_class diff --git a/accelerate/utils/deepspeed.py b/accelerate/utils/deepspeed.py new file mode 100644 index 0000000000000000000000000000000000000000..22db891c63d9bd48691acd87a15a206c270017a9 --- /dev/null +++ b/accelerate/utils/deepspeed.py @@ -0,0 +1,385 @@ +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +import json +import os +from copy import deepcopy + +from torch import optim + +from ..optimizer import AcceleratedOptimizer +from ..scheduler import AcceleratedScheduler +from .dataclasses import DistributedType +from .imports import is_bnb_available +from .versions import compare_versions + + +def map_pytorch_optim_to_deepspeed(optimizer): + """ + Args: + optimizer: torch.optim.Optimizer + + Returns the DeepSeedCPUOptimizer (deepspeed.ops) version of the optimizer. + """ + + defaults = {k: v for k, v in optimizer.defaults.items() if k in ["lr", "weight_decay"]} + + # Select the DeepSpeedCPUOptimizer based on the original optimizer class. + # DeepSpeedCPUAdam is the default + from deepspeed.ops.adam import DeepSpeedCPUAdam + + optimizer_class = DeepSpeedCPUAdam + + # For DeepSpeedCPUAdam (adamw_mode) + if compare_versions("deepspeed", ">=", "0.3.1"): + defaults["adamw_mode"] = False + is_adaw = isinstance(optimizer, optim.AdamW) + + if is_bnb_available() and not is_adaw: + import bitsandbytes.optim as bnb_opt + + if isinstance(optimizer, (bnb_opt.AdamW, bnb_opt.AdamW32bit)): + try: + is_adaw = optimizer.optim_bits == 32 + except AttributeError: + is_adaw = optimizer.args.optim_bits == 32 + else: + is_adaw = False + + if is_adaw: + defaults["adamw_mode"] = True + + # For DeepSpeedCPUAdagrad + if compare_versions("deepspeed", ">=", "0.5.5"): + # Check if the optimizer is PyTorch's Adagrad. + is_ada = isinstance(optimizer, optim.Adagrad) + # If not, and bitsandbytes is available, + # # check if the optimizer is the 32-bit bitsandbytes Adagrad. + if is_bnb_available() and not is_ada: + import bitsandbytes.optim as bnb_opt + + if isinstance(optimizer, (bnb_opt.Adagrad, bnb_opt.Adagrad32bit)): + try: + is_ada = optimizer.optim_bits == 32 + except AttributeError: + is_ada = optimizer.args.optim_bits == 32 + if is_ada: + from deepspeed.ops.adagrad import DeepSpeedCPUAdagrad + + optimizer_class = DeepSpeedCPUAdagrad + + # For DeepSpeedCPULion + if is_bnb_available(min_version="0.38.0") and compare_versions("deepspeed", ">=", "0.11.0"): + from bitsandbytes.optim import Lion, Lion32bit + + if isinstance(optimizer, (Lion, Lion32bit)): + try: + is_bnb_32bits = optimizer.optim_bits == 32 + except AttributeError: + is_bnb_32bits = optimizer.args.optim_bits == 32 + if is_bnb_32bits: + from deepspeed.ops.lion import DeepSpeedCPULion + + optimizer_class = DeepSpeedCPULion + + return optimizer_class(optimizer.param_groups, **defaults) + + +def get_active_deepspeed_plugin(state): + """ + Returns the currently active DeepSpeedPlugin. + + Raises: + ValueError: If DeepSpeed was not enabled and this function is called. + """ + if state.distributed_type != DistributedType.DEEPSPEED: + raise ValueError( + "Couldn't retrieve the active `DeepSpeedPlugin` as none were enabled. " + "Please make sure that either `Accelerator` is configured for `deepspeed` " + "or make sure that the desired `DeepSpeedPlugin` has been enabled (`AcceleratorState().select_deepspeed_plugin(name)`) " + "before calling this function." + ) + if not isinstance(state.deepspeed_plugins, dict): + return state.deepspeed_plugins + return next(plugin for plugin in state.deepspeed_plugins.values() if plugin.selected) + + +class HfDeepSpeedConfig: + """ + This object contains a DeepSpeed configuration dictionary and can be quickly queried for things like zero stage. + + A `weakref` of this object is stored in the module's globals to be able to access the config from areas where + things like the Trainer object is not available (e.g. `from_pretrained` and `_get_resized_embeddings`). Therefore + it's important that this object remains alive while the program is still running. + + [`Trainer`] uses the `HfTrainerDeepSpeedConfig` subclass instead. That subclass has logic to sync the configuration + with values of [`TrainingArguments`] by replacing special placeholder values: `"auto"`. Without this special logic + the DeepSpeed configuration is not modified in any way. + + Args: + config_file_or_dict (`Union[str, Dict]`): path to DeepSpeed config file or dict. + + """ + + def __init__(self, config_file_or_dict): + if isinstance(config_file_or_dict, dict): + # Don't modify user's data should they want to reuse it (e.g. in tests), because once we + # modified it, it will not be accepted here again, since `auto` values would have been overridden + config = deepcopy(config_file_or_dict) + elif os.path.exists(config_file_or_dict): + with open(config_file_or_dict, encoding="utf-8") as f: + config = json.load(f) + else: + try: + try: + # First try parsing as JSON directly + config = json.loads(config_file_or_dict) + except json.JSONDecodeError: + # If that fails, try base64 decoding + config_decoded = base64.urlsafe_b64decode(config_file_or_dict).decode("utf-8") + config = json.loads(config_decoded) + except (UnicodeDecodeError, AttributeError, ValueError): + raise ValueError( + f"Expected a string path to an existing deepspeed config, or a dictionary, or a base64 encoded string. Received: {config_file_or_dict}" + ) + + self.config = config + + self.set_stage_and_offload() + + def set_stage_and_offload(self): + # zero stage - this is done as early as possible, before model is created, to allow + # ``is_deepspeed_zero3_enabled`` query and getting to the early deepspeed config object + # during ``zero.Init()`` which needs to know the dtype, and some other hparams. + self._stage = self.get_value("zero_optimization.stage", -1) + + # offload + self._offload = False + if self.is_zero2() or self.is_zero3(): + offload_devices_valid = set(["cpu", "nvme"]) + offload_devices = set( + [ + self.get_value("zero_optimization.offload_optimizer.device"), + self.get_value("zero_optimization.offload_param.device"), + ] + ) + if len(offload_devices & offload_devices_valid) > 0: + self._offload = True + + def find_config_node(self, ds_key_long): + config = self.config + + # find the config node of interest if it exists + nodes = ds_key_long.split(".") + ds_key = nodes.pop() + for node in nodes: + config = config.get(node) + if config is None: + return None, ds_key + + return config, ds_key + + def get_value(self, ds_key_long, default=None): + """ + Returns the set value or `default` if no value is set + """ + config, ds_key = self.find_config_node(ds_key_long) + if config is None: + return default + return config.get(ds_key, default) + + def del_config_sub_tree(self, ds_key_long, must_exist=False): + """ + Deletes a sub-section of the config file if it's found. + + Unless `must_exist` is `True` the section doesn't have to exist. + """ + config = self.config + + # find the config node of interest if it exists + nodes = ds_key_long.split(".") + for node in nodes: + parent_config = config + config = config.get(node) + if config is None: + if must_exist: + raise ValueError(f"Can't find {ds_key_long} entry in the config: {self.config}") + else: + return + + # if found remove it + if parent_config is not None: + parent_config.pop(node) + + def is_true(self, ds_key_long): + """ + Returns `True`/``False` only if the value is set, always `False` otherwise. So use this method to ask the very + specific question of whether the value is set to `True` (and it's not set to `False`` or isn't set). + + """ + value = self.get_value(ds_key_long) + return False if value is None else bool(value) + + def is_false(self, ds_key_long): + """ + Returns `True`/``False` only if the value is set, always `False` otherwise. So use this method to ask the very + specific question of whether the value is set to `False` (and it's not set to `True`` or isn't set). + """ + value = self.get_value(ds_key_long) + return False if value is None else not bool(value) + + def is_zero2(self): + return self._stage == 2 + + def is_zero3(self): + return self._stage == 3 + + def is_offload(self): + return self._offload + + +class DeepSpeedEngineWrapper: + """ + Internal wrapper for deepspeed.runtime.engine.DeepSpeedEngine. This is used to follow conventional training loop. + + Args: + engine (deepspeed.runtime.engine.DeepSpeedEngine): deepspeed engine to wrap + """ + + def __init__(self, engine): + self.engine = engine + + def backward(self, loss, sync_gradients=True, **kwargs): + # Set gradient accumulation boundary based on Accelerate's sync_gradients state + # This tells DeepSpeed whether this is the final micro-batch before gradient sync + self.engine.set_gradient_accumulation_boundary(is_boundary=sync_gradients) + + # runs backpropagation and handles mixed precision + self.engine.backward(loss, **kwargs) + + # Only perform step and related operations at gradient accumulation boundaries + if sync_gradients: + # Deepspeed's `engine.step` performs the following operations: + # - gradient accumulation check + # - gradient clipping + # - optimizer step + # - zero grad + # - checking overflow + # - lr_scheduler step (only if engine.lr_scheduler is not None) + self.engine.step() + # and this plugin overrides the above calls with no-ops when Accelerate runs under + # Deepspeed, but allows normal functionality for non-Deepspeed cases thus enabling a simple + # training loop that works transparently under many training regimes. + + def get_global_grad_norm(self): + """Get the global gradient norm from DeepSpeed engine.""" + grad_norm = self.engine.get_global_grad_norm() + # Convert to scalar if it's a tensor + if hasattr(grad_norm, "item"): + return grad_norm.item() + return grad_norm + + +class DeepSpeedOptimizerWrapper(AcceleratedOptimizer): + """ + Internal wrapper around a deepspeed optimizer. + + Args: + optimizer (`torch.optim.optimizer.Optimizer`): + The optimizer to wrap. + """ + + def __init__(self, optimizer): + super().__init__(optimizer, device_placement=False, scaler=None) + self.__has_overflow__ = hasattr(self.optimizer, "overflow") + + def zero_grad(self, set_to_none=None): + pass # `accelerator.backward(loss)` is doing that automatically. Therefore, its implementation is not needed + + def step(self): + pass # `accelerator.backward(loss)` is doing that automatically. Therefore, its implementation is not needed + + @property + def step_was_skipped(self): + """Whether or not the optimizer step was done, or skipped because of gradient overflow.""" + if self.__has_overflow__: + return self.optimizer.overflow + return False + + +class DeepSpeedSchedulerWrapper(AcceleratedScheduler): + """ + Internal wrapper around a deepspeed scheduler. + + Args: + scheduler (`torch.optim.lr_scheduler.LambdaLR`): + The scheduler to wrap. + optimizers (one or a list of `torch.optim.Optimizer`): + """ + + def __init__(self, scheduler, optimizers): + super().__init__(scheduler, optimizers) + + def step(self): + pass # `accelerator.backward(loss)` is doing that automatically. Therefore, its implementation is not needed + + +class DummyOptim: + """ + Dummy optimizer presents model parameters or param groups, this is primarily used to follow conventional training + loop when optimizer config is specified in the deepspeed config file. + + Args: + lr (float): + Learning rate. + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + weight_decay (float): + Weight decay. + **kwargs (additional keyword arguments, *optional*): + Other arguments. + """ + + def __init__(self, params, lr=0.001, weight_decay=0, **kwargs): + self.params = params + self.lr = lr + self.weight_decay = weight_decay + self.kwargs = kwargs + + +class DummyScheduler: + """ + Dummy scheduler presents model parameters or param groups, this is primarily used to follow conventional training + loop when scheduler config is specified in the deepspeed config file. + + Args: + optimizer (`torch.optim.optimizer.Optimizer`): + The optimizer to wrap. + total_num_steps (int, *optional*): + Total number of steps. + warmup_num_steps (int, *optional*): + Number of steps for warmup. + lr_scheduler_callable (callable, *optional*): + A callable function that creates an LR Scheduler. It accepts only one argument `optimizer`. + **kwargs (additional keyword arguments, *optional*): + Other arguments. + """ + + def __init__(self, optimizer, total_num_steps=None, warmup_num_steps=0, lr_scheduler_callable=None, **kwargs): + self.optimizer = optimizer + self.total_num_steps = total_num_steps + self.warmup_num_steps = warmup_num_steps + self.lr_scheduler_callable = lr_scheduler_callable + self.kwargs = kwargs diff --git a/accelerate/utils/environment.py b/accelerate/utils/environment.py new file mode 100644 index 0000000000000000000000000000000000000000..792ca6e5c832b24bd5ba2039164561a125ea725c --- /dev/null +++ b/accelerate/utils/environment.py @@ -0,0 +1,474 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import math +import os +import platform +import subprocess +import sys +from contextlib import contextmanager +from dataclasses import dataclass, field +from functools import lru_cache, wraps +from shutil import which +from typing import Optional, Union + +import torch +from packaging.version import parse + + +logger = logging.getLogger(__name__) + + +def convert_dict_to_env_variables(current_env: dict): + """ + Verifies that all keys and values in `current_env` do not contain illegal keys or values, and returns a list of + strings as the result. + + Example: + ```python + >>> from accelerate.utils.environment import verify_env + + >>> env = {"ACCELERATE_DEBUG_MODE": "1", "BAD_ENV_NAME": ">> valid_env_items = verify_env(env) + >>> print(valid_env_items) + ["ACCELERATE_DEBUG_MODE=1\n", "OTHER_ENV=2\n"] + ``` + """ + forbidden_chars = [";", "\n", "<", ">", " "] + valid_env_items = [] + for key, value in current_env.items(): + if all(char not in (key + value) for char in forbidden_chars) and len(key) >= 1 and len(value) >= 1: + valid_env_items.append(f"{key}={value}\n") + else: + logger.warning(f"WARNING: Skipping {key}={value} as it contains forbidden characters or missing values.") + return valid_env_items + + +def str_to_bool(value, to_bool: bool = False) -> Union[int, bool]: + """ + Converts a string representation of truth to `True` (1) or `False` (0). + + True values are `y`, `yes`, `t`, `true`, `on`, and `1`; False value are `n`, `no`, `f`, `false`, `off`, and `0`; + """ + value = value.lower() + if value in ("y", "yes", "t", "true", "on", "1"): + return 1 if not to_bool else True + elif value in ("n", "no", "f", "false", "off", "0"): + return 0 if not to_bool else False + else: + raise ValueError(f"invalid truth value {value}") + + +def get_int_from_env(env_keys, default): + """Returns the first positive env value found in the `env_keys` list or the default.""" + for e in env_keys: + val = int(os.environ.get(e, -1)) + if val >= 0: + return val + return default + + +def parse_flag_from_env(key, default=False): + """Returns truthy value for `key` from the env if available else the default.""" + value = os.environ.get(key, str(default)) + return str_to_bool(value) == 1 # As its name indicates `str_to_bool` actually returns an int... + + +def parse_choice_from_env(key, default="no"): + value = os.environ.get(key, str(default)) + return value + + +def are_libraries_initialized(*library_names: str) -> list[str]: + """ + Checks if any of `library_names` are imported in the environment. Will return any names that are. + """ + return [lib_name for lib_name in library_names if lib_name in sys.modules.keys()] + + +def get_current_device_type() -> tuple[str, str]: + """ + Determines the current device type and distributed type without initializing any device. + + This is particularly important when using fork-based multiprocessing, as device initialization + before forking can cause errors. + + The device detection order follows the same priority as state.py:_prepare_backend(): + MLU -> SDAA -> MUSA -> NPU -> HPU -> CUDA -> XPU + + Returns: + tuple[str, str]: A tuple of (device_type, distributed_type) + - device_type: The device string (e.g., "cuda", "npu", "xpu") + - distributed_type: The distributed type string (e.g., "MULTI_GPU", "MULTI_NPU") + + Example: + ```python + >>> device_type, distributed_type = get_current_device_type() + >>> print(device_type) # "cuda" + >>> print(distributed_type) # "MULTI_GPU" + ``` + """ + from .imports import ( + is_hpu_available, + is_mlu_available, + is_musa_available, + is_neuron_available, + is_npu_available, + is_sdaa_available, + is_xpu_available, + ) + + if is_mlu_available(): + return "mlu", "MULTI_MLU" + elif is_sdaa_available(): + return "sdaa", "MULTI_SDAA" + elif is_musa_available(): + return "musa", "MULTI_MUSA" + elif is_npu_available(): + return "npu", "MULTI_NPU" + elif is_hpu_available(): + return "hpu", "MULTI_HPU" + elif torch.cuda.is_available(): + return "cuda", "MULTI_GPU" + elif is_xpu_available(): + return "xpu", "MULTI_XPU" + elif is_neuron_available(): + return "neuron", "MULTI_NEURON" + else: + # Default to CUDA even if not available (for CPU-only scenarios where CUDA code paths are still used) + return "cuda", "MULTI_GPU" + + +def _nvidia_smi(): + """ + Returns the right nvidia-smi command based on the system. + """ + if platform.system() == "Windows": + # If platform is Windows and nvidia-smi can't be found in path + # try from systemd drive with default installation path + command = which("nvidia-smi") + if command is None: + command = f"{os.environ['systemdrive']}\\Program Files\\NVIDIA Corporation\\NVSMI\\nvidia-smi.exe" + else: + command = "nvidia-smi" + return command + + +def get_gpu_info(): + """ + Gets GPU count and names using `nvidia-smi` instead of torch to not initialize CUDA. + + Largely based on the `gputil` library. + """ + # Returns as list of `n` GPUs and their names + output = subprocess.check_output( + [_nvidia_smi(), "--query-gpu=count,name", "--format=csv,noheader"], universal_newlines=True + ) + output = output.strip() + gpus = output.split(os.linesep) + # Get names from output + gpu_count = len(gpus) + gpu_names = [gpu.split(",")[1].strip() for gpu in gpus] + return gpu_names, gpu_count + + +def get_driver_version(): + """ + Returns the driver version + + In the case of multiple GPUs, will return the first. + """ + output = subprocess.check_output( + [_nvidia_smi(), "--query-gpu=driver_version", "--format=csv,noheader"], universal_newlines=True + ) + output = output.strip() + return output.split(os.linesep)[0] + + +def check_cuda_p2p_ib_support(): + """ + Checks if the devices being used have issues with P2P and IB communications, namely any consumer GPU hardware after + the 3090. + + Notably uses `nvidia-smi` instead of torch to not initialize CUDA. + """ + try: + device_names, device_count = get_gpu_info() + # As new consumer GPUs get released, add them to `unsupported_devices`` + unsupported_devices = {"RTX 40"} + if device_count > 1: + if any( + unsupported_device in device_name + for device_name in device_names + for unsupported_device in unsupported_devices + ): + # Check if they have the right driver version + acceptable_driver_version = "550.40.07" + current_driver_version = get_driver_version() + if parse(current_driver_version) < parse(acceptable_driver_version): + return False + return True + except Exception: + pass + return True + + +@lru_cache +def check_cuda_fp8_capability(): + """ + Checks if the current GPU available supports FP8. + + Notably might initialize `torch.cuda` to check. + """ + + try: + # try to get the compute capability from nvidia-smi + output = subprocess.check_output( + [_nvidia_smi(), "--query-gpu=compute_capability", "--format=csv,noheader"], universal_newlines=True + ) + output = output.strip() + # we take the first GPU's compute capability + compute_capability = tuple(map(int, output.split(os.linesep)[0].split("."))) + except Exception: + compute_capability = torch.cuda.get_device_capability() + + return compute_capability >= (8, 9) + + +@dataclass +class CPUInformation: + """ + Stores information about the CPU in a distributed environment. It contains the following attributes: + - rank: The rank of the current process. + - world_size: The total number of processes in the world. + - local_rank: The rank of the current process on the local node. + - local_world_size: The total number of processes on the local node. + """ + + rank: int = field(default=0, metadata={"help": "The rank of the current process."}) + world_size: int = field(default=1, metadata={"help": "The total number of processes in the world."}) + local_rank: int = field(default=0, metadata={"help": "The rank of the current process on the local node."}) + local_world_size: int = field(default=1, metadata={"help": "The total number of processes on the local node."}) + + +def get_cpu_distributed_information() -> CPUInformation: + """ + Returns various information about the environment in relation to CPU distributed training as a `CPUInformation` + dataclass. + """ + information = {} + information["rank"] = get_int_from_env(["RANK", "PMI_RANK", "OMPI_COMM_WORLD_RANK", "MV2_COMM_WORLD_RANK"], 0) + information["world_size"] = get_int_from_env( + ["WORLD_SIZE", "PMI_SIZE", "OMPI_COMM_WORLD_SIZE", "MV2_COMM_WORLD_SIZE"], 1 + ) + information["local_rank"] = get_int_from_env( + ["LOCAL_RANK", "MPI_LOCALRANKID", "OMPI_COMM_WORLD_LOCAL_RANK", "MV2_COMM_WORLD_LOCAL_RANK"], 0 + ) + information["local_world_size"] = get_int_from_env( + ["LOCAL_WORLD_SIZE", "MPI_LOCALNRANKS", "OMPI_COMM_WORLD_LOCAL_SIZE", "MV2_COMM_WORLD_LOCAL_SIZE"], + 1, + ) + return CPUInformation(**information) + + +def override_numa_affinity(local_process_index: int, verbose: Optional[bool] = None) -> None: + """ + Overrides whatever NUMA affinity is set for the current process. This is very taxing and requires recalculating the + affinity to set, ideally you should use `utils.environment.set_numa_affinity` instead. + + Args: + local_process_index (int): + The index of the current process on the current server. + verbose (bool, *optional*): + Whether to log out the assignment of each CPU. If `ACCELERATE_DEBUG_MODE` is enabled, will default to True. + """ + if verbose is None: + verbose = parse_flag_from_env("ACCELERATE_DEBUG_MODE", False) + if torch.cuda.is_available(): + from accelerate.utils import is_pynvml_available + + if not is_pynvml_available(): + raise ImportError( + "To set CPU affinity on CUDA GPUs the `nvidia-ml-py` package must be available. (`pip install nvidia-ml-py`)" + ) + import pynvml as nvml + + # The below code is based on https://github.com/NVIDIA/DeepLearningExamples/blob/master/TensorFlow2/LanguageModeling/BERT/gpu_affinity.py + nvml.nvmlInit() + num_elements = math.ceil(os.cpu_count() / 64) + handle = nvml.nvmlDeviceGetHandleByIndex(local_process_index) + affinity_string = "" + for j in nvml.nvmlDeviceGetCpuAffinity(handle, num_elements): + # assume nvml returns list of 64 bit ints + affinity_string = f"{j:064b}{affinity_string}" + affinity_list = [int(x) for x in affinity_string] + affinity_list.reverse() # so core 0 is the 0th element + affinity_to_set = [i for i, e in enumerate(affinity_list) if e != 0] + os.sched_setaffinity(0, affinity_to_set) + if verbose: + cpu_cores = os.sched_getaffinity(0) + logger.info(f"Assigning {len(cpu_cores)} cpu cores to process {local_process_index}: {cpu_cores}") + + +@lru_cache +def set_numa_affinity(local_process_index: int, verbose: Optional[bool] = None) -> None: + """ + Assigns the current process to a specific NUMA node. Ideally most efficient when having at least 2 cpus per node. + + This result is cached between calls. If you want to override it, please use + `accelerate.utils.environment.override_numa_afifnity`. + + Args: + local_process_index (int): + The index of the current process on the current server. + verbose (bool, *optional*): + Whether to print the new cpu cores assignment for each process. If `ACCELERATE_DEBUG_MODE` is enabled, will + default to True. + """ + override_numa_affinity(local_process_index=local_process_index, verbose=verbose) + + +@contextmanager +def clear_environment(): + """ + A context manager that will temporarily clear environment variables. + + When this context exits, the previous environment variables will be back. + + Example: + + ```python + >>> import os + >>> from accelerate.utils import clear_environment + + >>> os.environ["FOO"] = "bar" + >>> with clear_environment(): + ... print(os.environ) + ... os.environ["FOO"] = "new_bar" + ... print(os.environ["FOO"]) + {} + new_bar + + >>> print(os.environ["FOO"]) + bar + ``` + """ + _old_os_environ = os.environ.copy() + os.environ.clear() + + try: + yield + finally: + os.environ.clear() # clear any added keys, + os.environ.update(_old_os_environ) # then restore previous environment + + +@contextmanager +def patch_environment(**kwargs): + """ + A context manager that will add each keyword argument passed to `os.environ` and remove them when exiting. + + Will convert the values in `kwargs` to strings and upper-case all the keys. + + Example: + + ```python + >>> import os + >>> from accelerate.utils import patch_environment + + >>> with patch_environment(FOO="bar"): + ... print(os.environ["FOO"]) # prints "bar" + >>> print(os.environ["FOO"]) # raises KeyError + ``` + """ + existing_vars = {} + for key, value in kwargs.items(): + key = key.upper() + if key in os.environ: + existing_vars[key] = os.environ[key] + os.environ[key] = str(value) + + try: + yield + finally: + for key in kwargs: + key = key.upper() + if key in existing_vars: + # restore previous value + os.environ[key] = existing_vars[key] + else: + os.environ.pop(key, None) + + +def purge_accelerate_environment(func_or_cls): + """Decorator to clean up accelerate environment variables set by the decorated class or function. + + In some circumstances, calling certain classes or functions can result in accelerate env vars being set and not + being cleaned up afterwards. As an example, when calling: + + TrainingArguments(fp16=True, ...) + + The following env var will be set: + + ACCELERATE_MIXED_PRECISION=fp16 + + This can affect subsequent code, since the env var takes precedence over TrainingArguments(fp16=False). This is + especially relevant for unit testing, where we want to avoid the individual tests to have side effects on one + another. Decorate the unit test function or whole class with this decorator to ensure that after each test, the env + vars are cleaned up. This works for both unittest.TestCase and normal classes (pytest); it also works when + decorating the parent class. + + """ + prefix = "ACCELERATE_" + + @contextmanager + def env_var_context(): + # Store existing accelerate env vars + existing_vars = {k: v for k, v in os.environ.items() if k.startswith(prefix)} + try: + yield + finally: + # Restore original env vars or remove new ones + for key in [k for k in os.environ if k.startswith(prefix)]: + if key in existing_vars: + os.environ[key] = existing_vars[key] + else: + os.environ.pop(key, None) + + def wrap_function(func): + @wraps(func) + def wrapper(*args, **kwargs): + with env_var_context(): + return func(*args, **kwargs) + + wrapper._accelerate_is_purged_environment_wrapped = True + return wrapper + + if not isinstance(func_or_cls, type): + return wrap_function(func_or_cls) + + # Handle classes by wrapping test methods + def wrap_test_methods(test_class_instance): + for name in dir(test_class_instance): + if name.startswith("test"): + method = getattr(test_class_instance, name) + if callable(method) and not hasattr(method, "_accelerate_is_purged_environment_wrapped"): + setattr(test_class_instance, name, wrap_function(method)) + return test_class_instance + + # Handle inheritance + wrap_test_methods(func_or_cls) + func_or_cls.__init_subclass__ = classmethod(lambda cls, **kw: wrap_test_methods(cls)) + return func_or_cls diff --git a/accelerate/utils/fsdp_utils.py b/accelerate/utils/fsdp_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8663f2c5e9e2e3274f7e6c806c999f35bf171900 --- /dev/null +++ b/accelerate/utils/fsdp_utils.py @@ -0,0 +1,857 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import functools +import os +import re +import shutil +import warnings +from collections import defaultdict +from collections.abc import Iterable +from contextlib import nullcontext +from pathlib import Path +from typing import Callable, Union + +import torch + +from ..logging import get_logger +from .constants import FSDP_MODEL_NAME, OPTIMIZER_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_NAME +from .dataclasses import get_module_class_from_name +from .modeling import get_non_persistent_buffers, is_peft_model +from .other import get_module_children_bottom_up, is_compiled_module, save +from .versions import is_torch_version + + +logger = get_logger(__name__) + + +def enable_fsdp_ram_efficient_loading(): + """ + Enables RAM efficient loading of Hugging Face models for FSDP in the environment. + """ + # Sets values for `transformers.modeling_utils.is_fsdp_enabled` + if "ACCELERATE_USE_FSDP" not in os.environ: + os.environ["ACCELERATE_USE_FSDP"] = "True" + os.environ["FSDP_CPU_RAM_EFFICIENT_LOADING"] = "True" + + +def disable_fsdp_ram_efficient_loading(): + """ + Disables RAM efficient loading of Hugging Face models for FSDP in the environment. + """ + os.environ["FSDP_CPU_RAM_EFFICIENT_LOADING"] = "False" + + +def _get_model_state_dict(model, adapter_only=False, sd_options=None): + if adapter_only and is_peft_model(model): + from peft import get_peft_model_state_dict + + return get_peft_model_state_dict(model, adapter_name=model.active_adapter) + + # Invariant: `sd_options` is not None only for FSDP2 + if sd_options is not None: + from torch.distributed.checkpoint.state_dict import get_model_state_dict + + return get_model_state_dict(model, options=sd_options) + else: + return model.state_dict() + + +def _set_model_state_dict(model, state_dict, adapter_only=False, sd_options=None): + if adapter_only and is_peft_model(model): + from peft import set_peft_model_state_dict + + return set_peft_model_state_dict(model, state_dict, adapter_name=model.active_adapter) + + # Invariant: `sd_options` is not None only for FSDP2 + if sd_options is not None: + from torch.distributed.checkpoint.state_dict import set_model_state_dict + + return set_model_state_dict(model, state_dict, options=sd_options) + else: + return model.load_state_dict(state_dict) + + +def _prepare_sd_options(fsdp_plugin): + sd_options = None + + # we use this only for FSDP2, as it requires torch >= 2.6.0 and this api requires torch >= 2.2.0 + if fsdp_plugin.fsdp_version == 2: + from torch.distributed.checkpoint.state_dict import StateDictOptions + from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType + + sd_options = StateDictOptions( + full_state_dict=fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT, + cpu_offload=getattr(fsdp_plugin.state_dict_config, "offload_to_cpu", False), + broadcast_from_rank0=getattr(fsdp_plugin.state_dict_config, "rank0_only", False), + ) + + return sd_options + + +def save_fsdp_model(fsdp_plugin, accelerator, model, output_dir, model_index=0, adapter_only=False): + # Note: We import here to reduce import time from general modules, and isolate outside dependencies + import torch.distributed.checkpoint as dist_cp + from torch.distributed.checkpoint.default_planner import DefaultSavePlanner + from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP + from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType + + os.makedirs(output_dir, exist_ok=True) + if fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT: + # FSDP raises error when single GPU is used with `offload_to_cpu=True` for FULL_STATE_DICT + # so, only enable it when num_processes>1 + is_multi_process = accelerator.num_processes > 1 + fsdp_plugin.state_dict_config.offload_to_cpu = is_multi_process + fsdp_plugin.state_dict_config.rank0_only = is_multi_process + + ctx = ( + FSDP.state_dict_type( + model, fsdp_plugin.state_dict_type, fsdp_plugin.state_dict_config, fsdp_plugin.optim_state_dict_config + ) + if fsdp_plugin.fsdp_version == 1 + else nullcontext() + ) + sd_options = _prepare_sd_options(fsdp_plugin) + + with ctx: + state_dict = _get_model_state_dict(model, adapter_only=adapter_only, sd_options=sd_options) + if fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT: + weights_name = f"{FSDP_MODEL_NAME}.bin" if model_index == 0 else f"{FSDP_MODEL_NAME}_{model_index}.bin" + output_model_file = os.path.join(output_dir, weights_name) + if accelerator.process_index == 0: + logger.info(f"Saving model to {output_model_file}") + torch.save(state_dict, output_model_file) + logger.info(f"Model saved to {output_model_file}") + # Invariant: `LOCAL_STATE_DICT` is never possible with `FSDP2` + elif fsdp_plugin.state_dict_type == StateDictType.LOCAL_STATE_DICT: + weights_name = ( + f"{FSDP_MODEL_NAME}_rank{accelerator.process_index}.bin" + if model_index == 0 + else f"{FSDP_MODEL_NAME}_{model_index}_rank{accelerator.process_index}.bin" + ) + output_model_file = os.path.join(output_dir, weights_name) + logger.info(f"Saving model to {output_model_file}") + torch.save(state_dict, output_model_file) + logger.info(f"Model saved to {output_model_file}") + elif fsdp_plugin.state_dict_type == StateDictType.SHARDED_STATE_DICT: + ckpt_dir = os.path.join(output_dir, f"{FSDP_MODEL_NAME}_{model_index}") + os.makedirs(ckpt_dir, exist_ok=True) + logger.info(f"Saving model to {ckpt_dir}") + state_dict = {"model": state_dict} + + dist_cp.save( + state_dict=state_dict, + storage_writer=dist_cp.FileSystemWriter(ckpt_dir), + planner=DefaultSavePlanner(), + ) + logger.info(f"Model saved to {ckpt_dir}") + + +def load_fsdp_model(fsdp_plugin, accelerator, model, input_dir, model_index=0, adapter_only=False): + # Note: We import here to reduce import time from general modules, and isolate outside dependencies + import torch.distributed.checkpoint as dist_cp + from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner + from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP + from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType + + accelerator.wait_for_everyone() + if fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT: + # FSDP raises error when single GPU is used with `offload_to_cpu=True` for FULL_STATE_DICT + # so, only enable it when num_processes>1 + is_multi_process = accelerator.num_processes > 1 + fsdp_plugin.state_dict_config.offload_to_cpu = is_multi_process + fsdp_plugin.state_dict_config.rank0_only = is_multi_process + + ctx = ( + FSDP.state_dict_type( + model, fsdp_plugin.state_dict_type, fsdp_plugin.state_dict_config, fsdp_plugin.optim_state_dict_config + ) + if fsdp_plugin.fsdp_version == 1 + else nullcontext() + ) + sd_options = _prepare_sd_options(fsdp_plugin) + with ctx: + if fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT: + if type(model) is not FSDP and accelerator.process_index != 0 and not accelerator.is_fsdp2: + if not fsdp_plugin.sync_module_states and fsdp_plugin.fsdp_version == 1: + raise ValueError( + "Set the `sync_module_states` flag to `True` so that model states are synced across processes when " + "initializing FSDP object" + ) + return + weights_name = f"{FSDP_MODEL_NAME}.bin" if model_index == 0 else f"{FSDP_MODEL_NAME}_{model_index}.bin" + input_model_file = os.path.join(input_dir, weights_name) + logger.info(f"Loading model from {input_model_file}") + # we want an empty state dict for FSDP2 as we use `broadcast_from_rank0` + load_model = not accelerator.is_fsdp2 or accelerator.is_main_process + if load_model: + state_dict = torch.load(input_model_file, weights_only=True) + else: + state_dict = {} + logger.info(f"Model loaded from {input_model_file}") + elif fsdp_plugin.state_dict_type == StateDictType.LOCAL_STATE_DICT: + weights_name = ( + f"{FSDP_MODEL_NAME}_rank{accelerator.process_index}.bin" + if model_index == 0 + else f"{FSDP_MODEL_NAME}_{model_index}_rank{accelerator.process_index}.bin" + ) + input_model_file = os.path.join(input_dir, weights_name) + logger.info(f"Loading model from {input_model_file}") + state_dict = torch.load(input_model_file, weights_only=True) + logger.info(f"Model loaded from {input_model_file}") + elif fsdp_plugin.state_dict_type == StateDictType.SHARDED_STATE_DICT: + ckpt_dir = ( + os.path.join(input_dir, f"{FSDP_MODEL_NAME}_{model_index}") + if f"{FSDP_MODEL_NAME}" not in input_dir + else input_dir + ) + logger.info(f"Loading model from {ckpt_dir}") + state_dict = {"model": _get_model_state_dict(model, adapter_only=adapter_only, sd_options=sd_options)} + dist_cp.load( + state_dict=state_dict, + storage_reader=dist_cp.FileSystemReader(ckpt_dir), + planner=DefaultLoadPlanner(), + ) + state_dict = state_dict["model"] + logger.info(f"Model loaded from {ckpt_dir}") + + load_result = _set_model_state_dict(model, state_dict, adapter_only=adapter_only, sd_options=sd_options) + return load_result + + +def save_fsdp_optimizer(fsdp_plugin, accelerator, optimizer, model, output_dir, optimizer_index=0): + # Note: We import here to reduce import time from general modules, and isolate outside dependencies + import torch.distributed.checkpoint as dist_cp + from torch.distributed.checkpoint.default_planner import DefaultSavePlanner + from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP + from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType + + os.makedirs(output_dir, exist_ok=True) + + ctx = ( + FSDP.state_dict_type( + model, fsdp_plugin.state_dict_type, fsdp_plugin.state_dict_config, fsdp_plugin.optim_state_dict_config + ) + if fsdp_plugin.fsdp_version == 1 + else nullcontext() + ) + + sd_options = _prepare_sd_options(fsdp_plugin) + + with ctx: + if fsdp_plugin.fsdp_version == 2: + from torch.distributed.checkpoint.state_dict import get_optimizer_state_dict + + optim_state = get_optimizer_state_dict(model, optimizer, options=sd_options) + else: + optim_state = FSDP.optim_state_dict(model, optimizer) + + if fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT: + if accelerator.process_index == 0: + optim_state_name = ( + f"{OPTIMIZER_NAME}.bin" if optimizer_index == 0 else f"{OPTIMIZER_NAME}_{optimizer_index}.bin" + ) + output_optimizer_file = os.path.join(output_dir, optim_state_name) + logger.info(f"Saving Optimizer state to {output_optimizer_file}") + torch.save(optim_state, output_optimizer_file) + logger.info(f"Optimizer state saved in {output_optimizer_file}") + else: + ckpt_dir = os.path.join(output_dir, f"{OPTIMIZER_NAME}_{optimizer_index}") + os.makedirs(ckpt_dir, exist_ok=True) + logger.info(f"Saving Optimizer state to {ckpt_dir}") + dist_cp.save( + state_dict={"optimizer": optim_state}, + storage_writer=dist_cp.FileSystemWriter(ckpt_dir), + planner=DefaultSavePlanner(), + ) + logger.info(f"Optimizer state saved in {ckpt_dir}") + + +def load_fsdp_optimizer(fsdp_plugin, accelerator, optimizer, model, input_dir, optimizer_index=0, adapter_only=False): + # Note: We import here to reduce import time from general modules, and isolate outside dependencies + import torch.distributed.checkpoint as dist_cp + from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP + from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType + + accelerator.wait_for_everyone() + ctx = ( + FSDP.state_dict_type( + model, fsdp_plugin.state_dict_type, fsdp_plugin.state_dict_config, fsdp_plugin.optim_state_dict_config + ) + if fsdp_plugin.fsdp_version == 1 + else nullcontext() + ) + sd_options = _prepare_sd_options(fsdp_plugin) + with ctx: + if fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT: + optim_state = None + if accelerator.process_index == 0 or not fsdp_plugin.optim_state_dict_config.rank0_only: + optimizer_name = ( + f"{OPTIMIZER_NAME}.bin" if optimizer_index == 0 else f"{OPTIMIZER_NAME}_{optimizer_index}.bin" + ) + input_optimizer_file = os.path.join(input_dir, optimizer_name) + logger.info(f"Loading Optimizer state from {input_optimizer_file}") + optim_state = torch.load(input_optimizer_file, weights_only=True) + logger.info(f"Optimizer state loaded from {input_optimizer_file}") + else: + ckpt_dir = ( + os.path.join(input_dir, f"{OPTIMIZER_NAME}_{optimizer_index}") + if f"{OPTIMIZER_NAME}" not in input_dir + else input_dir + ) + logger.info(f"Loading Optimizer from {ckpt_dir}") + if fsdp_plugin.fsdp_version == 2: + from torch.distributed.checkpoint.state_dict import get_optimizer_state_dict + + optim_state = get_optimizer_state_dict(model, optimizer, options=sd_options) + else: + optim_state = FSDP.optim_state_dict(model, optimizer) + optim_state = {"optimizer": optim_state} + dist_cp.load( + optim_state, + checkpoint_id=ckpt_dir, + storage_reader=dist_cp.FileSystemReader(ckpt_dir), + ) + optim_state = optim_state["optimizer"] + logger.info(f"Optimizer loaded from {ckpt_dir}") + + if fsdp_plugin.fsdp_version == 1: + flattened_osd = FSDP.optim_state_dict_to_load(model=model, optim=optimizer, optim_state_dict=optim_state) + optimizer.load_state_dict(flattened_osd) + else: + from torch.distributed.checkpoint.state_dict import set_optimizer_state_dict + + set_optimizer_state_dict(model, optimizer, optim_state, options=sd_options) + + +def _distributed_checkpoint_to_merged_weights(checkpoint_dir: str, save_path: str, safe_serialization: bool = True): + """ + Passthrough to `torch.distributed.checkpoint.format_utils.dcp_to_torch_save` + + Will save under `save_path` as either `model.safetensors` or `pytorch_model.bin`. + """ + # Note: We import here to reduce import time from general modules, and isolate outside dependencies + import torch.distributed.checkpoint as dist_cp + import torch.distributed.checkpoint.format_utils as dist_cp_format_utils + + state_dict = {} + save_path = Path(save_path) + save_path.mkdir(exist_ok=True) + dist_cp_format_utils._load_state_dict( + state_dict, + storage_reader=dist_cp.FileSystemReader(checkpoint_dir), + planner=dist_cp_format_utils._EmptyStateDictLoadPlanner(), + no_dist=True, + ) + save_path = save_path / SAFE_WEIGHTS_NAME if safe_serialization else save_path / WEIGHTS_NAME + + # To handle if state is a dict like {model: {...}} + if len(state_dict.keys()) == 1: + state_dict = state_dict[list(state_dict)[0]] + save(state_dict, save_path, safe_serialization=safe_serialization) + return save_path + + +def merge_fsdp_weights( + checkpoint_dir: str, output_path: str, safe_serialization: bool = True, remove_checkpoint_dir: bool = False +): + """ + Merge the weights from sharded FSDP model checkpoints into a single combined checkpoint. Should be used if + `SHARDED_STATE_DICT` was used for the model. Weights will be saved to `{output_path}/model.safetensors` if + `safe_serialization` else `pytorch_model.bin`. + + Note: this is a CPU-bound process. + + Args: + checkpoint_dir (`str`): + The directory containing the FSDP checkpoints (can be either the model or optimizer). + output_path (`str`): + The path to save the merged checkpoint. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the merged weights with safetensors (recommended). + remove_checkpoint_dir (`bool`, *optional*, defaults to `False`): + Whether to remove the checkpoint directory after merging. + """ + checkpoint_dir = Path(checkpoint_dir) + from accelerate.state import PartialState + + if not is_torch_version(">=", "2.3.0"): + raise ValueError("`merge_fsdp_weights` requires PyTorch >= 2.3.0`") + + # Verify that the checkpoint directory exists + if not checkpoint_dir.exists(): + model_path_exists = (checkpoint_dir / "pytorch_model_fsdp_0").exists() + optimizer_path_exists = (checkpoint_dir / "optimizer_0").exists() + err = f"Tried to load from {checkpoint_dir} but couldn't find a valid metadata file." + if model_path_exists and optimizer_path_exists: + err += " However, potential model and optimizer checkpoint directories exist." + err += f"Please pass in either {checkpoint_dir}/pytorch_model_fsdp_0 or {checkpoint_dir}/optimizer_0" + err += "instead." + elif model_path_exists: + err += " However, a potential model checkpoint directory exists." + err += f"Please try passing in {checkpoint_dir}/pytorch_model_fsdp_0 instead." + elif optimizer_path_exists: + err += " However, a potential optimizer checkpoint directory exists." + err += f"Please try passing in {checkpoint_dir}/optimizer_0 instead." + raise ValueError(err) + + # To setup `save` to work + state = PartialState() + if state.is_main_process: + logger.info(f"Merging FSDP weights from {checkpoint_dir}") + save_path = _distributed_checkpoint_to_merged_weights(checkpoint_dir, output_path, safe_serialization) + logger.info(f"Successfully merged FSDP weights and saved to {save_path}") + if remove_checkpoint_dir: + logger.info(f"Removing old checkpoint directory {checkpoint_dir}") + shutil.rmtree(checkpoint_dir) + state.wait_for_everyone() + + +def ensure_weights_retied(param_init_fn, model: torch.nn.Module, device: torch.device): + _tied_names = getattr(model, "_tied_weights_keys", None) + if not _tied_names: + # if no tied names just passthrough + return param_init_fn + + # get map of parameter instances to params. + # - needed for replacement later + _tied_params = {} + for name in _tied_names: + name = name.split(".") + name, param_name = ".".join(name[:-1]), name[-1] + mod = model.get_submodule(name) + param = getattr(mod, param_name) + + _tied_params[id(param)] = None # placeholder for the param first + + # build param_init_fn for the case with tied params + def param_init_fn_tied_param(module: torch.nn.Module): + # track which params to tie + # - usually only 1, but for completeness consider > 1 + params_to_tie = defaultdict(list) + for n, param in module.named_parameters(recurse=False): + if id(param) in _tied_params: + params_to_tie[id(param)].append(n) + + # call the param init fn, which potentially re-allocates the + # parameters + module = param_init_fn(module) + + # search the parameters again and tie them up again + for id_key, _param_names in params_to_tie.items(): + for param_name in _param_names: + param = _tied_params[id_key] + if param is None: + # everything will be tied to the first time the + # param is observed + _tied_params[id_key] = getattr(module, param_name) + else: + setattr(module, param_name, param) # tie + + return module + + return param_init_fn_tied_param + + +def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dict, cpu_offload: bool = False): + """ + Loads the full state dict (could be only on rank 0) into the sharded model. This is done by broadcasting the + parameters from rank 0 to all other ranks. This function modifies the model in-place. + + Args: + accelerator (`Accelerator`): The accelerator instance + model (`torch.nn.Module`): + The model to load the state dict into, expected to be on meta device or a VRAM spike can occur + full_sd (`dict`): The full state dict to load, can only be on rank 0 + cpu_offload (`bool`, defaults to `False`): + If True, move sharded parameters to CPU after distribution. Required when FSDP CPU offloading is enabled. + """ + import torch.distributed as dist + from torch.distributed.tensor import DTensor, distribute_tensor + + # Model was previously copied to meta device + meta_sharded_sd = model.state_dict() + sharded_sd = {} + + # Rank 0 distributes the full state dict to other ranks + def _infer_parameter_dtype(model, param_name, empty_param): + try: + old_param = model.get_parameter_or_buffer(param_name) + except AttributeError: + # Need this for LORA, as there some params are not *parameters* of sorts + base_param_name, local_param_name = param_name.rsplit(".", 1) + submodule = model.get_submodule(base_param_name) + old_param = getattr(submodule, local_param_name) + + is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn") + casting_dtype = None + is_param_float8_e4m3fn = is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn + + if empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn: + casting_dtype = old_param.dtype + + return old_param is not None and old_param.is_contiguous(), casting_dtype + + def _cast_and_contiguous(tensor, to_contiguous, dtype): + if dtype is not None: + tensor = tensor.to(dtype=dtype) + if to_contiguous: + tensor = tensor.contiguous() + return tensor + + if accelerator.is_main_process: + for (param_name, full_param), sharded_param in zip(full_sd.items(), meta_sharded_sd.values()): + device_mesh = sharded_param.device_mesh + full_param = full_param.detach().to(device_mesh.device_type) + if isinstance(full_param, DTensor): + # dist.broadcast() only supports torch.Tensor. + # After prepare_tp(), model parameters may become DTensor. + # To broadcast such a parameter, convert it to a local tensor first. + full_param = full_param.to_local() + dist.broadcast(full_param, src=0, group=dist.group.WORLD) + sharded_tensor = distribute_tensor(full_param, device_mesh, sharded_param.placements) + to_contiguous, casting_dtype = _infer_parameter_dtype( + model, + param_name, + full_param, + ) + sharded_tensor = _cast_and_contiguous(sharded_tensor, to_contiguous, casting_dtype) + # When CPU offloading is enabled, FSDP2's lazy_init expects parameters on CPU + if cpu_offload: + sharded_tensor = sharded_tensor.to("cpu") + sharded_sd[param_name] = sharded_tensor + # We need this else to have a matching `broadcast` for all of the ranks, else we deadlock + else: + for param_name, sharded_param in meta_sharded_sd.items(): + device_mesh = sharded_param.device_mesh + full_tensor = torch.empty(sharded_param.size(), device=device_mesh.device_type, dtype=sharded_param.dtype) + dist.broadcast(full_tensor, src=0, group=dist.group.WORLD) + sharded_tensor = distribute_tensor(full_tensor, device_mesh, sharded_param.placements) + to_contiguous, casting_dtype = _infer_parameter_dtype( + model, + param_name, + full_tensor, + ) + sharded_tensor = _cast_and_contiguous(sharded_tensor, to_contiguous, casting_dtype) + # When CPU offloading is enabled, FSDP2's lazy_init expects parameters on CPU + if cpu_offload: + sharded_tensor = sharded_tensor.to("cpu") + sharded_sd[param_name] = sharded_tensor + + # we set `assign=True` because our params are on meta device + model.load_state_dict(sharded_sd, assign=True) + return model + + +def fsdp2_switch_optimizer_parameters(optimizer: torch.optim.Optimizer, mapping: dict): + """ + Switches the parameters of the optimizer to new ones (sharded parameters in usual case). This function modifies the + optimizer in-place. + + Args: + optimizer (`torch.optim.Optimizer`): Optimizer instance which contains the original model parameters + mapping (`dict`): Mapping from the original parameter (specified by `data_ptr`) to the sharded parameter + + Raises: + KeyError: + If a parameter in the optimizer couldn't be switched to its sharded version. This should never happen and + indicates a bug. If we kept the original params instead of raising, the training wouldn't be numerically + correct and weights wouldn't get updated. + """ + from torch.distributed.tensor import DTensor + + accessor_mapping = {} + + accessor_mapping[DTensor] = "_local_tensor" + try: + for param_group in optimizer.param_groups: + param_group["params"] = [mapping[p.data_ptr] for p in param_group["params"]] + except KeyError: + # This shouldn't ever happen, but we want to fail here else training wouldn't be numerically correct + # This basically means that we're missing a mapping from the original parameter to the sharded parameter + raise KeyError( + "A parameter in the optimizer couldn't be switched to its sharded version. This breaks the training. Please raise an issue on GitHub." + ) + + +def fsdp2_apply_ac(accelerator, model: torch.nn.Module): + """ + Applies the activation checkpointing to the model. + + Args: + accelerator (`Accelerator`): The accelerator instance + model (`torch.nn.Module`): The model to apply the activation checkpointing to + + Returns: + `torch.nn.Module`: The model with the activation checkpointing applied + """ + + from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + checkpoint_wrapper, + ) + + auto_wrap_policy_func = fsdp2_prepare_auto_wrap_policy(accelerator.state.fsdp_plugin, model) + + for layer_name, layer in get_module_children_bottom_up(model, return_fqns=True)[:-1]: + if len(layer_name.split(".")) > 1: + parent_name, child_name = layer_name.rsplit(".", 1) + else: + parent_name = None + child_name = layer_name + + parent_module = model.get_submodule(parent_name) if parent_name else model + if auto_wrap_policy_func(parent_module): + layer = checkpoint_wrapper(layer, preserve_rng_state=False) + parent_module.register_module(child_name, layer) + + return model + + +def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module: + """Prepares the model for FSDP2 in-place. Also returns the model to avoid misuse of the original model. + + Args: + accelerator (`Accelerator`): The accelerator instance + model (`torch.nn.Module`): The model to prepare + + Returns: + `torch.nn.Module`: Prepared model + """ + from torch.distributed.fsdp import FSDPModule, MixedPrecisionPolicy, fully_shard + + is_type_fsdp = isinstance(model, FSDPModule) or ( + is_compiled_module(model) and isinstance(model._orig_mod, FSDPModule) + ) + if is_type_fsdp: + return model + + fsdp2_plugin = accelerator.state.fsdp_plugin + + fsdp2_plugin.set_auto_wrap_policy(model) + + original_sd = model.state_dict() + mesh = getattr(accelerator, "torch_device_mesh", None) + + fsdp2_kwargs = { + "reshard_after_forward": fsdp2_plugin.reshard_after_forward, + "offload_policy": fsdp2_plugin.cpu_offload, + # `fully_shard` does not accept `None` in case of `MixedPrecisionPolicy` + "mp_policy": fsdp2_plugin.mixed_precision_policy or MixedPrecisionPolicy(), + "mesh": mesh[tuple(accelerator.parallelism_config.fsdp_dim_names)] if mesh is not None else None, + } + + # `ignored_params` is only supported in torch >= 2.7.0 + if is_torch_version(">=", "2.7.0") and fsdp2_plugin.ignored_modules is not None: + fsdp2_kwargs["ignored_params"] = get_parameters_from_modules( + fsdp2_plugin.ignored_modules, model, accelerator.device + ) + + model_has_params4bit = False + for name, param in model.named_parameters(): + # this is a temporary fix whereby loading models with bnb params cannot be moved from + # GPU to a meta device due with FSDP2 because torch operations don't return the original class type + # bypassing the move to meta will still cause the VRAM spike, but at least it still will load + if param.__class__.__name__ == "Params4bit": + model_has_params4bit = True + break + + if fsdp2_plugin.cpu_ram_efficient_loading and not model_has_params4bit: + # Context: `fully_shard` moves the model to GPU if it was on CPU, however it can also be on `meta` and then it stays there even after `fully_shard` + # For this reason, we need to move the model to `meta` device, as then sharding happens on `meta` device + # If we kept the model on CPU (`cpu_ram_efficient_loading` has model be on CPU on all ranks, though non-main ranks only have `torch.empty`), `fully_shard` would move it to GPU + # Afterwards, when we call `fsdp2_load_full_state_dict`, us creating the state_dict would result into briefly having two copies of model state_dict on the GPU -> VRAM spike + + # We need to keep the original non-persistent buffers, as those MAY not be in the state_dict, resulting in them staying on meta device + # Also, these buffers aren't getting sharded by default + # We get the FQNs of all non-persistent buffers, to re-register them after + non_persistent_buffer_fqns = get_non_persistent_buffers(model, recurse=True, fqns=True) + original_non_persistent_buffers = copy.deepcopy( + {k: v for k, v in model.named_buffers() if k in non_persistent_buffer_fqns} + ) + # We move the model to meta device, as then sharding happens on meta device + model = model.to(torch.device("meta")) + # We need to re-tie the weights, not exactly sure why, but if we don't do this, reference to `lm_head/embed_tokens` stay hanging -> more VRAM usage + # We assume `transformers` models have a `tie_weights` method if they support it + if hasattr(model, "tie_weights"): + model.tie_weights() + + auto_wrap_policy_func = fsdp2_prepare_auto_wrap_policy(fsdp2_plugin, model) + if auto_wrap_policy_func is not None: + # We skip the model itself, as that one is always wrapped + for module in get_module_children_bottom_up(model)[:-1]: + if auto_wrap_policy_func(module) and not isinstance(module, FSDPModule): + fully_shard(module, **fsdp2_kwargs) + + if not isinstance(model, FSDPModule): + fully_shard(model, **fsdp2_kwargs) + + if fsdp2_plugin.cpu_ram_efficient_loading: + # If `cpu_ram_efficient_loading` is enabled, only rank 0 loads the weights + # Other ranks have an empty model on `meta` device, so we need to distribute the weights properly + # When CPU offloading is enabled, parameters need to stay on CPU after distribution + from torch.distributed.fsdp import CPUOffloadPolicy + + fsdp2_load_full_state_dict( + accelerator, model, original_sd, cpu_offload=isinstance(fsdp2_plugin.cpu_offload, CPUOffloadPolicy) + ) + + if fsdp2_plugin.cpu_ram_efficient_loading and not model_has_params4bit: + # We re-register the buffers, as they may not be in the state_dict + for fqn, buffer_tensor in original_non_persistent_buffers.items(): + buffer_tensor = buffer_tensor.to(accelerator.device) + + if "." in fqn: + parent_fqn, local_buffer_name = fqn.rsplit(".", 1) + parent_module = model.get_submodule(parent_fqn) + else: + local_buffer_name = fqn + parent_module = model + + parent_module.register_buffer(local_buffer_name, buffer_tensor, persistent=False) + + # We need to tie the weights again, as call to `load_full_state_dict` breaks the tie + # Needs to be called both here and above + # removing this call makes the have slightly different loss + # removing the call above leads to extra memory usage as explained in the comment above + if hasattr(model, "tie_weights"): + model.tie_weights() + + # There is no `dtype` attribution for nn.Module + # Set it to None if it doesn't exist and do the upcast always + model_dtype = getattr(model, "dtype", None) + if accelerator.mixed_precision != "no" and (model_dtype is None or model_dtype != torch.float32): + # We upcast the trainable parameters according to `deepspeed`'s implementation + # More info about this can be found in `accelerator.py:prepare_model`s FSDP1 section + upcasted_params = [] + for name, param in model.named_parameters(): + if param.requires_grad and param.dtype != torch.float32: + upcasted_params.append(name) + param = param.to(torch.float32) + if accelerator.is_main_process and upcasted_params: + warnings.warn( + "FSDP upcast of low precision parameters to fp32 (since mixed_precision != 'no') may affect the precision of model checkpoints. " + f"This effects {len(upcasted_params)} parameters: {upcasted_params}..." + ) + return model + + +def fsdp2_prepare_auto_wrap_policy(fsdp2_plugin, model: torch.nn.Module) -> Callable[[torch.nn.Module], bool]: + """Prepares the auto wrap policy based on its type, done to mimic the behaviour of FSDP1 auto wrap policy. + + Args: + fsdp2_plugin (`FullyShardedDataParallelPlugin`): + Instance of `FullyShardedDataParallelPlugin` containing the configuration options + auto_wrap_policy_type (`str`): + Either `transformer` or `size` + model (`torch.nn.Module`): + The model to wrap + + Returns: + `Callable[[torch.nn.Module], bool]`: + The auto wrap policy function to be applied to the model + """ + from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy + + fn = fsdp2_plugin.auto_wrap_policy + + if isinstance(fn, functools.partial): + fn = fn.func + + if fn is transformer_auto_wrap_policy: + no_split_modules = getattr(model, "_no_split_modules", None) + if no_split_modules is None: + no_split_modules = [] + transformer_cls_names_to_wrap = list(no_split_modules) + if fsdp2_plugin.transformer_cls_names_to_wrap is not None: + transformer_cls_names_to_wrap = fsdp2_plugin.transformer_cls_names_to_wrap + transformer_cls_to_wrap = set() + + for layer_class in transformer_cls_names_to_wrap: + transformer_cls = get_module_class_from_name(model, layer_class) + if transformer_cls is None: + raise ValueError(f"Could not find the transformer layer class {layer_class} in the model.") + transformer_cls_to_wrap.add(transformer_cls) + + def policy(module: torch.nn.Module) -> bool: + if fsdp2_plugin.transformer_cls_names_to_wrap is None: + return False + return isinstance(module, tuple(transformer_cls_to_wrap)) + + elif fn is size_based_auto_wrap_policy: + + def policy(module: torch.nn.Module) -> bool: + module_num_params = sum(p.numel() for p in module.parameters()) + return module_num_params > fsdp2_plugin.min_num_params + else: + return None + + return policy + + +def get_fsdp2_grad_scaler(**kwargs): + """ + Returns a `GradScaler` for FSDP2, as the current implementation of `get_grad_scaler` doesn't accept other args. We + need this as current `get_grad_scaler` accepts only `distributed_type` as arg, which doesn't differentiate between + FSDP1 and FSDP2 + """ + from torch.amp.grad_scaler import GradScaler + + return GradScaler(**kwargs) + + +def fsdp2_canonicalize_names(named_params: dict) -> dict: + """Removes parameter name modifiers in order to map them back to their original names. + + See huggingface/accelerate#3554 for more context. + + Args: + named_params (`dict`): The named parameters dictionary to canonicalize. + + Returns: + `dict`: The canonicalized named parameters dictionary + """ + named_params = {k.replace("._checkpoint_wrapped_module", ""): v for k, v in named_params.items()} + named_params = { + k.replace("_orig_mod.", "") if k.startswith("_orig_mod.") else k: v for k, v in named_params.items() + } + named_params = {k.replace("._orig_mod", ""): v for k, v in named_params.items()} + return named_params + + +def get_parameters_from_modules( + modules: Union[Iterable[torch.nn.Module], str], model, device +) -> set[torch.nn.Parameter]: + """Converts modules to parameters where modules can be a string or list of torch.nn.Module + + Args: + modules (`Union[Iterable[torch.nn.Module], str]`): List of modules + + Returns: + `set[torch.nn.Parameter]`: List of parameters + """ + if modules is None: + return set() + parameters = [] + # code taken from accelerate while preparing kwargs for FSDP + if isinstance(modules, str): + reg = re.compile(modules) + mapped_modules = [] + for name, module in model.named_modules(): + if reg.fullmatch(name): + module.to(device) + mapped_modules.append(module) + modules = mapped_modules + for module in modules: + parameters.extend(list(module.parameters())) + return set(parameters) diff --git a/accelerate/utils/imports.py b/accelerate/utils/imports.py new file mode 100644 index 0000000000000000000000000000000000000000..0704088d2096902c952867add39892c32b1f661e --- /dev/null +++ b/accelerate/utils/imports.py @@ -0,0 +1,536 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import importlib.metadata +import os +import sys +import warnings +from functools import lru_cache, wraps + +import torch +from packaging import version +from packaging.version import parse + +from .environment import parse_flag_from_env, patch_environment, str_to_bool +from .versions import compare_versions, is_torch_version + + +# Try to run Torch native job in an environment with TorchXLA installed by setting this value to 0. +USE_TORCH_XLA = parse_flag_from_env("USE_TORCH_XLA", default=True) + +_torch_xla_available = False +if USE_TORCH_XLA: + try: + import torch_xla.core.xla_model as xm # noqa: F401 + import torch_xla.runtime + + _torch_xla_available = True + except ImportError: + pass + +# Keep it for is_tpu_available. It will be removed along with is_tpu_available. +_tpu_available = _torch_xla_available + +# Cache this result has it's a C FFI call which can be pretty time-consuming +_torch_distributed_available = torch.distributed.is_available() + + +def _is_package_available(pkg_name, metadata_name=None): + # Check we're not importing a "pkg_name" directory somewhere but the actual library by trying to grab the version + package_exists = importlib.util.find_spec(pkg_name) is not None + if package_exists: + try: + # Some libraries have different names in the metadata + _ = importlib.metadata.metadata(pkg_name if metadata_name is None else metadata_name) + return True + except importlib.metadata.PackageNotFoundError: + return False + + +def is_torch_distributed_available() -> bool: + return _torch_distributed_available + + +def is_xccl_available(): + if is_torch_version(">=", "2.7.0"): + return torch.distributed.distributed_c10d.is_xccl_available() + return False + + +def is_import_timer_available(): + return _is_package_available("import_timer") + + +def is_pynvml_available(): + return _is_package_available("pynvml") or _is_package_available("pynvml", "nvidia-ml-py") + + +def is_pytest_available(): + return _is_package_available("pytest") + + +def is_msamp_available(): + return _is_package_available("msamp", "ms-amp") + + +def is_schedulefree_available(): + return _is_package_available("schedulefree") + + +def is_transformer_engine_available(): + if is_hpu_available(): + return _is_package_available("intel_transformer_engine", "intel-transformer-engine") + else: + return _is_package_available("transformer_engine", "transformer-engine") + + +def is_transformer_engine_mxfp8_available(): + if _is_package_available("transformer_engine", "transformer-engine"): + from transformer_engine.pytorch.fp8 import check_mxfp8_support + + return check_mxfp8_support()[0] + return False + + +def is_lomo_available(): + return _is_package_available("lomo_optim") + + +def is_cuda_available(): + """ + Checks if `cuda` is available via an `nvml-based` check which won't trigger the drivers and leave cuda + uninitialized. + """ + with patch_environment(PYTORCH_NVML_BASED_CUDA_CHECK="1"): + available = torch.cuda.is_available() + + return available + + +@lru_cache +def is_torch_xla_available(check_is_tpu=False, check_is_gpu=False): + """ + Check if `torch_xla` is available. To train a native pytorch job in an environment with torch xla installed, set + the USE_TORCH_XLA to false. + """ + assert not (check_is_tpu and check_is_gpu), "The check_is_tpu and check_is_gpu cannot both be true." + + if not _torch_xla_available: + return False + elif check_is_gpu: + return torch_xla.runtime.device_type() in ["GPU", "CUDA"] + elif check_is_tpu: + return torch_xla.runtime.device_type() == "TPU" + + return True + + +def is_torchao_available(): + package_exists = _is_package_available("torchao") + if package_exists: + torchao_version = version.parse(importlib.metadata.version("torchao")) + return compare_versions(torchao_version, ">=", "0.6.1") + return False + + +def is_deepspeed_available(): + return _is_package_available("deepspeed") + + +def is_pippy_available(): + return is_torch_version(">=", "2.4.0") + + +def is_bf16_available(ignore_tpu=False): + "Checks if bf16 is supported, optionally ignoring the TPU" + if is_torch_xla_available(check_is_tpu=True): + return not ignore_tpu + if is_cuda_available(): + return torch.cuda.is_bf16_supported() + if is_mlu_available(): + return torch.mlu.is_bf16_supported() + if is_xpu_available(): + return torch.xpu.is_bf16_supported() + if is_mps_available(): + return torch.backends.mps.is_macos_or_newer(14, 0) + return True + + +def is_fp16_available(): + "Checks if fp16 is supported" + if is_habana_gaudi1(): + return False + + return True + + +def is_fp8_available(): + "Checks if fp8 is supported" + return is_msamp_available() or is_transformer_engine_available() or is_torchao_available() + + +def is_4bit_bnb_available(): + package_exists = _is_package_available("bitsandbytes") + if package_exists: + bnb_version = version.parse(importlib.metadata.version("bitsandbytes")) + return compare_versions(bnb_version, ">=", "0.39.0") + return False + + +def is_8bit_bnb_available(): + package_exists = _is_package_available("bitsandbytes") + if package_exists: + bnb_version = version.parse(importlib.metadata.version("bitsandbytes")) + return compare_versions(bnb_version, ">=", "0.37.2") + return False + + +def is_bnb_available(min_version=None): + package_exists = _is_package_available("bitsandbytes") + if package_exists and min_version is not None: + bnb_version = version.parse(importlib.metadata.version("bitsandbytes")) + return compare_versions(bnb_version, ">=", min_version) + else: + return package_exists + + +def is_bitsandbytes_multi_backend_available(): + if not is_bnb_available(): + return False + import bitsandbytes as bnb + + return "multi_backend" in getattr(bnb, "features", set()) + + +def is_torchvision_available(): + return _is_package_available("torchvision") + + +def is_megatron_lm_available(): + if str_to_bool(os.environ.get("ACCELERATE_USE_MEGATRON_LM", "False")) == 1: + if importlib.util.find_spec("megatron") is not None: + try: + megatron_version = parse(importlib.metadata.version("megatron-core")) + if compare_versions(megatron_version, ">=", "0.8.0"): + return importlib.util.find_spec(".training", "megatron") + except Exception as e: + warnings.warn(f"Parse Megatron version failed. Exception:{e}") + return False + + +def is_transformers_available(): + return _is_package_available("transformers") + + +def is_datasets_available(): + return _is_package_available("datasets") + + +def is_peft_available(): + return _is_package_available("peft") + + +def is_timm_available(): + return _is_package_available("timm") + + +def is_triton_available(): + if is_xpu_available(): + return _is_package_available("triton", "pytorch-triton-xpu") + return _is_package_available("triton") + + +def is_aim_available(): + package_exists = _is_package_available("aim") + if package_exists: + aim_version = version.parse(importlib.metadata.version("aim")) + return compare_versions(aim_version, "<", "4.0.0") + return False + + +def is_tensorboard_available(): + return _is_package_available("tensorboard") or _is_package_available("tensorboardX") + + +def is_wandb_available(): + return _is_package_available("wandb") + + +def is_comet_ml_available(): + return _is_package_available("comet_ml") + + +def is_swanlab_available(): + return _is_package_available("swanlab") + + +def is_trackio_available(): + return sys.version_info >= (3, 10) and _is_package_available("trackio") + + +def is_boto3_available(): + return _is_package_available("boto3") + + +def is_rich_available(): + if _is_package_available("rich"): + return parse_flag_from_env("ACCELERATE_ENABLE_RICH", False) + return False + + +def is_sagemaker_available(): + return _is_package_available("sagemaker") + + +def is_tqdm_available(): + return _is_package_available("tqdm") + + +def is_clearml_available(): + return _is_package_available("clearml") + + +def is_pandas_available(): + return _is_package_available("pandas") + + +def is_matplotlib_available(): + return _is_package_available("matplotlib") + + +def is_mlflow_available(): + if _is_package_available("mlflow"): + return True + + if importlib.util.find_spec("mlflow") is not None: + try: + _ = importlib.metadata.metadata("mlflow-skinny") + return True + except importlib.metadata.PackageNotFoundError: + return False + return False + + +def is_mps_available(min_version="1.12"): + "Checks if MPS device is available. The minimum version required is 1.12." + # With torch 1.12, you can use torch.backends.mps + # With torch 2.0.0, you can use torch.mps + return is_torch_version(">=", min_version) and torch.backends.mps.is_available() and torch.backends.mps.is_built() + + +@lru_cache +def is_mlu_available(check_device=False): + """ + Checks if `mlu` is available via an `cndev-based` check which won't trigger the drivers and leave mlu + uninitialized. + """ + if importlib.util.find_spec("torch_mlu") is None: + return False + + import torch_mlu # noqa: F401 + + with patch_environment(PYTORCH_CNDEV_BASED_MLU_CHECK="1"): + available = torch.mlu.is_available() + + return available + + +@lru_cache +def is_musa_available(check_device=False): + "Checks if `torch_musa` is installed and potentially if a MUSA is in the environment" + if importlib.util.find_spec("torch_musa") is None: + return False + + import torch_musa # noqa: F401 + + if check_device: + try: + # Will raise a RuntimeError if no MUSA is found + _ = torch.musa.device_count() + return torch.musa.is_available() + except RuntimeError: + return False + return hasattr(torch, "musa") and torch.musa.is_available() + + +@lru_cache +def is_npu_available(check_device=False): + "Checks if `torch_npu` is installed and potentially if a NPU is in the environment" + if importlib.util.find_spec("torch_npu") is None: + return False + + # NOTE: importing torch_npu may raise error in some envs + # e.g. inside cpu-only container with torch_npu installed + try: + import torch_npu # noqa: F401 + except Exception: + return False + + if check_device: + try: + # Will raise a RuntimeError if no NPU is found + _ = torch.npu.device_count() + return torch.npu.is_available() + except RuntimeError: + return False + return hasattr(torch, "npu") and torch.npu.is_available() + + +@lru_cache +def is_sdaa_available(check_device=False): + "Checks if `torch_sdaa` is installed and potentially if a SDAA is in the environment" + if importlib.util.find_spec("torch_sdaa") is None: + return False + + import torch_sdaa # noqa: F401 + + if check_device: + try: + # Will raise a RuntimeError if no NPU is found + _ = torch.sdaa.device_count() + return torch.sdaa.is_available() + except RuntimeError: + return False + return hasattr(torch, "sdaa") and torch.sdaa.is_available() + + +@lru_cache +def is_hpu_available(init_hccl=False): + "Checks if `torch.hpu` is installed and potentially if a HPU is in the environment" + if ( + importlib.util.find_spec("habana_frameworks") is None + or importlib.util.find_spec("habana_frameworks.torch") is None + ): + return False + + import habana_frameworks.torch # noqa: F401 + + if init_hccl: + import habana_frameworks.torch.distributed.hccl as hccl # noqa: F401 + + return hasattr(torch, "hpu") and torch.hpu.is_available() + + +def is_habana_gaudi1(): + if is_hpu_available(): + import habana_frameworks.torch.utils.experimental as htexp # noqa: F401 + + if htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi: + return True + + return False + + +@lru_cache +def is_xpu_available(check_device=False): + """ + Checks if XPU acceleration is available via stock PyTorch (>=2.7) and + potentially if a XPU is in the environment + """ + + if is_torch_version("<=", "2.6"): + return False + + if check_device: + try: + # Will raise a RuntimeError if no XPU is found + _ = torch.xpu.device_count() + return torch.xpu.is_available() + except RuntimeError: + return False + return hasattr(torch, "xpu") and torch.xpu.is_available() + + +@lru_cache +def is_neuron_available(check_device=False): + if importlib.util.find_spec("torch_neuronx") is None: + return False + + if check_device: + try: + import torch_neuronx # noqa: F401 + + # Will raise a RuntimeError if no Neuron is found + _ = torch.neuron.device_count() + return torch.neuron.is_available() + except RuntimeError: + return False + + return hasattr(torch, "neuron") and torch.neuron.is_available() + + +def is_dvclive_available(): + return _is_package_available("dvclive") + + +def is_torchdata_available(): + return _is_package_available("torchdata") + + +# TODO: Remove this function once stateful_dataloader is a stable feature in torchdata. +def is_torchdata_stateful_dataloader_available(): + package_exists = _is_package_available("torchdata") + if package_exists: + torchdata_version = version.parse(importlib.metadata.version("torchdata")) + return compare_versions(torchdata_version, ">=", "0.8.0") + return False + + +def torchao_required(func): + """ + A decorator that ensures the decorated function is only called when torchao is available. + """ + + @wraps(func) + def wrapper(*args, **kwargs): + if not is_torchao_available(): + raise ImportError( + "`torchao` is not available, please install it before calling this function via `pip install torchao`." + ) + return func(*args, **kwargs) + + return wrapper + + +# TODO: Rework this into `utils.deepspeed` and migrate the "core" chunks into `accelerate.deepspeed` +def deepspeed_required(func): + """ + A decorator that ensures the decorated function is only called when deepspeed is enabled. + """ + + @wraps(func) + def wrapper(*args, **kwargs): + from accelerate.state import AcceleratorState + from accelerate.utils.dataclasses import DistributedType + + if AcceleratorState._shared_state != {} and AcceleratorState().distributed_type != DistributedType.DEEPSPEED: + raise ValueError( + "DeepSpeed is not enabled, please make sure that an `Accelerator` is configured for `deepspeed` " + "before calling this function." + ) + return func(*args, **kwargs) + + return wrapper + + +def is_weights_only_available(): + # Weights only with allowlist was added in 2.4.0 + # ref: https://github.com/pytorch/pytorch/pull/124331 + return is_torch_version(">=", "2.4.0") + + +def is_numpy_available(min_version="1.25.0"): + numpy_version = parse(importlib.metadata.version("numpy")) + return compare_versions(numpy_version, ">=", min_version) diff --git a/accelerate/utils/launch.py b/accelerate/utils/launch.py new file mode 100644 index 0000000000000000000000000000000000000000..a1801c8c4335a445026b64bc00ff3fc4abe0414f --- /dev/null +++ b/accelerate/utils/launch.py @@ -0,0 +1,827 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import subprocess +import sys +import warnings +from ast import literal_eval +from shutil import which +from typing import Any + +import torch + +from ..commands.config.config_args import SageMakerConfig +from ..utils import ( + DynamoBackend, + PrecisionType, + is_fp8_available, + is_hpu_available, + is_mlu_available, + is_musa_available, + is_neuron_available, + is_npu_available, + is_sdaa_available, + is_torch_xla_available, + is_xpu_available, +) +from ..utils.constants import DEEPSPEED_MULTINODE_LAUNCHERS +from ..utils.other import get_free_port, is_port_in_use, merge_dicts +from ..utils.versions import compare_versions +from . import parse_flag_from_env +from .dataclasses import DistributedType, SageMakerDistributedType + + +def _filter_args(args, parser, default_args=[]): + """ + Filters out all `accelerate` specific args + """ + new_args, _ = parser.parse_known_args(default_args) + for key, value in vars(args).items(): + if key in vars(new_args).keys(): + setattr(new_args, key, value) + return new_args + + +def _get_mpirun_args(): + """ + Determines the executable and argument names for mpirun, based on the type of install. The supported MPI programs + are: OpenMPI, Intel MPI, or MVAPICH. + + Returns: Program name and arg names for hostfile, num processes, and processes per node + """ + # Find the MPI program name + mpi_apps = [x for x in ["mpirun", "mpiexec"] if which(x)] + + if len(mpi_apps) == 0: + raise OSError("mpirun or mpiexec were not found. Ensure that Intel MPI, Open MPI, or MVAPICH are installed.") + + # Call the app with the --version flag to determine which MPI app is installed + mpi_app = mpi_apps[0] + mpirun_version = subprocess.check_output([mpi_app, "--version"]) + + if b"Open MPI" in mpirun_version: + return mpi_app, "--hostfile", "-n", "--npernode", "--bind-to" + else: + # Intel MPI and MVAPICH both use the same arg names + return mpi_app, "-f", "-n", "-ppn", "" + + +def setup_fp8_env(args: argparse.Namespace, current_env: dict[str, str]): + """ + Setup the FP8 environment variables. + """ + prefix = "ACCELERATE_" + for arg in vars(args): + if arg.startswith("fp8_"): + value = getattr(args, arg) + if value is not None: + if arg == "fp8_override_linear_precision": + current_env[prefix + "FP8_OVERRIDE_FPROP"] = str(value[0]) + current_env[prefix + "FP8_OVERRIDE_DGRAD"] = str(value[1]) + current_env[prefix + "FP8_OVERRIDE_WGRAD"] = str(value[2]) + else: + current_env[f"{prefix}{arg.upper()}"] = str(getattr(args, arg)) + return current_env + + +def prepare_simple_launcher_cmd_env(args: argparse.Namespace) -> tuple[list[str], dict[str, str]]: + """ + Prepares and returns the command list and an environment with the correct simple launcher environment variables. + """ + cmd = [] + if args.no_python and args.module: + raise ValueError("--module and --no_python cannot be used together") + + num_processes = getattr(args, "num_processes", None) + num_machines = args.num_machines + if args.mpirun_hostfile is not None: + mpi_app_name, hostfile_arg, num_proc_arg, proc_per_node_arg, bind_to_arg = _get_mpirun_args() + bind_to = getattr(args, "bind-to", "socket") + nproc_per_node = str(num_processes // num_machines) if num_processes and num_machines else "1" + cmd += [ + mpi_app_name, + hostfile_arg, + args.mpirun_hostfile, + proc_per_node_arg, + nproc_per_node, + ] + if num_processes: + cmd += [num_proc_arg, str(num_processes)] + if bind_to_arg: + cmd += [bind_to_arg, bind_to] + if not args.no_python: + cmd.append(sys.executable) + if args.module: + cmd.append("-m") + cmd.append(args.training_script) + cmd.extend(args.training_script_args) + + current_env = os.environ.copy() + current_env["ACCELERATE_USE_CPU"] = str(args.cpu or args.use_cpu) + if args.debug: + current_env["ACCELERATE_DEBUG_MODE"] = "true" + if args.gpu_ids != "all" and args.gpu_ids is not None: + if is_xpu_available(): + current_env["ZE_AFFINITY_MASK"] = args.gpu_ids + elif is_mlu_available(): + current_env["MLU_VISIBLE_DEVICES"] = args.gpu_ids + elif is_sdaa_available(): + current_env["SDAA_VISIBLE_DEVICES"] = args.gpu_ids + elif is_musa_available(): + current_env["MUSA_VISIBLE_DEVICES"] = args.gpu_ids + elif is_npu_available(): + current_env["ASCEND_RT_VISIBLE_DEVICES"] = args.gpu_ids + elif is_hpu_available(): + current_env["HABANA_VISIBLE_MODULES"] = args.gpu_ids + elif is_neuron_available(): + current_env["NEURON_RT_VISIBLE_CORES"] = args.gpu_ids + else: + current_env["CUDA_VISIBLE_DEVICES"] = args.gpu_ids + if num_machines > 1: + assert args.main_process_ip is not None, ( + "When using multiple machines, you need to specify the main process IP." + ) + assert args.main_process_port is not None, ( + "When using multiple machines, you need to specify the main process port." + ) + + if (num_processes is not None and num_processes > 1) or num_machines > 1: + current_env["MASTER_ADDR"] = args.main_process_ip if args.main_process_ip is not None else "127.0.0.1" + current_env["MASTER_PORT"] = str(args.main_process_port) if args.main_process_port is not None else "29500" + if parse_flag_from_env(current_env["ACCELERATE_USE_CPU"], False): + current_env["KMP_AFFINITY"] = "granularity=fine,compact,1,0" + current_env["KMP_BLOCKTIME"] = str(1) + + try: + mixed_precision = PrecisionType(args.mixed_precision.lower()) + except ValueError: + raise ValueError( + f"Unknown mixed_precision mode: {args.mixed_precision.lower()}. Choose between {PrecisionType.list()}." + ) + + current_env["ACCELERATE_MIXED_PRECISION"] = str(mixed_precision) + if args.mixed_precision.lower() == "fp8": + if not is_fp8_available(): + raise RuntimeError( + "FP8 is not available on this machine. Please ensure that either Transformer Engine, MSAMP or torchao is installed." + ) + current_env = setup_fp8_env(args, current_env) + + try: + dynamo_backend = DynamoBackend(args.dynamo_backend.upper()) + except ValueError: + raise ValueError( + f"Unknown dynamo backend: {args.dynamo_backend.upper()}. Choose between {DynamoBackend.list()}." + ) + current_env["ACCELERATE_DYNAMO_BACKEND"] = dynamo_backend.value + current_env["ACCELERATE_DYNAMO_MODE"] = args.dynamo_mode + current_env["ACCELERATE_DYNAMO_USE_FULLGRAPH"] = str(args.dynamo_use_fullgraph) + current_env["ACCELERATE_DYNAMO_USE_DYNAMIC"] = str(args.dynamo_use_dynamic) + current_env["ACCELERATE_DYNAMO_USE_REGIONAL_COMPILATION"] = str(args.dynamo_use_regional_compilation) + + current_env["OMP_NUM_THREADS"] = str(args.num_cpu_threads_per_process) + if args.enable_cpu_affinity: + current_env["ACCELERATE_CPU_AFFINITY"] = "1" + return cmd, current_env + + +def prepare_multi_gpu_env(args: argparse.Namespace) -> dict[str, str]: + """ + Prepares and returns an environment with the correct multi-GPU environment variables. + """ + # get free port and update configurations + if args.main_process_port == 0: + args.main_process_port = get_free_port() + + elif args.main_process_port is None: + args.main_process_port = 29500 + + num_processes = args.num_processes + num_machines = args.num_machines + main_process_ip = args.main_process_ip + main_process_port = args.main_process_port + if num_machines > 1: + args.nproc_per_node = str(num_processes // num_machines) + args.nnodes = str(num_machines) + args.node_rank = int(args.machine_rank) + if getattr(args, "same_network", False): + args.master_addr = str(main_process_ip) + args.master_port = str(main_process_port) + else: + args.rdzv_endpoint = f"{main_process_ip}:{main_process_port}" + else: + args.nproc_per_node = str(num_processes) + if main_process_port is not None: + args.master_port = str(main_process_port) + + # only need to check port availability in main process, in case we have to start multiple launchers on the same machine + # for some reasons like splitting log files. + need_port_check = num_machines <= 1 or int(args.machine_rank) == 0 + if need_port_check and is_port_in_use(main_process_port): + if num_machines <= 1: + args.standalone = True + warnings.warn( + f"Port `{main_process_port}` is already in use. " + "Accelerate will attempt to launch in a standalone-like mode by finding an open port automatically for this session. " + "If this current attempt fails, or for more control in future runs, please specify a different port " + "(e.g., `--main_process_port `) or use `--main_process_port 0` for automatic selection " + "in your launch command or Accelerate config file." + ) + else: + raise ConnectionError( + f"Tried to launch distributed communication on port `{main_process_port}`, but another process is utilizing it. " + "Please specify a different port (such as using the `--main_process_port` flag or specifying a different `main_process_port` in your config file)" + " and rerun your script. To automatically use the next open port (on a single node), you can set this to `0`." + ) + + if args.module and args.no_python: + raise ValueError("--module and --no_python cannot be used together") + elif args.module: + args.module = True + elif args.no_python: + args.no_python = True + + current_env = os.environ.copy() + if args.debug: + current_env["ACCELERATE_DEBUG_MODE"] = "true" + gpu_ids = getattr(args, "gpu_ids", "all") + if gpu_ids != "all" and args.gpu_ids is not None: + if is_xpu_available(): + current_env["ZE_AFFINITY_MASK"] = gpu_ids + elif is_mlu_available(): + current_env["MLU_VISIBLE_DEVICES"] = gpu_ids + elif is_sdaa_available(): + current_env["SDAA_VISIBLE_DEVICES"] = gpu_ids + elif is_musa_available(): + current_env["MUSA_VISIBLE_DEVICES"] = gpu_ids + elif is_npu_available(): + current_env["ASCEND_RT_VISIBLE_DEVICES"] = gpu_ids + elif is_hpu_available(): + current_env["HABANA_VISIBLE_MODULES"] = gpu_ids + elif is_neuron_available(): + current_env["NEURON_RT_VISIBLE_CORES"] = gpu_ids + else: + current_env["CUDA_VISIBLE_DEVICES"] = gpu_ids + mixed_precision = args.mixed_precision.lower() + try: + mixed_precision = PrecisionType(mixed_precision) + except ValueError: + raise ValueError(f"Unknown mixed_precision mode: {mixed_precision}. Choose between {PrecisionType.list()}.") + + current_env["ACCELERATE_MIXED_PRECISION"] = str(mixed_precision) + if args.mixed_precision.lower() == "fp8": + if not is_fp8_available(): + raise RuntimeError( + "FP8 is not available on this machine. Please ensure that either Transformer Engine, MSAMP or torchao is installed." + ) + current_env = setup_fp8_env(args, current_env) + + try: + dynamo_backend = DynamoBackend(args.dynamo_backend.upper()) + except ValueError: + raise ValueError( + f"Unknown dynamo backend: {args.dynamo_backend.upper()}. Choose between {DynamoBackend.list()}." + ) + current_env["ACCELERATE_DYNAMO_BACKEND"] = dynamo_backend.value + current_env["ACCELERATE_DYNAMO_MODE"] = args.dynamo_mode + current_env["ACCELERATE_DYNAMO_USE_FULLGRAPH"] = str(args.dynamo_use_fullgraph) + current_env["ACCELERATE_DYNAMO_USE_DYNAMIC"] = str(args.dynamo_use_dynamic) + current_env["ACCELERATE_DYNAMO_USE_REGIONAL_COMPILATION"] = str(args.dynamo_use_regional_compilation) + + if args.use_fsdp: + current_env["ACCELERATE_USE_FSDP"] = "true" + if args.fsdp_cpu_ram_efficient_loading and not args.fsdp_sync_module_states: + raise ValueError("When using `--fsdp_cpu_ram_efficient_loading` set `--fsdp_sync_module_states` to `True`") + + current_env["FSDP_VERSION"] = str(args.fsdp_version) if hasattr(args, "fsdp_version") else "1" + + # For backwards compatibility, we support this in launched scripts, + # however, we do not ask users for this in `accelerate config` CLI + current_env["FSDP_SHARDING_STRATEGY"] = str(args.fsdp_sharding_strategy) + + current_env["FSDP_RESHARD_AFTER_FORWARD"] = str(args.fsdp_reshard_after_forward).lower() + current_env["FSDP_OFFLOAD_PARAMS"] = str(args.fsdp_offload_params).lower() + current_env["FSDP_MIN_NUM_PARAMS"] = str(args.fsdp_min_num_params) + if args.fsdp_auto_wrap_policy is not None: + current_env["FSDP_AUTO_WRAP_POLICY"] = str(args.fsdp_auto_wrap_policy) + if args.fsdp_transformer_layer_cls_to_wrap is not None: + current_env["FSDP_TRANSFORMER_CLS_TO_WRAP"] = str(args.fsdp_transformer_layer_cls_to_wrap) + if args.fsdp_backward_prefetch is not None: + current_env["FSDP_BACKWARD_PREFETCH"] = str(args.fsdp_backward_prefetch) + if args.fsdp_state_dict_type is not None: + current_env["FSDP_STATE_DICT_TYPE"] = str(args.fsdp_state_dict_type) + current_env["FSDP_FORWARD_PREFETCH"] = str(args.fsdp_forward_prefetch).lower() + current_env["FSDP_USE_ORIG_PARAMS"] = str(args.fsdp_use_orig_params).lower() + current_env["FSDP_CPU_RAM_EFFICIENT_LOADING"] = str(args.fsdp_cpu_ram_efficient_loading).lower() + current_env["FSDP_SYNC_MODULE_STATES"] = str(args.fsdp_sync_module_states).lower() + current_env["FSDP_ACTIVATION_CHECKPOINTING"] = str(args.fsdp_activation_checkpointing).lower() + if getattr(args, "fsdp_ignored_modules", None) is not None: + current_env["FSDP_IGNORED_MODULES"] = str(args.fsdp_ignored_modules) + + if args.use_megatron_lm: + prefix = "MEGATRON_LM_" + current_env["ACCELERATE_USE_MEGATRON_LM"] = "true" + current_env[prefix + "TP_DEGREE"] = str(args.megatron_lm_tp_degree) + current_env[prefix + "USE_CUSTOM_FSDP"] = str(args.megatron_lm_use_custom_fsdp) + if args.megatron_lm_no_load_optim is not None: + current_env[prefix + "NO_LOAD_OPTIM"] = str(args.megatron_lm_no_load_optim) + if args.megatron_lm_eod_mask_loss is not None: + current_env[prefix + "EOD_MASK_LOSS"] = str(args.megatron_lm_eod_mask_loss) + if args.megatron_lm_no_save_optim is not None: + current_env[prefix + "NO_SAVE_OPTIM"] = str(args.megatron_lm_no_save_optim) + if args.megatron_lm_optimizer_cpu_offload is not None: + current_env[prefix + "OPTIMIZER_CPU_OFFLOAD"] = str(args.megatron_lm_optimizer_cpu_offload) + if args.megatron_lm_use_precision_aware_optimizer is not None: + current_env[prefix + "USE_PRECISION_AWARE_OPTIMIZER"] = str(args.megatron_lm_use_precision_aware_optimizer) + if args.megatron_lm_overlap_cpu_optimizer_d2h_h2d is not None: + current_env[prefix + "OVERLAP_CPU_OPTIMIZER_D2H_H2D"] = str(args.megatron_lm_overlap_cpu_optimizer_d2h_h2d) + if args.megatron_lm_decoder_last_pipeline_num_layers is not None: + current_env[prefix + "DECODER_LAST_PIPELINE_NUM_LAYERS"] = str( + args.megatron_lm_decoder_last_pipeline_num_layers + ) + current_env[prefix + "PP_DEGREE"] = str(args.megatron_lm_pp_degree) + current_env[prefix + "GRADIENT_CLIPPING"] = str(args.megatron_lm_gradient_clipping) + if args.megatron_lm_num_micro_batches is not None: + current_env[prefix + "NUM_MICRO_BATCHES"] = str(args.megatron_lm_num_micro_batches) + if args.megatron_lm_sequence_parallelism is not None: + current_env[prefix + "SEQUENCE_PARALLELISM"] = str(args.megatron_lm_sequence_parallelism) + if args.megatron_lm_recompute_activations is not None: + current_env[prefix + "RECOMPUTE_ACTIVATIONS"] = str(args.megatron_lm_recompute_activations) + if args.megatron_lm_use_distributed_optimizer is not None: + current_env[prefix + "USE_DISTRIBUTED_OPTIMIZER"] = str(args.megatron_lm_use_distributed_optimizer) + if args.megatron_lm_recompute_granularity is not None: + current_env[prefix + "RECOMPUTE_GRANULARITY"] = str(args.megatron_lm_recompute_granularity) + if args.megatron_lm_recompute_method is not None: + current_env[prefix + "RECOMPUTE_METHOD"] = str(args.megatron_lm_recompute_method) + if args.megatron_lm_recompute_num_layers is not None: + current_env[prefix + "RECOMPUTE_NUM_LAYERS"] = str(args.megatron_lm_recompute_num_layers) + if args.megatron_lm_attention_backend is not None: + current_env[prefix + "ATTENTION_BACKEND"] = str(args.megatron_lm_attention_backend) + if args.megatron_lm_expert_model_parallel_size is not None: + current_env[prefix + "EXPERT_MODEL_PARALLEL_SIZE"] = str(args.megatron_lm_expert_model_parallel_size) + if args.megatron_lm_context_parallel_size is not None: + current_env[prefix + "CONTEXT_PARALLEL_SIZE"] = str(args.megatron_lm_context_parallel_size) + if args.megatron_lm_attention_dropout is not None: + current_env[prefix + "ATTENTION_DROPOUT"] = str(args.megatron_lm_attention_dropout) + if args.megatron_lm_hidden_dropout is not None: + current_env[prefix + "HIDDEN_DROPOUT"] = str(args.megatron_lm_hidden_dropout) + if args.megatron_lm_attention_softmax_in_fp32 is not None: + current_env[prefix + "ATTENTION_SOFTMAX_IN_FP32"] = str(args.megatron_lm_attention_softmax_in_fp32) + if args.megatron_lm_expert_tensor_parallel_size is not None: + current_env[prefix + "EXPERT_TENSOR_PARALLEL_SIZE"] = str(args.megatron_lm_expert_tensor_parallel_size) + if args.megatron_lm_calculate_per_token_loss is not None: + current_env[prefix + "CALCULATE_PER_TOKEN_LOSS"] = str(args.megatron_lm_calculate_per_token_loss) + if args.megatron_lm_use_rotary_position_embeddings is not None: + current_env[prefix + "USE_ROTARY_POSITION_EMBEDDINGS"] = str( + args.megatron_lm_use_rotary_position_embeddings + ) + + current_env["OMP_NUM_THREADS"] = str(args.num_cpu_threads_per_process) + if args.enable_cpu_affinity: + current_env["ACCELERATE_CPU_AFFINITY"] = "1" + + if args.use_parallelism_config: + current_env = prepare_extend_env_parallelism_config(args, current_env) + + return current_env + + +def prepare_extend_env_parallelism_config( + args: argparse.Namespace, current_env: dict +) -> tuple[list[str], dict[str, str]]: + """ + Extends `current_env` with context parallelism env vars if any have been set + """ + + prefix = "PARALLELISM_CONFIG_" + + current_env["ACCELERATE_USE_PARALLELISM_CONFIG"] = "true" + current_env[prefix + "DP_REPLICATE_SIZE"] = str(args.parallelism_config_dp_replicate_size) + current_env[prefix + "DP_SHARD_SIZE"] = str(args.parallelism_config_dp_shard_size) + current_env[prefix + "TP_SIZE"] = str(args.parallelism_config_tp_size) + current_env[prefix + "CP_SIZE"] = str(args.parallelism_config_cp_size) + current_env[prefix + "CP_BACKEND"] = str(args.parallelism_config_cp_backend) + current_env[prefix + "SP_SIZE"] = str(args.parallelism_config_sp_size) + current_env[prefix + "SP_BACKEND"] = str(args.parallelism_config_sp_backend) + if args.parallelism_config_cp_size > 1: + current_env[prefix + "CP_COMM_STRATEGY"] = str(args.parallelism_config_cp_comm_strategy) + if args.parallelism_config_sp_size > 1: + current_env[prefix + "SP_SEQ_LENGTH"] = str(args.parallelism_config_sp_seq_length) + current_env[prefix + "SP_SEQ_LENGTH_IS_VARIABLE"] = str(args.parallelism_config_sp_seq_length_is_variable) + current_env[prefix + "SP_ATTN_IMPLEMENTATION"] = str(args.parallelism_config_sp_attn_implementation) + + return current_env + + +def prepare_deepspeed_cmd_env(args: argparse.Namespace) -> tuple[list[str], dict[str, str]]: + """ + Prepares and returns the command list and an environment with the correct DeepSpeed environment variables. + """ + # get free port and update configurations + if args.main_process_port == 0: + args.main_process_port = get_free_port() + + elif args.main_process_port is None: + args.main_process_port = 29500 + + num_processes = args.num_processes + num_machines = args.num_machines + main_process_ip = args.main_process_ip + main_process_port = args.main_process_port + cmd = None + + # make sure launcher is not None + if args.deepspeed_multinode_launcher is None: + # set to default pdsh + args.deepspeed_multinode_launcher = DEEPSPEED_MULTINODE_LAUNCHERS[0] + + if num_machines > 1 and args.deepspeed_multinode_launcher != DEEPSPEED_MULTINODE_LAUNCHERS[1]: + cmd = ["deepspeed"] + cmd.extend(["--hostfile", str(args.deepspeed_hostfile)]) + if args.deepspeed_multinode_launcher == "nossh": + if compare_versions("deepspeed", "<", "0.14.5"): + raise ValueError("nossh launcher requires DeepSpeed >= 0.14.5") + cmd.extend(["--node_rank", str(args.machine_rank), "--no_ssh"]) + else: + cmd.extend(["--no_local_rank", "--launcher", str(args.deepspeed_multinode_launcher)]) + if args.deepspeed_exclusion_filter is not None: + cmd.extend( + [ + "--exclude", + str(args.deepspeed_exclusion_filter), + ] + ) + elif args.deepspeed_inclusion_filter is not None: + cmd.extend( + [ + "--include", + str(args.deepspeed_inclusion_filter), + ] + ) + else: + cmd.extend(["--num_gpus", str(args.num_processes // args.num_machines)]) + if main_process_ip: + cmd.extend(["--master_addr", str(main_process_ip)]) + cmd.extend(["--master_port", str(main_process_port)]) + if args.module and args.no_python: + raise ValueError("--module and --no_python cannot be used together") + elif args.module: + cmd.append("--module") + elif args.no_python: + cmd.append("--no_python") + cmd.append(args.training_script) + cmd.extend(args.training_script_args) + elif num_machines > 1 and args.deepspeed_multinode_launcher == DEEPSPEED_MULTINODE_LAUNCHERS[1]: + args.nproc_per_node = str(num_processes // num_machines) + args.nnodes = str(num_machines) + args.node_rank = int(args.machine_rank) + if getattr(args, "same_network", False): + args.master_addr = str(main_process_ip) + args.master_port = str(main_process_port) + else: + args.rdzv_endpoint = f"{main_process_ip}:{main_process_port}" + else: + args.nproc_per_node = str(num_processes) + if main_process_port is not None: + args.master_port = str(main_process_port) + + # only need to check port availability in main process, in case we have to start multiple launchers on the same machine + # for some reasons like splitting log files. + need_port_check = num_machines <= 1 or int(args.machine_rank) == 0 + if need_port_check and is_port_in_use(main_process_port): + if num_machines <= 1: + args.standalone = True + warnings.warn( + f"Port `{main_process_port}` is already in use. " + "Accelerate will attempt to launch in a standalone-like mode by finding an open port automatically for this session. " + "If this current attempt fails, or for more control in future runs, please specify a different port " + "(e.g., `--main_process_port `) or use `--main_process_port 0` for automatic selection " + "in your launch command or Accelerate config file." + ) + else: + raise ConnectionError( + f"Tried to launch distributed communication on port `{main_process_port}`, but another process is utilizing it. " + "Please specify a different port (such as using the `--main_process_port` flag or specifying a different `main_process_port` in your config file)" + " and rerun your script. To automatically use the next open port (on a single node), you can set this to `0`." + ) + + if args.module and args.no_python: + raise ValueError("--module and --no_python cannot be used together") + elif args.module: + args.module = True + elif args.no_python: + args.no_python = True + + current_env = os.environ.copy() + if args.debug: + current_env["ACCELERATE_DEBUG_MODE"] = "true" + gpu_ids = getattr(args, "gpu_ids", "all") + if gpu_ids != "all" and args.gpu_ids is not None: + if is_xpu_available(): + current_env["ZE_AFFINITY_MASK"] = gpu_ids + elif is_mlu_available(): + current_env["MLU_VISIBLE_DEVICES"] = gpu_ids + elif is_sdaa_available(): + current_env["SDAA_VISIBLE_DEVICES"] = gpu_ids + elif is_musa_available(): + current_env["MUSA_VISIBLE_DEVICES"] = gpu_ids + elif is_npu_available(): + current_env["ASCEND_RT_VISIBLE_DEVICES"] = gpu_ids + elif is_hpu_available(): + current_env["HABANA_VISIBLE_MODULES"] = gpu_ids + elif is_neuron_available(): + current_env["NEURON_RT_VISIBLE_CORES"] = gpu_ids + else: + current_env["CUDA_VISIBLE_DEVICES"] = gpu_ids + try: + mixed_precision = PrecisionType(args.mixed_precision.lower()) + except ValueError: + raise ValueError( + f"Unknown mixed_precision mode: {args.mixed_precision.lower()}. Choose between {PrecisionType.list()}." + ) + + current_env["PYTHONPATH"] = env_var_path_add("PYTHONPATH", os.path.abspath(".")) + current_env["ACCELERATE_MIXED_PRECISION"] = str(mixed_precision) + if args.mixed_precision.lower() == "fp8": + if not is_fp8_available(): + raise RuntimeError( + "FP8 is not available on this machine. Please ensure that either Transformer Engine, MSAMP or torchao is installed." + ) + current_env = setup_fp8_env(args, current_env) + current_env["ACCELERATE_CONFIG_DS_FIELDS"] = str(args.deepspeed_fields_from_accelerate_config).lower() + current_env["ACCELERATE_USE_DEEPSPEED"] = "true" + if args.zero_stage is not None: + current_env["ACCELERATE_DEEPSPEED_ZERO_STAGE"] = str(args.zero_stage) + if args.gradient_accumulation_steps is not None: + current_env["ACCELERATE_GRADIENT_ACCUMULATION_STEPS"] = str(args.gradient_accumulation_steps) + if args.gradient_clipping is not None: + current_env["ACCELERATE_GRADIENT_CLIPPING"] = str(args.gradient_clipping).lower() + if args.offload_optimizer_device is not None: + current_env["ACCELERATE_DEEPSPEED_OFFLOAD_OPTIMIZER_DEVICE"] = str(args.offload_optimizer_device).lower() + if args.offload_param_device is not None: + current_env["ACCELERATE_DEEPSPEED_OFFLOAD_PARAM_DEVICE"] = str(args.offload_param_device).lower() + if args.zero3_init_flag is not None: + current_env["ACCELERATE_DEEPSPEED_ZERO3_INIT"] = str(args.zero3_init_flag).lower() + if args.zero3_save_16bit_model is not None: + current_env["ACCELERATE_DEEPSPEED_ZERO3_SAVE_16BIT_MODEL"] = str(args.zero3_save_16bit_model).lower() + if args.deepspeed_config_file is not None: + current_env["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = str(args.deepspeed_config_file) + if args.enable_cpu_affinity: + current_env["ACCELERATE_CPU_AFFINITY"] = "1" + if args.deepspeed_moe_layer_cls_names is not None: + current_env["ACCELERATE_DEEPSPEED_MOE_LAYER_CLS_NAMES"] = str(args.deepspeed_moe_layer_cls_names) + + if args.use_parallelism_config: + current_env = prepare_extend_env_parallelism_config(args, current_env) + + return cmd, current_env + + +def prepare_tpu( + args: argparse.Namespace, current_env: dict[str, str], pod: bool = False +) -> tuple[argparse.Namespace, dict[str, str]]: + """ + Prepares and returns an environment with the correct TPU environment variables. + """ + if args.mixed_precision == "bf16" and is_torch_xla_available(check_is_tpu=True): + if args.downcast_bf16: + current_env["XLA_DOWNCAST_BF16"] = "1" + else: + current_env["XLA_USE_BF16"] = "1" + if args.debug: + current_env["ACCELERATE_DEBUG_MODE"] = "true" + if pod: + # Take explicit args and set them up for XLA + args.vm = args.tpu_vm + args.tpu = args.tpu_name + return args, current_env + + +def _convert_nargs_to_dict(nargs: list[str]) -> dict[str, str]: + if len(nargs) < 0: + return {} + # helper function to infer type for argsparser + + def _infer_type(s): + try: + s = float(s) + + if s // 1 == s: + return int(s) + return s + except ValueError: + return s + + parser = argparse.ArgumentParser() + _, unknown = parser.parse_known_args(nargs) + for index, argument in enumerate(unknown): + if argument.startswith(("-", "--")): + action = None + if index + 1 < len(unknown): # checks if next index would be in list + if unknown[index + 1].startswith(("-", "--")): # checks if next element is an key + # raise an error if element is store_true or store_false + raise ValueError( + "SageMaker doesn’t support argparse actions for `store_true` or `store_false`. Please define explicit types" + ) + else: # raise an error if last element is store_true or store_false + raise ValueError( + "SageMaker doesn’t support argparse actions for `store_true` or `store_false`. Please define explicit types" + ) + # adds argument to parser based on action_store true + if action is None: + parser.add_argument(argument, type=_infer_type) + else: + parser.add_argument(argument, action=action) + + return { + key: (literal_eval(value) if value in ("True", "False") else value) + for key, value in parser.parse_args(nargs).__dict__.items() + } + + +def prepare_sagemager_args_inputs( + sagemaker_config: SageMakerConfig, args: argparse.Namespace +) -> tuple[argparse.Namespace, dict[str, Any]]: + # configure environment + print("Configuring Amazon SageMaker environment") + os.environ["AWS_DEFAULT_REGION"] = sagemaker_config.region + + # configure credentials + if sagemaker_config.profile is not None: + os.environ["AWS_PROFILE"] = sagemaker_config.profile + elif args.aws_access_key_id is not None and args.aws_secret_access_key is not None: + os.environ["AWS_ACCESS_KEY_ID"] = args.aws_access_key_id + os.environ["AWS_SECRET_ACCESS_KEY"] = args.aws_secret_access_key + else: + raise OSError("You need to provide an aws_access_key_id and aws_secret_access_key when not using aws_profile") + + # extract needed arguments + source_dir = os.path.dirname(args.training_script) + if not source_dir: # checks if string is empty + source_dir = "." + entry_point = os.path.basename(args.training_script) + if not entry_point.endswith(".py"): + raise ValueError(f'Your training script should be a python script and not "{entry_point}"') + + print("Converting Arguments to Hyperparameters") + hyperparameters = _convert_nargs_to_dict(args.training_script_args) + + try: + mixed_precision = PrecisionType(args.mixed_precision.lower()) + except ValueError: + raise ValueError( + f"Unknown mixed_precision mode: {args.mixed_precision.lower()}. Choose between {PrecisionType.list()}." + ) + + try: + dynamo_backend = DynamoBackend(args.dynamo_backend.upper()) + except ValueError: + raise ValueError( + f"Unknown dynamo backend: {args.dynamo_backend.upper()}. Choose between {DynamoBackend.list()}." + ) + + # Environment variables to be set for use during training job + environment = { + "ACCELERATE_USE_SAGEMAKER": "true", + "ACCELERATE_MIXED_PRECISION": str(mixed_precision), + "ACCELERATE_DYNAMO_BACKEND": dynamo_backend.value, + "ACCELERATE_DYNAMO_MODE": args.dynamo_mode, + "ACCELERATE_DYNAMO_USE_FULLGRAPH": str(args.dynamo_use_fullgraph), + "ACCELERATE_DYNAMO_USE_DYNAMIC": str(args.dynamo_use_dynamic), + "ACCELERATE_DYNAMO_USE_REGIONAL_COMPILATION": str(args.dynamo_use_regional_compilation), + "ACCELERATE_SAGEMAKER_DISTRIBUTED_TYPE": sagemaker_config.distributed_type.value, + } + if args.mixed_precision.lower() == "fp8": + if not is_fp8_available(): + raise RuntimeError( + "FP8 is not available on this machine. Please ensure that either Transformer Engine, MSAMP or torchao is installed." + ) + environment = setup_fp8_env(args, environment) + # configure distribution set up + distribution = None + if sagemaker_config.distributed_type == SageMakerDistributedType.DATA_PARALLEL: + distribution = {"smdistributed": {"dataparallel": {"enabled": True}}} + + # configure sagemaker inputs + sagemaker_inputs = None + if sagemaker_config.sagemaker_inputs_file is not None: + print(f"Loading SageMaker Inputs from {sagemaker_config.sagemaker_inputs_file} file") + sagemaker_inputs = {} + with open(sagemaker_config.sagemaker_inputs_file) as file: + for i, line in enumerate(file): + if i == 0: + continue + l = line.split("\t") + sagemaker_inputs[l[0]] = l[1].strip() + print(f"Loaded SageMaker Inputs: {sagemaker_inputs}") + + # configure sagemaker metrics + sagemaker_metrics = None + if sagemaker_config.sagemaker_metrics_file is not None: + print(f"Loading SageMaker Metrics from {sagemaker_config.sagemaker_metrics_file} file") + sagemaker_metrics = [] + with open(sagemaker_config.sagemaker_metrics_file) as file: + for i, line in enumerate(file): + if i == 0: + continue + l = line.split("\t") + metric_dict = { + "Name": l[0], + "Regex": l[1].strip(), + } + sagemaker_metrics.append(metric_dict) + print(f"Loaded SageMaker Metrics: {sagemaker_metrics}") + + # configure session + print("Creating Estimator") + args = { + "image_uri": sagemaker_config.image_uri, + "entry_point": entry_point, + "source_dir": source_dir, + "role": sagemaker_config.iam_role_name, + "transformers_version": sagemaker_config.transformers_version, + "pytorch_version": sagemaker_config.pytorch_version, + "py_version": sagemaker_config.py_version, + "base_job_name": sagemaker_config.base_job_name, + "instance_count": sagemaker_config.num_machines, + "instance_type": sagemaker_config.ec2_instance_type, + "debugger_hook_config": False, + "distribution": distribution, + "hyperparameters": hyperparameters, + "environment": environment, + "metric_definitions": sagemaker_metrics, + } + + if sagemaker_config.additional_args is not None: + args = merge_dicts(sagemaker_config.additional_args, args) + return args, sagemaker_inputs + + +def env_var_path_add(env_var_name, path_to_add): + """ + Extends a path-based environment variable's value with a new path and returns the updated value. It's up to the + caller to set it in os.environ. + """ + paths = [p for p in os.environ.get(env_var_name, "").split(":") if len(p) > 0] + paths.append(str(path_to_add)) + return ":".join(paths) + + +class PrepareForLaunch: + """ + Prepare a function that will launched in a distributed setup. + + Args: + launcher (`Callable`): + The function to launch. + distributed_type ([`~state.DistributedType`]): + The distributed type to prepare for. + debug (`bool`, *optional*, defaults to `False`): + Whether or not this is a debug launch. + """ + + def __init__(self, launcher, distributed_type="NO", debug=False): + self.launcher = launcher + self.distributed_type = DistributedType(distributed_type) + self.debug = debug + + def __call__(self, index, *args): + if self.debug: + world_size = int(os.environ.get("WORLD_SIZE")) + rdv_file = os.environ.get("ACCELERATE_DEBUG_RDV_FILE") + torch.distributed.init_process_group( + "gloo", + rank=index, + store=torch.distributed.FileStore(rdv_file, world_size), + world_size=world_size, + ) + elif self.distributed_type in ( + DistributedType.MULTI_GPU, + DistributedType.MULTI_MLU, + DistributedType.MULTI_MUSA, + DistributedType.MULTI_NPU, + DistributedType.MULTI_XPU, + DistributedType.MULTI_CPU, + DistributedType.MULTI_NEURON, + ): + # Prepare the environment for torch.distributed + os.environ["LOCAL_RANK"] = str(index) + nproc = int(os.environ.get("NPROC", 1)) + node_rank = int(os.environ.get("NODE_RANK", 0)) + os.environ["RANK"] = str(nproc * node_rank + index) + + os.environ["FORK_LAUNCHED"] = str(1) + self.launcher(*args) diff --git a/accelerate/utils/megatron_lm.py b/accelerate/utils/megatron_lm.py new file mode 100644 index 0000000000000000000000000000000000000000..ff8900cbad888af9dceca2adc2600a02d213d304 --- /dev/null +++ b/accelerate/utils/megatron_lm.py @@ -0,0 +1,1248 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import math +import os +from abc import ABC +from functools import partial + +import torch +import torch.nn.functional as F +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ..optimizer import AcceleratedOptimizer +from ..scheduler import AcceleratedScheduler +from .imports import is_megatron_lm_available +from .operations import recursively_apply, send_to_device + + +if is_megatron_lm_available(): + from megatron.core import mpu, tensor_parallel + from megatron.core.distributed import DistributedDataParallel as LocalDDP + from megatron.core.distributed import finalize_model_grads + from megatron.core.enums import ModelType + from megatron.core.num_microbatches_calculator import get_num_microbatches + from megatron.core.optimizer import get_megatron_optimizer + from megatron.core.parallel_state import get_tensor_model_parallel_group, get_tensor_model_parallel_src_rank + from megatron.core.pipeline_parallel import get_forward_backward_func + from megatron.core.utils import get_model_config + from megatron.legacy.data.dataset_utils import build_train_valid_test_datasets + from megatron.legacy.model import BertModel, T5Model + from megatron.legacy.model.classification import Classification + from megatron.training import ( + get_args, + get_tensorboard_writer, + get_tokenizer, + print_rank_last, + ) + from megatron.training.arguments import ( + _add_data_args, + _add_validation_args, + core_transformer_config_from_args, + parse_args, + validate_args, + ) + from megatron.training.checkpointing import load_args_from_checkpoint, load_checkpoint, save_checkpoint + from megatron.training.global_vars import set_global_variables + from megatron.training.gpt_builders import gpt_builder + from megatron.training.initialize import ( + _compile_dependencies, + _init_autoresume, + _initialize_distributed, + _set_random_seed, + set_jit_fusion_options, + write_args_to_tensorboard, + ) + from megatron.training.tokenizer.tokenizer import _vocab_size_with_padding + from megatron.training.training import ( + build_train_valid_test_data_iterators, + get_optimizer_param_scheduler, + num_floating_point_operations, + setup_model_and_optimizer, + train_step, + training_log, + ) + from megatron.training.utils import ( + average_losses_across_data_parallel_group, + calc_params_l2_norm, + get_ltor_masks_and_position_ids, + ) + + +# model utilities +def model_provider_func(pre_process=True, post_process=True, add_encoder=True, add_decoder=True): + """Build the model.""" + args = get_args() + mode = "pre-training" if args.pretraining_flag else "fine-tuning" + if args.rank == 0: + print(f"Building {args.model_type_name} model in the {mode} mode.") + print( + "The Megatron LM model weights are initialized at random in `accelerator.prepare`. " + "Please use `accelerator.load_checkpoint` to load a pre-trained checkpoint matching the distributed setup." + ) + config = core_transformer_config_from_args(args) + if args.model_type_name == "bert": + if args.pretraining_flag: + num_tokentypes = 2 if args.bert_binary_head else 0 + model = BertModel( + config=config, + num_tokentypes=num_tokentypes, + add_binary_head=args.bert_binary_head, + parallel_output=True, + pre_process=pre_process, + post_process=post_process, + ) + else: + model = Classification( + config=config, + num_classes=args.num_labels, + num_tokentypes=2, + pre_process=pre_process, + post_process=post_process, + ) + elif args.model_type_name == "gpt": + # use the latest gpt builder to build the model and set use_legacy_models to False + args.use_legacy_models = False + model = gpt_builder(args, pre_process, post_process, vp_stage=None, config=None) + elif args.model_type_name == "t5": + model = T5Model( + config=config, + num_tokentypes=0, + parallel_output=True, + pre_process=pre_process, + post_process=post_process, + add_encoder=add_encoder, + add_decoder=add_decoder, + ) + else: + raise ValueError(f"Unsupported model type: {args.model_type_name}") + return model + + +def prepare_model_optimizer_scheduler(accelerator): + accelerator.print("Preparing model optimizer scheduler") + args = get_args() + if accelerator.state.megatron_lm_plugin.custom_prepare_model_function is not None: + if accelerator.state.megatron_lm_plugin.custom_model_provider_function is None: + raise ValueError( + "You must provide a `custom_model_provider_function` when using a `custom_prepare_model_function`." + ) + custom_model_provider_func = accelerator.state.megatron_lm_plugin.custom_model_provider_function + model = accelerator.state.megatron_lm_plugin.custom_prepare_model_function(custom_model_provider_func) + optimizer = prepare_optimizer(accelerator, model) + scheduler = prepare_scheduler(accelerator, optimizer, scheduler=None) + else: + model_type = ModelType.encoder_or_decoder + if args.model_type_name == "t5": + model_type = ModelType.encoder_and_decoder + model_provider_func_ = model_provider_func + if accelerator.state.megatron_lm_plugin.custom_model_provider_function is not None: + model_provider_func_ = accelerator.state.megatron_lm_plugin.custom_model_provider_function + (model, optimizer, scheduler) = setup_model_and_optimizer( + model_provider_func_, + model_type, + ) + args.model_len = len(model) + return model, optimizer, scheduler + + +# dataloader utilities +class MegatronLMDummyDataLoader: + """ + Dummy dataloader presents model parameters or param groups, this is primarily used to follow conventional training + + Args: + **dataset_kwargs: Megatron data arguments. + """ + + def __init__(self, **dataset_kwargs): + parser = argparse.ArgumentParser() + parser = _add_data_args(parser) + parser = _add_validation_args(parser) + data_args = parser.parse_known_args() + self.dataset_args = vars(data_args[0]) + self.dataset_args.update(dataset_kwargs) + self.dataset_args["megatron_dataset_flag"] = True + + def set_megatron_data_args(self): + args = get_args() + for key, value in self.dataset_args.items(): + old_value = getattr(args, key, "") + if old_value != value: + print( + f"WARNING: MegatronLMDummyDataLoader overriding arguments for {key}:{old_value} with {key}:{value}" + ) + setattr(args, key, value) + + def get_train_valid_test_datasets_provider(self, accelerator): + def train_valid_test_datasets_provider(train_val_test_num_samples): + """Build train, valid, and test datasets.""" + args = get_args() + dataset_args = { + "data_prefix": args.data_path if isinstance(args.data_path, (list, tuple)) else [args.data_path], + "splits_string": args.split, + "train_valid_test_num_samples": train_val_test_num_samples, + "seed": args.seed, + } + if args.model_type_name == "bert": + dataset_args.update( + { + "max_seq_length": args.seq_length, + "binary_head": args.bert_binary_head, + } + ) + elif args.model_type_name == "gpt": + dataset_args.update( + { + "max_seq_length": args.seq_length, + } + ) + elif args.model_type_name == "t5": + dataset_args.update( + { + "max_seq_length": args.encoder_seq_length, + "max_seq_length_dec": args.decoder_seq_length, + "dataset_type": "t5", + } + ) + else: + raise ValueError(f"Unsupported model type: {args.model_type_name}") + train_ds, valid_ds, test_ds = build_train_valid_test_datasets(**dataset_args) + return train_ds, valid_ds, test_ds + + if accelerator.state.megatron_lm_plugin.custom_megatron_datasets_provider_function is not None: + return accelerator.state.megatron_lm_plugin.custom_megatron_datasets_provider_function + try: + args = get_args() + # Use '--no-use-pep517 -e' to pip install nvidia's megatron from source + if args.model_type_name == "bert": + from pretrain_bert import train_valid_test_datasets_provider + + train_valid_test_datasets_provider.is_distributed = True + return train_valid_test_datasets_provider + elif args.model_type_name == "gpt": + from pretrain_gpt import train_valid_test_datasets_provider + + train_valid_test_datasets_provider.is_distributed = True + return train_valid_test_datasets_provider + elif args.model_type_name == "t5": + from pretrain_t5 import train_valid_test_datasets_provider + + train_valid_test_datasets_provider.is_distributed = True + return train_valid_test_datasets_provider + except ImportError: + pass + return train_valid_test_datasets_provider + + def build_train_valid_test_data_iterators(self, accelerator): + args = get_args() + + train_valid_test_dataset_provider = self.get_train_valid_test_datasets_provider(accelerator) + if args.virtual_pipeline_model_parallel_size is not None: + train_data_iterator = [] + valid_data_iterator = [] + test_data_iterator = [] + for i in range(getattr(args, "model_len", 0)): + mpu.set_virtual_pipeline_model_parallel_rank(i) + iterators = build_train_valid_test_data_iterators(train_valid_test_dataset_provider) + train_data_iterator.append(iterators[0]) + valid_data_iterator.append(iterators[1]) + test_data_iterator.append(iterators[2]) + else: + train_data_iterator, valid_data_iterator, test_data_iterator = build_train_valid_test_data_iterators( + train_valid_test_dataset_provider + ) + + return train_data_iterator, valid_data_iterator, test_data_iterator + + +def _handle_megatron_data_iterator(accelerator, data_iterator): + class DummyMegatronDataloader: + def __iter__(self): + return self + + def __next__(self): + return {} + + is_data_iterator_empty = data_iterator is None + is_src_data_iterator_empty = torch.tensor(is_data_iterator_empty, dtype=torch.bool, device=accelerator.device) + torch.distributed.broadcast( + is_src_data_iterator_empty, get_tensor_model_parallel_src_rank(), group=get_tensor_model_parallel_group() + ) + if not is_src_data_iterator_empty and is_data_iterator_empty: + return DummyMegatronDataloader() + return data_iterator + + +def prepare_data_loader(accelerator, dataloader): + accelerator.print("Preparing dataloader") + args = get_args() + if not args.megatron_dataset_flag: + from ..data_loader import _PYTORCH_DATALOADER_KWARGS, prepare_data_loader + + micro_batch_size = args.micro_batch_size * args.num_micro_batches + kwargs = {k: getattr(dataloader, k, _PYTORCH_DATALOADER_KWARGS[k]) for k in _PYTORCH_DATALOADER_KWARGS} + if kwargs["batch_size"] is None: + if isinstance(kwargs["sampler"], torch.utils.data.BatchSampler): + kwargs["sampler"].batch_size = micro_batch_size + else: + del kwargs["sampler"] + del kwargs["shuffle"] + del kwargs["batch_size"] + kwargs["batch_sampler"].batch_size = micro_batch_size + else: + del kwargs["batch_sampler"] + kwargs["batch_size"] = micro_batch_size + + dataloader = torch.utils.data.DataLoader(dataloader.dataset, **kwargs) + # split_batches: + # Megatron only needs to fetch different data between different dp groups, + # and does not need to split the data within the dp group. + return prepare_data_loader( + dataloader, + accelerator.device, + num_processes=mpu.get_data_parallel_world_size(), + process_index=mpu.get_data_parallel_rank(), + split_batches=False, + put_on_device=True, + rng_types=accelerator.rng_types.copy(), + dispatch_batches=accelerator.dispatch_batches, + ) + else: + if args.consumed_samples is not None: + ( + args.consumed_train_samples, + args.consumed_valid_samples, + args.consumed_test_samples, + ) = args.consumed_samples + else: + args.consumed_train_samples, args.consumed_valid_samples, args.consumed_test_samples = 0, 0, 0 + args.micro_batch_size = args.micro_batch_size * args.num_micro_batches + # In order to be compatible with data in transform format, + # it needs to increase the size of mbs first, + # and then split the large batch data into some mbs. + ( + train_data_iterator, + valid_data_iterator, + test_data_iterator, + ) = dataloader.build_train_valid_test_data_iterators(accelerator) + args.micro_batch_size = args.micro_batch_size // args.num_micro_batches + + train_data_iterator = _handle_megatron_data_iterator( + accelerator=accelerator, data_iterator=train_data_iterator + ) + valid_data_iterator = _handle_megatron_data_iterator( + accelerator=accelerator, data_iterator=valid_data_iterator + ) + test_data_iterator = _handle_megatron_data_iterator(accelerator=accelerator, data_iterator=test_data_iterator) + + return train_data_iterator, valid_data_iterator, test_data_iterator + + +# optimizer utilities +class MegatronLMOptimizerWrapper(AcceleratedOptimizer): + def __init__(self, optimizer): + super().__init__(optimizer, device_placement=False, scaler=None) + + def zero_grad(self, set_to_none=None): + pass # `model(**batch)` is doing that automatically. Therefore, its implementation is not needed + + def step(self): + pass # `model(**batch)` is doing that automatically. Therefore, its implementation is not needed + + @property + def step_was_skipped(self): + """Whether or not the optimizer step was done, or skipped because of gradient overflow.""" + return self.optimizer.skipped_iter + + +def prepare_optimizer(accelerator, model): + accelerator.print("Preparing optimizer") + args = get_args() + return get_megatron_optimizer(model, args.no_wd_decay_cond, args.scale_lr_cond, args.lr_mult) + + +# scheduler utilities +class MegatronLMDummyScheduler: + """ + Dummy scheduler presents model parameters or param groups, this is primarily used to follow conventional training + loop when scheduler config is specified in the deepspeed config file. + + Args: + optimizer (`torch.optim.optimizer.Optimizer`): + The optimizer to wrap. + total_num_steps (int): + Total number of steps. + warmup_num_steps (int): + Number of steps for warmup. + **kwargs (additional keyword arguments, *optional*): + Other arguments. + """ + + def __init__(self, optimizer, total_num_steps=None, warmup_num_steps=0, **kwargs): + self.optimizer = optimizer + self.total_num_steps = total_num_steps + self.warmup_num_steps = warmup_num_steps + self.kwargs = kwargs + + +class MegatronLMSchedulerWrapper(AcceleratedScheduler): + def __init__(self, scheduler, optimizers): + super().__init__(scheduler, optimizers) + + def step(self, *args, **kwargs): + return # `model(**batch)` is doing that automatically. Therefore, its implementation is not needed + + +def prepare_scheduler(accelerator, optimizer, scheduler): + accelerator.print("Preparing scheduler") + scheduler = get_optimizer_param_scheduler(optimizer) + return scheduler + + +class AbstractTrainStep(ABC): + """Abstract class for batching, forward pass and loss handler.""" + + def __init__(self, name): + super().__init__() + self.name = name + + def get_batch_func(self, accelerator, megatron_dataset_flag): + pass + + def get_forward_step_func(self): + pass + + def get_loss_func(self, accelerator): + pass + + +class BertTrainStep(AbstractTrainStep): + """ + Bert train step class. + + Args: + args (`argparse.Namespace`): Megatron-LM arguments. + """ + + def __init__(self, accelerator, args): + super().__init__("BertTrainStep") + self.get_batch = self.get_batch_func(accelerator, args.megatron_dataset_flag) + self.loss_func = self.get_loss_func(accelerator, args.pretraining_flag, args.num_labels) + self.forward_step = self.get_forward_step_func(args.pretraining_flag, args.bert_binary_head) + if not args.model_return_dict: + self.model_output_class = None + else: + from transformers.modeling_outputs import SequenceClassifierOutput + + self.model_output_class = SequenceClassifierOutput + + def get_batch_func(self, accelerator, megatron_dataset_flag): + def get_batch_megatron(data_iterator): + """Build the batch.""" + + # Items and their type. + keys = ["text", "types", "labels", "is_random", "loss_mask", "padding_mask"] + datatype = torch.int64 + + # Broadcast data. + if data_iterator is not None: + data = next(data_iterator) + else: + data = None + data_b = tensor_parallel.broadcast_data(keys, data, datatype) + + # Unpack. + tokens = data_b["text"].long() + types = data_b["types"].long() + sentence_order = data_b["is_random"].long() + loss_mask = data_b["loss_mask"].float() + lm_labels = data_b["labels"].long() + padding_mask = data_b["padding_mask"].long() + + return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask + + def get_batch_transformer(data_iterator): + """Build the batch.""" + data = next(data_iterator) + data = send_to_device(data, torch.cuda.current_device()) + + # Unpack. + tokens = data["input_ids"].long() + padding_mask = data["attention_mask"].long() + if "token_type_ids" in data: + types = data["token_type_ids"].long() + else: + types = None + if "labels" in data: + lm_labels = data["labels"].long() + loss_mask = (data["labels"] != -100).to(torch.float) + else: + lm_labels = None + loss_mask = None + if "next_sentence_label" in data: + sentence_order = data["next_sentence_label"].long() + else: + sentence_order = None + + return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask + + if accelerator.state.megatron_lm_plugin.custom_get_batch_function is not None: + return accelerator.state.megatron_lm_plugin.custom_get_batch_function + if megatron_dataset_flag: + try: + # Use '--no-use-pep517 -e' to pip install nvidia's megatron from source + from pretrain_bert import get_batch + + return get_batch + except ImportError: + pass + return get_batch_megatron + else: + return get_batch_transformer + + def get_loss_func(self, accelerator, pretraining_flag, num_labels): + def loss_func_pretrain(loss_mask, sentence_order, output_tensor): + lm_loss_, sop_logits = output_tensor + + lm_loss_ = lm_loss_.float() + loss_mask = loss_mask.float() + lm_loss = torch.sum(lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum() + + if sop_logits is not None: + sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(), sentence_order.view(-1), ignore_index=-1) + sop_loss = sop_loss.float() + loss = lm_loss + sop_loss + averaged_losses = average_losses_across_data_parallel_group([lm_loss, sop_loss]) + return loss, {"lm loss": averaged_losses[0], "sop loss": averaged_losses[1]} + + else: + loss = lm_loss + averaged_losses = average_losses_across_data_parallel_group([lm_loss]) + return loss, {"lm loss": averaged_losses[0]} + + def loss_func_finetune(labels, logits): + if num_labels == 1: + # We are doing regression + loss_fct = MSELoss() + loss = loss_fct(logits.view(-1), labels.view(-1)) + elif self.num_labels > 1 and (labels.dtype in (torch.long, torch.int)): + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, num_labels), labels.view(-1)) + else: + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + averaged_losses = average_losses_across_data_parallel_group([loss]) + return loss, {"loss": averaged_losses[0]} + + if accelerator.state.megatron_lm_plugin.custom_loss_function is not None: + return accelerator.state.megatron_lm_plugin.custom_loss_function + if pretraining_flag: + return loss_func_pretrain + else: + return loss_func_finetune + + def get_forward_step_func(self, pretraining_flag, bert_binary_head): + def forward_step(data_iterator, model): + """Forward step.""" + tokens, types, sentence_order, loss_mask, labels, padding_mask = self.get_batch(data_iterator) + if not bert_binary_head: + types = None + # Forward pass through the model. + if pretraining_flag: + output_tensor = model(tokens, padding_mask, tokentype_ids=types, lm_labels=labels) + return output_tensor, partial(self.loss_func, loss_mask, sentence_order) + else: + logits = model(tokens, padding_mask, tokentype_ids=types) + return logits, partial(self.loss_func, labels) + + return forward_step + + +class GPTTrainStep(AbstractTrainStep): + """ + GPT train step class. + + Args: + args (`argparse.Namespace`): Megatron-LM arguments. + """ + + def __init__(self, accelerator, args): + super().__init__("GPTTrainStep") + self.get_batch = self.get_batch_func(accelerator, args.megatron_dataset_flag) + self.loss_func = self.get_loss_func(accelerator) + self.forward_step = self.get_forward_step_func() + if args.vocab_file is not None: + tokenizer = get_tokenizer() + self.eod_token = tokenizer.eod + self.eod_token = args.eos_token_id + self.pad_token = args.eos_token_id + self.reset_position_ids = args.reset_position_ids + self.reset_attention_mask = args.reset_attention_mask + self.eod_mask_loss = args.eod_mask_loss + if not args.model_return_dict: + self.model_output_class = None + else: + from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions + + self.model_output_class = CausalLMOutputWithCrossAttentions + + def get_batch_func(self, accelerator, megatron_dataset_flag): + def get_batch_megatron(data_iterator): + """Generate a batch""" + # Items and their type. + keys = ["text"] + datatype = torch.int64 + + # Broadcast data. + if data_iterator is not None: + data = next(data_iterator) + else: + data = None + data_b = tensor_parallel.broadcast_data(keys, data, datatype) + + # Unpack. + tokens_ = data_b["text"].long() + labels = tokens_[:, 1:].contiguous() + tokens = tokens_[:, :-1].contiguous() + + # Get the masks and position ids. + attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( + tokens, + eod_token=self.eod_token, + pad_token=self.eod_token, + reset_position_ids=self.reset_position_ids, + reset_attention_mask=self.reset_attention_mask, + eod_mask_loss=self.eod_mask_loss, + pad_mask_loss=True, + ) + return tokens, labels, loss_mask, attention_mask, position_ids + + def get_batch_transformer(data_iterator): + data = next(data_iterator) + data = {"input_ids": data["input_ids"]} + data = send_to_device(data, torch.cuda.current_device()) + + tokens_ = data["input_ids"].long() + padding = torch.zeros((tokens_.shape[0], 1), dtype=tokens_.dtype, device=tokens_.device) + self.eod_token + tokens_ = torch.concat([tokens_, padding], dim=1) + labels = tokens_[:, 1:].contiguous() + tokens = tokens_[:, :-1].contiguous() + # Get the masks and position ids. + attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( + tokens, + eod_token=self.eod_token, + pad_token=self.eod_token, + reset_position_ids=self.reset_position_ids, + reset_attention_mask=self.reset_attention_mask, + eod_mask_loss=self.eod_mask_loss, + pad_mask_loss=True, + ) + return tokens, labels, loss_mask, attention_mask, position_ids + + if accelerator.state.megatron_lm_plugin.custom_get_batch_function is not None: + return accelerator.state.megatron_lm_plugin.custom_get_batch_function + if megatron_dataset_flag: + try: + # Use '--no-use-pep517 -e' to pip install nvidia's megatron from source + from pretrain_gpt import get_batch + + return get_batch + except ImportError: + pass + return get_batch_megatron + else: + return get_batch_transformer + + def get_loss_func(self, accelerator): + args = get_args() + + def loss_func(loss_mask, output_tensor): + if args.return_logits: + losses, logits = output_tensor + else: + losses = output_tensor + losses = losses.float() + loss_mask = loss_mask.view(-1).float() + if args.context_parallel_size > 1: + loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), loss_mask.sum().view(1)]) + torch.distributed.all_reduce(loss, group=mpu.get_context_parallel_group()) + loss = loss[0] / loss[1] + else: + loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() + + # Check individual rank losses are not NaN prior to DP all-reduce. + if args.check_for_nan_in_loss_and_grad: + global_rank = torch.distributed.get_rank() + assert not loss.isnan(), ( + f"Rank {global_rank}: found NaN in local forward loss calculation. " + f"Device: {torch.cuda.current_device()}, node: {os.uname()[1]}" + ) + + # Reduce loss for logging. + averaged_loss = average_losses_across_data_parallel_group([loss]) + + output_dict = {"lm loss": averaged_loss[0]} + if args.return_logits: + output_dict.update({"logits": logits}) + return loss, output_dict + + if accelerator.state.megatron_lm_plugin.custom_loss_function is not None: + return accelerator.state.megatron_lm_plugin.custom_loss_function + return loss_func + + def get_forward_step_func(self): + def forward_step(data_iterator, model): + """Forward step.""" + # Get the batch. + tokens, labels, loss_mask, attention_mask, position_ids = self.get_batch(data_iterator) + output_tensor = model(tokens, position_ids, attention_mask, labels=labels) + + return output_tensor, partial(self.loss_func, loss_mask) + + return forward_step + + +class T5TrainStep(AbstractTrainStep): + """ + T5 train step class. + + Args: + args (`argparse.Namespace`): Megatron-LM arguments. + """ + + def __init__(self, accelerator, args): + super().__init__("T5TrainStep") + self.get_batch = self.get_batch_func(accelerator, args.megatron_dataset_flag) + self.loss_func = self.get_loss_func(accelerator) + self.forward_step = self.get_forward_step_func() + if not args.model_return_dict: + self.model_output_class = None + else: + from transformers.modeling_outputs import Seq2SeqLMOutput + + self.model_output_class = Seq2SeqLMOutput + + @staticmethod + def attn_mask_postprocess(attention_mask): + # We create a 3D attention mask from a 2D tensor mask. + # [b, 1, s] + attention_mask_b1s = attention_mask.unsqueeze(1) + # [b, s, 1] + attention_mask_bs1 = attention_mask.unsqueeze(2) + # [b, s, s] + attention_mask_bss = attention_mask_b1s * attention_mask_bs1 + # Convert attention mask to binary: + extended_attention_mask = attention_mask_bss < 0.5 + return extended_attention_mask + + @staticmethod + def get_decoder_mask(seq_length, device): + attention_mask = torch.tril(torch.ones((1, seq_length, seq_length), device=device)) + attention_mask = attention_mask < 0.5 + return attention_mask + + @staticmethod + def get_enc_dec_mask(attention_mask, dec_seq_length, device): + batch_size, _ = attention_mask.shape + # We create a 3D attention mask from a 2D tensor mask. + # [b, 1, s] + attention_mask_b1s = attention_mask.unsqueeze(1) + # [b, s, 1] + attention_mask_bs1 = torch.ones((batch_size, dec_seq_length, 1), device=device) + attention_mask_bss = attention_mask_bs1 * attention_mask_b1s + extended_attention_mask = attention_mask_bss < 0.5 + return extended_attention_mask + + def get_batch_func(self, accelerator, megatron_dataset_flag): + def get_batch_megatron(data_iterator): + """Build the batch.""" + + keys = ["text_enc", "text_dec", "labels", "loss_mask", "enc_mask", "dec_mask", "enc_dec_mask"] + datatype = torch.int64 + + # Broadcast data. + if data_iterator is not None: + data = next(data_iterator) + else: + data = None + data_b = tensor_parallel.broadcast_data(keys, data, datatype) + + # Unpack. + tokens_enc = data_b["text_enc"].long() + tokens_dec = data_b["text_dec"].long() + labels = data_b["labels"].long() + loss_mask = data_b["loss_mask"].float() + + enc_mask = data_b["enc_mask"] < 0.5 + dec_mask = data_b["dec_mask"] < 0.5 + enc_dec_mask = data_b["enc_dec_mask"] < 0.5 + + return tokens_enc, tokens_dec, loss_mask, labels, enc_mask, dec_mask, enc_dec_mask + + def get_batch_transformer(data_iterator): + """Build the batch.""" + data = next(data_iterator) + data = send_to_device(data, torch.cuda.current_device()) + + tokens_enc = data["input_ids"].long() + labels = data["labels"].long() + loss_mask = (labels != -100).to(torch.float) + if "decoder_input_ids" in data: + tokens_dec = data["decoder_input_ids"].long() + else: + tokens_dec = labels.new_zeros(labels.shape, device=labels.device, dtype=torch.long) + tokens_dec[..., 1:] = labels[..., :-1].clone() + tokens_dec[..., 0] = 0 + tokens_dec.masked_fill_(tokens_dec == -100, 0) + enc_mask = T5TrainStep.attn_mask_postprocess(data["attention_mask"].long()) + dec_mask = T5TrainStep.get_decoder_mask(tokens_dec.shape[1], tokens_dec.device) + enc_dec_mask = T5TrainStep.get_enc_dec_mask( + data["attention_mask"].long(), tokens_dec.shape[1], tokens_dec.device + ) + + return tokens_enc, tokens_dec, loss_mask, labels, enc_mask, dec_mask, enc_dec_mask + + if accelerator.state.megatron_lm_plugin.custom_get_batch_function is not None: + return accelerator.state.megatron_lm_plugin.custom_get_batch_function + if megatron_dataset_flag: + try: + # Use '--no-use-pep517 -e' to pip install nvidia's megatron from source + from pretrain_t5 import get_batch + + return get_batch + except ImportError: + pass + return get_batch_megatron + else: + return get_batch_transformer + + def get_loss_func(self, accelerator): + def loss_func(loss_mask, output_tensor): + lm_loss_ = output_tensor.float() + lm_loss = torch.sum(lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum() + + loss = lm_loss + averaged_losses = average_losses_across_data_parallel_group([lm_loss]) + + return loss, {"lm loss": averaged_losses[0]} + + if accelerator.state.megatron_lm_plugin.custom_loss_function is not None: + return accelerator.state.megatron_lm_plugin.custom_loss_function + return loss_func + + def get_forward_step_func(self): + def forward_step(data_iterator, model): + """Forward step.""" + # Get the batch. + tokens_enc, tokens_dec, loss_mask, lm_labels, enc_mask, dec_mask, enc_dec_mask = self.get_batch( + data_iterator + ) + # Forward model lm_labels + output_tensor = model( + tokens_enc, tokens_dec, enc_mask, dec_mask, enc_dec_mask, tokentype_ids=None, lm_labels=lm_labels + ) + + return output_tensor, partial(self.loss_func, loss_mask) + + return forward_step + + +def finish_mpu_init(): + # torch.distributed initialization + args = get_args() + # Pytorch distributed. + _initialize_distributed(None, None, None) + + # Random seeds for reproducibility. + if args.rank == 0: + print(f"> setting random seeds to {args.seed} ...") + _set_random_seed(args.seed, args.data_parallel_random_init) + + +# initialize megatron setup +def initialize(accelerator, extra_args_provider=None, args_defaults=None): + if args_defaults is None: + args_defaults = {} + accelerator.print("Initializing Megatron-LM") + assert torch.cuda.is_available(), "Megatron requires CUDA." + + # Parse arguments + args = parse_args(extra_args_provider, ignore_unknown_args=True) + + # Set defaults + for key, value in args_defaults.items(): + if getattr(args, key, None) is not None: + if args.rank == 0: + print( + f"WARNING: overriding default arguments for {key}:{getattr(args, key)} with {key}:{value}", + flush=True, + ) + setattr(args, key, value) + + if args.use_checkpoint_args or args_defaults.get("use_checkpoint_args", False): + assert args.load is not None, "--use-checkpoints-args requires --load argument" + load_args_from_checkpoint(args) + + validate_args(args) + + # set global args, build tokenizer, and set adlr-autoresume, + # tensorboard-writer, and timers. + set_global_variables(args, build_tokenizer=False) + + # Megatron's MPU is the master. Complete initialization right away. + finish_mpu_init() + + # Autoresume. + _init_autoresume() + + # Compile dependencies. + _compile_dependencies() + + # Set pytorch JIT layer fusion options and warmup JIT functions. + set_jit_fusion_options() + args = get_args() + if getattr(args, "padded_vocab_size", None) is None: + args.padded_vocab_size = _vocab_size_with_padding(args.orig_vocab_size, args) + if args.model_type_name == "bert" and args.pretraining_flag and args.num_labels == 2: + args.bert_binary_head = True + else: + args.bert_binary_head = False + args.iteration = 0 + + +class MegatronEngine(torch.nn.Module): + """ + Megatron-LM model wrapper + + Args: + accelerator (:class:`~accelerate.Accelerator`): The accelerator object to use. + model: Megatron-LM model + optimizer: Megatron-LM optimizer + lr_scheduler: Megatron-LM lr scheduler + """ + + def __init__(self, accelerator, model, optimizer, scheduler): + super().__init__() + self.module = model + self.base_model = model[0] + self.optimizer = optimizer + self.scheduler = scheduler + args = get_args() + if accelerator.state.megatron_lm_plugin.custom_train_step_class is not None: + self.train_step_handler = accelerator.state.megatron_lm_plugin.custom_train_step_class( + args, **accelerator.state.megatron_lm_plugin.custom_train_step_kwargs + ) + elif args.model_type_name == "bert": + self.train_step_handler = BertTrainStep(accelerator, args) + elif args.model_type_name == "gpt": + self.train_step_handler = GPTTrainStep(accelerator, args) + elif args.model_type_name == "t5": + self.train_step_handler = T5TrainStep(accelerator, args) + else: + raise ValueError(f"Unsupported model type: {args.model_type_name}") + self.optimizer.skipped_iter = False + + # Tracking loss. + self.total_loss_dict = {} + self.eval_total_loss_dict = {} + self.iteration = 0 + self.report_memory_flag = True + self.num_floating_point_operations_so_far = 0 + self.module_config = None + if args.tensorboard_dir is not None: + write_args_to_tensorboard() + + def get_module_config(self): + args = get_args() + config = get_model_config(self.module[0]) + # Setup some training config params + config.grad_scale_func = self.optimizer.scale_loss + if isinstance(self.module[0], LocalDDP) and args.overlap_grad_reduce: + assert config.no_sync_func is None, ( + "When overlap_grad_reduce is True, config.no_sync_func must be None; " + "a custom no_sync_func is not supported when overlapping grad-reduce" + ) + config.no_sync_func = [model_chunk.no_sync for model_chunk in self.module] + if len(self.module) == 1: + config.no_sync_func = config.no_sync_func[0] + if args.delay_grad_reduce: + config.grad_sync_func = [model_chunk.start_grad_sync for model_chunk in self.module] + if len(self.module) == 1: + config.grad_sync_func = config.grad_sync_func[0] + if args.overlap_param_gather and args.delay_param_gather: + config.param_sync_func = [ + lambda x: self.optimizer.finish_param_sync(model_index, x) for model_index in range(len(self.module)) + ] + if len(self.module) == 1: + config.param_sync_func = config.param_sync_func[0] + config.finalize_model_grads_func = finalize_model_grads + return config + + def train(self): + for model_module in self.module: + model_module.train() + + if self.module_config is None: + self.module_config = self.get_module_config() + + self.log_eval_results() + + def eval(self): + for model_module in self.module: + model_module.eval() + + if self.module_config is None: + self.module_config = self.get_module_config() + + def get_batch_data_iterator(self, batch_data): + args = get_args() + data_chunks = [] + if len(batch_data) > 0: + if args.num_micro_batches > 1: + for i in range(0, args.num_micro_batches): + data_chunks.append( + { + k: v[i * args.micro_batch_size : (i + 1) * args.micro_batch_size] + for k, v in batch_data.items() + } + ) + else: + data_chunks = [batch_data] + + if len(self.module) > 1: + batch_data_iterator = ( + [iter(data_chunks) for _ in range(len(self.module))] + if len(batch_data) > 0 + else [None] * len(self.module) + ) + else: + batch_data_iterator = iter(data_chunks) if len(batch_data) > 0 else None + return batch_data_iterator + + def train_step(self, **batch_data): + """ + Training step for Megatron-LM + + Args: + batch_data (:obj:`dict`): The batch data to train on. + """ + + batch_data_iterator = self.get_batch_data_iterator(batch_data) + + loss_reduced, skipped_iter, _, _, _, grad_norm, num_zeros_in_grad = train_step( + forward_step_func=self.train_step_handler.forward_step, + data_iterator=batch_data_iterator, + model=self.module, + optimizer=self.optimizer, + opt_param_scheduler=self.scheduler, + config=self.module_config, + forward_backward_func=get_forward_backward_func(), + ) + + self.optimizer.skipped_iter = skipped_iter == 1 + + return loss_reduced, skipped_iter, grad_norm, num_zeros_in_grad + + def eval_step(self, **batch_data): + """ + Evaluation step for Megatron-LM + + Args: + batch_data (:obj:`dict`): The batch data to evaluate on. + """ + + args = get_args() + batch_data_iterator = self.get_batch_data_iterator(batch_data) + forward_backward_func = get_forward_backward_func() + loss_dicts = forward_backward_func( + forward_step_func=self.train_step_handler.forward_step, + data_iterator=batch_data_iterator, + model=self.module, + num_microbatches=get_num_microbatches(), + seq_length=args.seq_length, + micro_batch_size=args.micro_batch_size, + forward_only=True, + ) + # Empty unused memory + if args.empty_unused_memory_level >= 1: + torch.cuda.empty_cache() + + args.consumed_valid_samples += ( + mpu.get_data_parallel_world_size() * args.micro_batch_size * get_num_microbatches() + ) + + if mpu.is_pipeline_last_stage(ignore_virtual=True): + # Average loss across microbatches. + loss_reduced = {} + for key in loss_dicts[0]: + losses_reduced_for_key = [x[key] for x in loss_dicts] + if len(losses_reduced_for_key[0].shape) == 0: + loss_reduced[key] = sum(losses_reduced_for_key) / len(losses_reduced_for_key) + else: + loss_reduced[key] = torch.concat(losses_reduced_for_key) + return loss_reduced + return {} + + def forward(self, **batch_data): + # During training, we use train_step() + # model(**batch_data) performs following operations by delegating it to `self.train_step`: + # 1. Prepare **batch_data for Tendor, Pipeline and Model Parallelism + # 2. Set grad to zero. + # 3. forward pass and backward pass using Pipeline Parallelism + # 4. Empty unused memory. + # 5. Reduce gradients. + # 6. Update parameters. + # 7. Gather params when using Distributed Optimizer (Data Parallelism). + # 8. Update learning rate if scheduler is specified. + # 9. Empty unused memory. + # 10. Average loss across microbatches and across DP ranks. + # + # During evaluation, we use eval_step() + args = get_args() + if self.module[0].training: + loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = self.train_step(**batch_data) + self.iteration += 1 + batch_size = mpu.get_data_parallel_world_size() * args.micro_batch_size * get_num_microbatches() + args.consumed_train_samples += batch_size + self.num_floating_point_operations_so_far += num_floating_point_operations(args, batch_size) + if args.tensorboard_dir is not None: + # Logging. + loss_scale = self.optimizer.get_loss_scale().item() + params_norm = None + if args.log_params_norm: + params_norm = calc_params_l2_norm(self.model) + self.report_memory_flag = training_log( + loss_dict, + self.total_loss_dict, + self.optimizer.param_groups[0]["lr"], + self.iteration, + loss_scale, + self.report_memory_flag, + skipped_iter, + grad_norm, + params_norm, + num_zeros_in_grad, + ) + else: + loss_dict = self.eval_step(**batch_data) + if args.tensorboard_dir is not None: + for key in loss_dict: + self.eval_total_loss_dict[key] = ( + self.eval_total_loss_dict.get(key, torch.cuda.FloatTensor([0.0])) + loss_dict[key] + ) + self.eval_total_loss_dict[key + "_num_iters"] = self.eval_total_loss_dict.get( + key + "_num_iters", torch.cuda.FloatTensor([0.0]) + ) + torch.cuda.FloatTensor([1.0]) + + loss = torch.tensor(0.0, device=torch.cuda.current_device()) + for key in loss_dict: + if len(loss_dict[key].shape) == 0: + loss += loss_dict[key] + + logits = None + if "logits" in loss_dict: + logits = loss_dict["logits"] + if self.train_step_handler.model_output_class is not None: + return self.train_step_handler.model_output_class(loss=loss, logits=logits) + return loss + + def log_eval_results(self): + args = get_args() + if args.tensorboard_dir is None or self.iteration == 0: + return + args = get_args() + writer = get_tensorboard_writer() + string = f"validation loss at iteration {self.iteration} | " + for key in self.eval_total_loss_dict: + if key.endswith("_num_iters"): + continue + value = self.eval_total_loss_dict[key] / self.eval_total_loss_dict[key + "_num_iters"] + string += f"{key} value: {value} | " + ppl = math.exp(min(20, value.item())) + if args.pretraining_flag: + string += f"{key} PPL: {ppl} | " + if writer: + writer.add_scalar(f"{key} validation", value.item(), self.iteration) + if args.pretraining_flag: + writer.add_scalar(f"{key} validation ppl", ppl, self.iteration) + + length = len(string) + 1 + print_rank_last("-" * length) + print_rank_last(string) + print_rank_last("-" * length) + self.eval_total_loss_dict = {} + + def save_checkpoint(self, output_dir): + self.log_eval_results() + args = get_args() + args.save = output_dir + torch.distributed.barrier() + save_checkpoint( + self.iteration, + self.module, + self.optimizer, + self.scheduler, + num_floating_point_operations_so_far=self.num_floating_point_operations_so_far, + ) + torch.distributed.barrier() + + def load_checkpoint(self, input_dir): + args = get_args() + args.load = input_dir + args.consumed_train_samples = 0 + args.consumed_valid_samples = 0 + torch.distributed.barrier() + iteration, num_floating_point_operations_so_far = load_checkpoint(self.module, self.optimizer, self.scheduler) + torch.distributed.barrier() + self.iteration = iteration + self.num_floating_point_operations_so_far = num_floating_point_operations_so_far + if args.fp16 and self.iteration == 0: + self.optimizer.reload_model_params() + + +# other utilities +def avg_losses_across_data_parallel_group(losses): + """ + Average losses across data parallel group. + + Args: + losses (List[Tensor]): List of losses to average across data parallel group. + """ + + return average_losses_across_data_parallel_group(losses) + + +def gather_across_data_parallel_groups(tensor): + """ + Recursively gather tensor in a nested list/tuple/dictionary of tensors from data parallel ranks. + + Args: + tensor (nested list/tuple/dictionary of `torch.Tensor`): + The data to gather across data parallel ranks. + + """ + + def _gpu_gather_one(tensor): + if tensor.ndim == 0: + tensor = tensor.clone()[None] + output_tensors = [ + torch.empty_like(tensor) + for _ in range(torch.distributed.get_world_size(group=mpu.get_data_parallel_group())) + ] + torch.distributed.all_gather(output_tensors, tensor, group=mpu.get_data_parallel_group()) + return torch.cat(output_tensors, dim=0) + + return recursively_apply(_gpu_gather_one, tensor, error_on_other_type=True) diff --git a/accelerate/utils/memory.py b/accelerate/utils/memory.py new file mode 100644 index 0000000000000000000000000000000000000000..f12d96fd188b70236e00302ed1916f510f0113db --- /dev/null +++ b/accelerate/utils/memory.py @@ -0,0 +1,184 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +A collection of utilities for ensuring that training can always occur. Heavily influenced by the +[toma](https://github.com/BlackHC/toma) library. +""" + +import functools +import gc +import inspect +from typing import Optional + +import torch + +from .imports import ( + is_cuda_available, + is_hpu_available, + is_mlu_available, + is_mps_available, + is_musa_available, + is_neuron_available, + is_npu_available, + is_sdaa_available, + is_xpu_available, +) + + +def clear_device_cache(garbage_collection=False): + """ + Clears the device cache by calling `torch.{backend}.empty_cache`. Can also run `gc.collect()`, but do note that + this is a *considerable* slowdown and should be used sparingly. + """ + if garbage_collection: + gc.collect() + + if is_xpu_available(): + torch.xpu.empty_cache() + elif is_mlu_available(): + torch.mlu.empty_cache() + elif is_sdaa_available(): + torch.sdaa.empty_cache() + elif is_musa_available(): + torch.musa.empty_cache() + elif is_npu_available(): + torch.npu.empty_cache() + elif is_mps_available(min_version="2.0"): + torch.mps.empty_cache() + elif is_cuda_available(): + torch.cuda.empty_cache() + elif is_hpu_available(): + # torch.hpu.empty_cache() # not available on hpu as it reserves all device memory for the current process + pass + elif is_neuron_available(): + # Not sure it actually does something, but adding for consistency with other backends + torch.neuron.empty_cache() + + +def release_memory(*objects): + """ + Releases memory from `objects` by setting them to `None` and calls `gc.collect()` and `torch.cuda.empty_cache()`. + Returned objects should be reassigned to the same variables. + + Args: + objects (`Iterable`): + An iterable of objects + Returns: + A list of `None` objects to replace `objects` + + Example: + + ```python + >>> import torch + >>> from accelerate.utils import release_memory + + >>> a = torch.ones(1000, 1000).cuda() + >>> b = torch.ones(1000, 1000).cuda() + >>> a, b = release_memory(a, b) + ``` + """ + if not isinstance(objects, list): + objects = list(objects) + for i in range(len(objects)): + objects[i] = None + clear_device_cache(garbage_collection=True) + return objects + + +def should_reduce_batch_size(exception: Exception) -> bool: + """ + Checks if `exception` relates to CUDA out-of-memory, XPU out-of-memory, CUDNN not supported, or CPU out-of-memory + + Args: + exception (`Exception`): + An exception + """ + _statements = [ + " out of memory.", # OOM for CUDA, HIP, XPU + "cuDNN error: CUDNN_STATUS_NOT_SUPPORTED.", # CUDNN SNAFU + "DefaultCPUAllocator: can't allocate memory", # CPU OOM + "FATAL ERROR :: MODULE:PT_DEVMEM Allocation failed", # HPU OOM + ] + if isinstance(exception, RuntimeError) and len(exception.args) == 1: + return any(err in exception.args[0] for err in _statements) + return False + + +def find_executable_batch_size( + function: Optional[callable] = None, + starting_batch_size: int = 128, + reduce_batch_size_fn: Optional[callable] = None, +): + """ + A basic decorator that will try to execute `function`. If it fails from exceptions related to out-of-memory or + CUDNN, the batch size is multiplied by 0.9 and passed to `function` + + `function` must take in a `batch_size` parameter as its first argument. + + Args: + function (`callable`, *optional*): + A function to wrap + starting_batch_size (`int`, *optional*): + The batch size to try and fit into memory + + Example: + + ```python + >>> from accelerate.utils import find_executable_batch_size + + + >>> @find_executable_batch_size(starting_batch_size=128) + ... def train(batch_size, model, optimizer): + ... ... + + + >>> train(model, optimizer) + ``` + """ + if function is None: + return functools.partial(find_executable_batch_size, starting_batch_size=starting_batch_size) + + batch_size = starting_batch_size + if reduce_batch_size_fn is None: + + def reduce_batch_size_fn(): + nonlocal batch_size + batch_size = int(batch_size * 0.9) + return batch_size + + def decorator(*args, **kwargs): + nonlocal batch_size + clear_device_cache(garbage_collection=True) + params = list(inspect.signature(function).parameters.keys()) + # Guard against user error + if len(params) < (len(args) + 1): + arg_str = ", ".join([f"{arg}={value}" for arg, value in zip(params[1:], args[1:])]) + raise TypeError( + f"Batch size was passed into `{function.__name__}` as the first argument when called." + f"Remove this as the decorator already does so: `{function.__name__}({arg_str})`" + ) + while True: + if batch_size == 0: + raise RuntimeError("No executable batch size found, reached zero.") + try: + return function(batch_size, *args, **kwargs) + except Exception as e: + if should_reduce_batch_size(e): + clear_device_cache(garbage_collection=True) + batch_size = reduce_batch_size_fn() + else: + raise + + return decorator diff --git a/accelerate/utils/modeling.py b/accelerate/utils/modeling.py new file mode 100644 index 0000000000000000000000000000000000000000..7bfe1ec8e6738d4be8ac4f95dbf359cca13be02b --- /dev/null +++ b/accelerate/utils/modeling.py @@ -0,0 +1,2187 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import gc +import inspect +import json +import logging +import os +import re +import shutil +import tempfile +import warnings +from collections import OrderedDict, defaultdict +from typing import Optional, Union + +import torch +from torch import distributed as dist +from torch import nn + +from ..state import AcceleratorState +from .constants import SAFE_WEIGHTS_NAME, WEIGHTS_NAME +from .dataclasses import AutocastKwargs, CustomDtype, DistributedType +from .imports import ( + is_hpu_available, + is_mlu_available, + is_mps_available, + is_musa_available, + is_npu_available, + is_peft_available, + is_sdaa_available, + is_torch_xla_available, + is_xpu_available, +) +from .memory import clear_device_cache +from .offload import load_offloaded_weight, offload_weight, save_offload_index +from .tqdm import is_tqdm_available, tqdm +from .versions import is_torch_version + + +if is_npu_available(check_device=False): + import torch_npu # noqa: F401 + +if is_mlu_available(check_device=False): + import torch_mlu # noqa: F401 + +if is_sdaa_available(check_device=False): + import torch_sdaa # noqa: F401 + +if is_musa_available(check_device=False): + import torch_musa # noqa: F401 + +from safetensors import safe_open +from safetensors.torch import load_file as safe_load_file + + +WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json" + +logger = logging.getLogger(__name__) + + +def is_peft_model(model): + from .other import extract_model_from_parallel + + if is_peft_available(): + from peft import PeftModel + + return is_peft_available() and isinstance(extract_model_from_parallel(model), PeftModel) + + +def check_device_same(first_device, second_device): + """ + Utility method to check if two `torch` devices are similar. When dealing torch accelerator devices(e.g. cuda, xpu), + torch throws `False` for `torch.device("cuda") == torch.device("cuda:0")` whereas they should be the same + + Args: + first_device (`torch.device`): + First device to check + second_device (`torch.device`): + Second device to check + """ + if first_device.type != second_device.type: + return False + + if first_device.type != "cpu" and first_device.index is None: + # In case the first_device is an torch accelerator device(e.g. cuda, xpu) and have + # the index attribute set to `None`, default it to `0` + first_device = torch.device(first_device.type, index=0) + + if second_device.type != "cpu" and second_device.index is None: + # In case the second_device is an torch accelerator device(e.g. cuda, xpu) and have + # the index attribute set to `None`, default it to `0` + second_device = torch.device(second_device.type, index=0) + + return first_device == second_device + + +def convert_file_size_to_int(size: Union[int, str]): + """ + Converts a size expressed as a string with digits an unit (like `"5MB"`) to an integer (in bytes). + + Args: + size (`int` or `str`): The size to convert. Will be directly returned if an `int`. + + Example: + + ```py + >>> convert_file_size_to_int("1MiB") + 1048576 + ``` + """ + mem_size = -1 + err_msg = ( + f"`size` {size} is not in a valid format. Use an integer for bytes, or a string with an unit (like '5.0GB')." + ) + try: + if isinstance(size, int): + mem_size = size + elif size.upper().endswith("GIB"): + mem_size = int(float(size[:-3]) * (2**30)) + elif size.upper().endswith("MIB"): + mem_size = int(float(size[:-3]) * (2**20)) + elif size.upper().endswith("KIB"): + mem_size = int(float(size[:-3]) * (2**10)) + elif size.upper().endswith("GB"): + int_size = int(float(size[:-2]) * (10**9)) + mem_size = int_size // 8 if size.endswith("b") else int_size + elif size.upper().endswith("MB"): + int_size = int(float(size[:-2]) * (10**6)) + mem_size = int_size // 8 if size.endswith("b") else int_size + elif size.upper().endswith("KB"): + int_size = int(float(size[:-2]) * (10**3)) + mem_size = int_size // 8 if size.endswith("b") else int_size + except ValueError: + raise ValueError(err_msg) + + if mem_size < 0: + raise ValueError(err_msg) + return mem_size + + +def dtype_byte_size(dtype: torch.dtype): + """ + Returns the size (in bytes) occupied by one parameter of type `dtype`. + + Example: + + ```py + >>> dtype_byte_size(torch.float32) + 4 + ``` + """ + if dtype == torch.bool: + return 1 / 8 + elif dtype == CustomDtype.INT2: + return 1 / 4 + elif dtype == CustomDtype.INT4: + return 1 / 2 + elif dtype == CustomDtype.FP8: + return 1 + elif is_torch_version(">=", "2.1.0") and dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: + return 1 + bit_search = re.search(r"[^\d](\d+)$", str(dtype)) + if bit_search is None: + raise ValueError(f"`dtype` is not a valid dtype: {dtype}.") + bit_size = int(bit_search.groups()[0]) + return bit_size // 8 + + +def id_tensor_storage(tensor: torch.Tensor) -> tuple[torch.device, int, int]: + """ + Unique identifier to a tensor storage. Multiple different tensors can share the same underlying storage. For + example, "meta" tensors all share the same storage, and thus their identifier will all be equal. This identifier is + guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with + non-overlapping lifetimes may have the same id. + """ + _SIZE = { + torch.int64: 8, + torch.float32: 4, + torch.int32: 4, + torch.bfloat16: 2, + torch.float16: 2, + torch.int16: 2, + torch.uint8: 1, + torch.int8: 1, + torch.bool: 1, + torch.float64: 8, + } + try: + storage_ptr = tensor.untyped_storage().data_ptr() + storage_size = tensor.untyped_storage().nbytes() + except Exception: + try: + # Fallback for torch==1.10 + storage_ptr = tensor.storage().data_ptr() + storage_size = tensor.storage().size() * _SIZE[tensor.dtype] + except NotImplementedError: + # Fallback for meta storage + storage_ptr = 0 + # On torch >=2.0 this is the tensor size + storage_size = tensor.nelement() * _SIZE[tensor.dtype] + + return tensor.device, storage_ptr, storage_size + + +def set_module_tensor_to_device( + module: nn.Module, + tensor_name: str, + device: Union[int, str, torch.device], + value: Optional[torch.Tensor] = None, + dtype: Optional[Union[str, torch.dtype]] = None, + fp16_statistics: Optional[torch.HalfTensor] = None, + tied_params_map: Optional[dict[int, dict[torch.device, torch.Tensor]]] = None, + non_blocking: bool = False, + clear_cache: bool = True, +): + """ + A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing + `param.to(device)` creates a new tensor not linked to the parameter, which is why we need this function). + + Args: + module (`torch.nn.Module`): + The module in which the tensor we want to move lives. + tensor_name (`str`): + The full name of the parameter/buffer. + device (`int`, `str` or `torch.device`): + The device on which to set the tensor. + value (`torch.Tensor`, *optional*): + The value of the tensor (useful when going from the meta device to any other device). + dtype (`torch.dtype`, *optional*): + If passed along the value of the parameter will be cast to this `dtype`. Otherwise, `value` will be cast to + the dtype of the existing parameter in the model. + fp16_statistics (`torch.HalfTensor`, *optional*): + The list of fp16 statistics to set on the module, used for 8 bit model serialization. + tied_params_map (Dict[int, Dict[torch.device, torch.Tensor]], *optional*, defaults to `None`): + A map of current data pointers to dictionaries of devices to already dispatched tied weights. For a given + execution device, this parameter is useful to reuse the first available pointer of a shared weight on the + device for all others, instead of duplicating memory. + non_blocking (`bool`, *optional*, defaults to `False`): + If `True`, the device transfer will be asynchronous with respect to the host, if possible. + clear_cache (`bool`, *optional*, defaults to `True`): + Whether or not to clear the device cache after setting the tensor on the device. + """ + # Recurse if needed + if "." in tensor_name: + splits = tensor_name.split(".") + for split in splits[:-1]: + new_module = getattr(module, split) + if new_module is None: + raise ValueError(f"{module} has no attribute {split}.") + module = new_module + tensor_name = splits[-1] + + if tensor_name not in module._parameters and tensor_name not in module._buffers: + raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.") + is_buffer = tensor_name in module._buffers + old_value = getattr(module, tensor_name) + + # Treat the case where old_value (or a custom `value`, typically offloaded to RAM/disk) belongs to a tied group, and one of the weight + # in the tied group has already been dispatched to the device, by avoiding reallocating memory on the device and just copying the pointer. + if ( + value is not None + and tied_params_map is not None + and value.data_ptr() in tied_params_map + and device in tied_params_map[value.data_ptr()] + ): + module._parameters[tensor_name] = tied_params_map[value.data_ptr()][device] + return + elif ( + tied_params_map is not None + and old_value.data_ptr() in tied_params_map + and device in tied_params_map[old_value.data_ptr()] + ): + module._parameters[tensor_name] = tied_params_map[old_value.data_ptr()][device] + return + + if old_value.device == torch.device("meta") and device not in ["meta", torch.device("meta")] and value is None: + raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {device}.") + + param = module._parameters[tensor_name] if tensor_name in module._parameters else None + param_cls = type(param) + + if value is not None: + # We can expect mismatches when using bnb 4bit since Params4bit will reshape and pack the weights. + # In other cases, we want to make sure we're not loading checkpoints that do not match the config. + if old_value.shape != value.shape and param_cls.__name__ != "Params4bit": + raise ValueError( + f'Trying to set a tensor of shape {value.shape} in "{tensor_name}" (which has shape {old_value.shape}), this looks incorrect.' + ) + + if dtype is None: + # For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model + value = value.to(old_value.dtype, non_blocking=non_blocking) + elif not str(value.dtype).startswith(("torch.uint", "torch.int", "torch.bool")): + value = value.to(dtype, non_blocking=non_blocking) + + device_quantization = None + with torch.no_grad(): + # leave it on cpu first before moving them to device + # # fix the case where the device is meta, we don't want to put it on cpu because there is no data =0 + if ( + param is not None + and param.device.type not in ("cuda", "xpu") + and torch.device(device).type in ("cuda", "xpu") + and param_cls.__name__ in ["Int8Params", "FP4Params", "Params4bit"] + ): + device_quantization = device + device = "cpu" + # `torch.Tensor.to()` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)). + if isinstance(device, int): + if is_npu_available(): + device = f"npu:{device}" + elif is_mlu_available(): + device = f"mlu:{device}" + elif is_sdaa_available(): + device = f"sdaa:{device}" + elif is_musa_available(): + device = f"musa:{device}" + elif is_hpu_available(): + device = "hpu" + if "xpu" in str(device) and not is_xpu_available(): + raise ValueError(f'{device} is not available, you should use device="cpu" instead') + if value is None: + new_value = old_value.to(device, non_blocking=non_blocking) + if dtype is not None and device in ["meta", torch.device("meta")]: + if not str(old_value.dtype).startswith(("torch.uint", "torch.int", "torch.bool")): + new_value = new_value.to(dtype, non_blocking=non_blocking) + + if not is_buffer: + module._parameters[tensor_name] = param_cls(new_value, requires_grad=old_value.requires_grad) + elif isinstance(value, torch.Tensor): + new_value = value.to(device, non_blocking=non_blocking) + else: + new_value = torch.tensor(value, device=device) + if device_quantization is not None: + device = device_quantization + if is_buffer: + module._buffers[tensor_name] = new_value + elif value is not None or not check_device_same(torch.device(device), module._parameters[tensor_name].device): + param_cls = type(module._parameters[tensor_name]) + kwargs = module._parameters[tensor_name].__dict__ + if param_cls.__name__ in ["Int8Params", "FP4Params", "Params4bit"]: + if param_cls.__name__ == "Int8Params" and new_value.dtype == torch.float32: + # downcast to fp16 if any - needed for 8bit serialization + new_value = new_value.to(torch.float16, non_blocking=non_blocking) + # quantize module that are going to stay on the cpu so that we offload quantized weights + if device == "cpu" and param_cls.__name__ == "Int8Params": + new_value = param_cls(new_value, requires_grad=old_value.requires_grad, **kwargs).to(0).to("cpu") + new_value.CB = new_value.CB.to("cpu") + new_value.SCB = new_value.SCB.to("cpu") + else: + new_value = param_cls(new_value, requires_grad=old_value.requires_grad, **kwargs).to( + device, non_blocking=non_blocking + ) + elif param_cls.__name__ in ["QTensor", "QBitsTensor"]: + new_value = torch.nn.Parameter(new_value, requires_grad=old_value.requires_grad).to( + device, non_blocking=non_blocking + ) + elif param_cls.__name__ in ["AffineQuantizedTensor"]: + new_value = new_value.to(device, non_blocking=non_blocking) + else: + new_value = param_cls(new_value, requires_grad=old_value.requires_grad).to( + device, non_blocking=non_blocking + ) + + module._parameters[tensor_name] = new_value + if fp16_statistics is not None: + module._parameters[tensor_name].SCB = fp16_statistics.to(device, non_blocking=non_blocking) + del fp16_statistics + # as we put the weight to meta, it doesn't have SCB attr anymore. make sure that it is not a meta weight + if ( + module.__class__.__name__ == "Linear8bitLt" + and getattr(module.weight, "SCB", None) is None + and str(module.weight.device) != "meta" + ): + # quantize only if necessary + device_index = torch.device(device).index if torch.device(device).type in ["cuda", "xpu"] else None + if not getattr(module.weight, "SCB", None) and device_index is not None: + if module.bias is not None and module.bias.device.type != "meta": + # if a bias exists, we need to wait until the bias is set on the correct device + module = module.to(device_index) + elif module.bias is None: + # if no bias exists, we can quantize right away + module = module.to(device_index) + elif ( + module.__class__.__name__ == "Linear4bit" + and getattr(module.weight, "quant_state", None) is None + and str(module.weight.device) != "meta" + ): + # quantize only if necessary + device_index = torch.device(device).index if torch.device(device).type in ["cuda", "xpu"] else None + if not getattr(module.weight, "quant_state", None) and device_index is not None: + module.weight = module.weight.to(device_index) + + # clean pre and post forward hook + if clear_cache and device not in ("cpu", "meta"): + clear_device_cache() + + # When handling tied weights, we update tied_params_map to keep track of the tied weights that have already been allocated on the device in + # order to avoid duplicating memory, see above. + if ( + tied_params_map is not None + and old_value.data_ptr() in tied_params_map + and device not in tied_params_map[old_value.data_ptr()] + ): + tied_params_map[old_value.data_ptr()][device] = new_value + elif ( + value is not None + and tied_params_map is not None + and value.data_ptr() in tied_params_map + and device not in tied_params_map[value.data_ptr()] + ): + tied_params_map[value.data_ptr()][device] = new_value + + +def named_module_tensors( + module: nn.Module, include_buffers: bool = True, recurse: bool = False, remove_non_persistent: bool = False +): + """ + A helper function that gathers all the tensors (parameters + buffers) of a given module. If `include_buffers=True` + it's the same as doing `module.named_parameters(recurse=recurse) + module.named_buffers(recurse=recurse)`. + + Args: + module (`torch.nn.Module`): + The module we want the tensors on. + include_buffer (`bool`, *optional*, defaults to `True`): + Whether or not to include the buffers in the result. + recurse (`bool`, *optional`, defaults to `False`): + Whether or not to go look in every submodule or just return the direct parameters and buffers. + remove_non_persistent (`bool`, *optional*, defaults to `False`): + Whether or not to remove the non persistent buffer from the buffers. Useful only when include_buffers = + True + """ + yield from module.named_parameters(recurse=recurse) + + if include_buffers: + non_persistent_buffers = set() + if remove_non_persistent: + non_persistent_buffers = get_non_persistent_buffers(module, recurse=recurse) + for named_buffer in module.named_buffers(recurse=recurse): + name, _ = named_buffer + if name not in non_persistent_buffers: + yield named_buffer + + +def get_non_persistent_buffers(module: nn.Module, recurse: bool = False, fqns: bool = False): + """ + Gather all non persistent buffers of a given modules into a set + + Args: + module (`nn.Module`): + The module we want the non persistent buffers on. + recurse (`bool`, *optional*, defaults to `False`): + Whether or not to go look in every submodule or just return the direct non persistent buffers. + fqns (`bool`, *optional*, defaults to `False`): + Whether or not to return the fully-qualified names of the non persistent buffers. + """ + + non_persistent_buffers_set = module._non_persistent_buffers_set + if recurse: + for n, m in module.named_modules(): + if fqns: + non_persistent_buffers_set |= {n + "." + b for b in m._non_persistent_buffers_set} + else: + non_persistent_buffers_set |= m._non_persistent_buffers_set + + return non_persistent_buffers_set + + +def check_tied_parameters_in_config(model: nn.Module): + """ + Check if there is any indication in the given model that some weights should be tied. + + Args: + model (`torch.nn.Module`): The model to inspect + + Returns: + bool: True if the model needs to have tied weights + """ + + # based on model.tie_weights() method + has_tied_word_embedding = False + has_tied_encoder_decoder = False + has_tied_module = False + + if "PreTrainedModel" in [c.__name__ for c in inspect.getmro(model.__class__)]: + has_tied_word_embedding = False + model_decoder_config = None + if hasattr(model, "config"): + model_decoder_config = ( + model.config.get_text_config(decoder=True) + if hasattr(model.config, "get_text_config") + else model.config + ) + has_tied_word_embedding = ( + model_decoder_config is not None + and getattr(model_decoder_config, "tie_word_embeddings", False) + and model.get_output_embeddings() + ) + + has_tied_encoder_decoder = ( + hasattr(model, "config") + and getattr(model.config, "is_encoder_decoder", False) + and getattr(model.config, "tie_encoder_decoder", False) + ) + has_tied_module = any(hasattr(module, "_tie_weights") for module in model.modules()) + return any([has_tied_word_embedding, has_tied_encoder_decoder, has_tied_module]) + + +def _get_param_device(param, device_map): + if param in device_map: + return device_map[param] + parent_param = ".".join(param.split(".")[:-1]) + if parent_param == param: + raise ValueError(f"The `device_map` does not contain the module {param}.") + else: + return _get_param_device(parent_param, device_map) + + +def check_tied_parameters_on_same_device(tied_params, device_map): + """ + Check if tied parameters are on the same device + + Args: + tied_params (`List[List[str]]`): + A list of lists of parameter names being all tied together. + + device_map (`Dict[str, Union[int, str, torch.device]]`): + A map that specifies where each submodule should go. + + """ + for tie_param in tied_params: + tie_param_devices = {} + for param in tie_param: + tie_param_devices[param] = _get_param_device(param, device_map) + if len(set(tie_param_devices.values())) > 1: + logger.warning( + f"Tied parameters are on different devices: {tie_param_devices}. " + "Please modify your custom device map or set `device_map='auto'`. " + ) + + +def find_tied_parameters(model: torch.nn.Module, **kwargs) -> list[list[str]]: + """ + Find the tied parameters in a given model. + + + + The signature accepts keyword arguments, but they are for the recursive part of this function and you should ignore + them. + + + + Args: + model (`torch.nn.Module`): The model to inspect. + + Returns: + List[List[str]]: A list of lists of parameter names being all tied together. + + Example: + + ```py + >>> from collections import OrderedDict + >>> import torch.nn as nn + + >>> model = nn.Sequential(OrderedDict([("linear1", nn.Linear(4, 4)), ("linear2", nn.Linear(4, 4))])) + >>> model.linear2.weight = model.linear1.weight + >>> find_tied_parameters(model) + [['linear1.weight', 'linear2.weight']] + ``` + """ + + # get ALL model parameters and their names + all_named_parameters = {name: param for name, param in model.named_parameters(remove_duplicate=False)} + + # get ONLY unique named parameters, + # if parameter is tied and have multiple names, it will be included only once + no_duplicate_named_parameters = {name: param for name, param in model.named_parameters(remove_duplicate=True)} + + # the difference of the two sets will give us the tied parameters + tied_param_names = set(all_named_parameters.keys()) - set(no_duplicate_named_parameters.keys()) + + # 'tied_param_names' contains the names of parameters that are tied in the model, but we do not know + # which names refer to the same parameter. To identify this, we need to group them together. + tied_param_groups = {} + for tied_param_name in tied_param_names: + tied_param = all_named_parameters[tied_param_name] + for param_name, param in no_duplicate_named_parameters.items(): + # compare if parameters are the same, if so, group their names together + if param is tied_param: + if param_name not in tied_param_groups: + tied_param_groups[param_name] = [] + tied_param_groups[param_name].append(tied_param_name) + + return [sorted([weight] + list(set(tied))) for weight, tied in tied_param_groups.items()] + + +def retie_parameters(model, tied_params): + """ + Reties tied parameters in a given model if the link was broken (for instance when adding hooks). + + Args: + model (`torch.nn.Module`): + The model in which to retie parameters. + tied_params (`List[List[str]]`): + A mapping parameter name to tied parameter name as obtained by `find_tied_parameters`. + """ + for tied_group in tied_params: + param_to_tie = None + # two loops : the first one to set param_to_tie , the second one to change the values of tied_group + for param_name in tied_group: + module = model + splits = param_name.split(".") + for split in splits[:-1]: + module = getattr(module, split) + param = getattr(module, splits[-1]) + if param_to_tie is None and param.device != torch.device("meta"): + param_to_tie = param + break + if param_to_tie is not None: + for param_name in tied_group: + module = model + splits = param_name.split(".") + for split in splits[:-1]: + module = getattr(module, split) + setattr(module, splits[-1], param_to_tie) + + +def _get_proper_dtype(dtype: Union[str, torch.device]) -> torch.dtype: + """ + Just does torch.dtype(dtype) if necessary. + """ + if isinstance(dtype, str): + # We accept "torch.float16" or just "float16" + dtype = dtype.replace("torch.", "") + dtype = getattr(torch, dtype) + return dtype + + +def compute_module_sizes( + model: nn.Module, + dtype: Optional[Union[str, torch.device]] = None, + special_dtypes: Optional[dict[str, Union[str, torch.device]]] = None, + buffers_only: bool = False, +): + """ + Compute the size of each submodule of a given model. + """ + if dtype is not None: + dtype = _get_proper_dtype(dtype) + dtype_size = dtype_byte_size(dtype) + if special_dtypes is not None: + special_dtypes = {key: _get_proper_dtype(dtyp) for key, dtyp in special_dtypes.items()} + special_dtypes_size = {key: dtype_byte_size(dtyp) for key, dtyp in special_dtypes.items()} + module_sizes = defaultdict(int) + + module_list = [] + + if not buffers_only: + module_list = named_module_tensors(model, recurse=True) + else: + module_list = model.named_buffers(recurse=True) + + for name, tensor in module_list: + if special_dtypes is not None and name in special_dtypes: + size = tensor.numel() * special_dtypes_size[name] + elif dtype is None: + size = tensor.numel() * dtype_byte_size(tensor.dtype) + elif str(tensor.dtype).startswith(("torch.uint", "torch.int", "torch.bool")): + # According to the code in set_module_tensor_to_device, these types won't be converted + # so use their original size here + size = tensor.numel() * dtype_byte_size(tensor.dtype) + else: + size = tensor.numel() * min(dtype_size, dtype_byte_size(tensor.dtype)) + name_parts = name.split(".") + for idx in range(len(name_parts) + 1): + module_sizes[".".join(name_parts[:idx])] += size + + return module_sizes + + +def compute_module_total_buffer_size( + model: nn.Module, + dtype: Optional[Union[str, torch.device]] = None, + special_dtypes: Optional[dict[str, Union[str, torch.device]]] = None, +): + """ + Compute the total size of buffers in each submodule of a given model. + """ + module_sizes = compute_module_sizes(model, dtype=dtype, special_dtypes=special_dtypes, buffers_only=True) + return module_sizes.get("", 0) + + +def get_max_layer_size( + modules: list[tuple[str, torch.nn.Module]], module_sizes: dict[str, int], no_split_module_classes: list[str] +): + """ + Utility function that will scan a list of named modules and return the maximum size used by one full layer. The + definition of a layer being: + - a module with no direct children (just parameters and buffers) + - a module whose class name is in the list `no_split_module_classes` + + Args: + modules (`List[Tuple[str, torch.nn.Module]]`): + The list of named modules where we want to determine the maximum layer size. + module_sizes (`Dict[str, int]`): + A dictionary mapping each layer name to its size (as generated by `compute_module_sizes`). + no_split_module_classes (`List[str]`): + A list of class names for layers we don't want to be split. + + Returns: + `Tuple[int, List[str]]`: The maximum size of a layer with the list of layer names realizing that maximum size. + """ + max_size = 0 + layer_names = [] + modules_to_treat = modules.copy() + while len(modules_to_treat) > 0: + module_name, module = modules_to_treat.pop(0) + modules_children = list(module.named_children()) if isinstance(module, torch.nn.Module) else [] + if len(modules_children) == 0 or module.__class__.__name__ in no_split_module_classes: + # No splitting this one so we compare to the max_size + size = module_sizes[module_name] + if size > max_size: + max_size = size + layer_names = [module_name] + elif size == max_size: + layer_names.append(module_name) + else: + modules_to_treat = [(f"{module_name}.{n}", v) for n, v in modules_children] + modules_to_treat + return max_size, layer_names + + +def get_max_memory(max_memory: Optional[dict[Union[int, str], Union[int, str]]] = None): + """ + Get the maximum memory available if nothing is passed, converts string to int otherwise. + """ + import psutil + + if max_memory is None: + max_memory = {} + # Make sure device is initialized on each device to have the right memory info. + if is_npu_available(): + for i in range(torch.npu.device_count()): + try: + _ = torch.tensor(0, device=torch.device("npu", i)) + max_memory[i] = torch.npu.mem_get_info(i)[0] + except Exception: + logger.info(f"Device {i} seems unavailable, Proceeding to check subsequent devices.") + continue + elif is_mlu_available(): + for i in range(torch.mlu.device_count()): + try: + _ = torch.tensor(0, device=torch.device("mlu", i)) + max_memory[i] = torch.mlu.mem_get_info(i)[0] + except Exception: + logger.info(f"Device {i} seems unavailable, Proceeding to check subsequent devices.") + continue + elif is_sdaa_available(): + for i in range(torch.sdaa.device_count()): + try: + _ = torch.tensor(0, device=torch.device("sdaa", i)) + max_memory[i] = torch.sdaa.mem_get_info(i)[0] + except Exception: + logger.info(f"Device {i} seems unavailable, Proceeding to check subsequent devices.") + continue + elif is_musa_available(): + for i in range(torch.musa.device_count()): + try: + _ = torch.tensor(0, device=torch.device("musa", i)) + max_memory[i] = torch.musa.mem_get_info(i)[0] + except Exception: + logger.info(f"Device {i} seems unavailable, Proceeding to check subsequent devices.") + continue + elif is_xpu_available(): + for i in range(torch.xpu.device_count()): + try: + _ = torch.tensor(0, device=torch.device("xpu", i)) + max_memory[i] = torch.xpu.mem_get_info(i)[0] + except Exception: + logger.info(f"Device {i} seems unavailable, Proceeding to check subsequent devices.") + continue + elif is_hpu_available(): + for i in range(torch.hpu.device_count()): + try: + _ = torch.tensor(0, device=torch.device("hpu", i)) + max_memory[i] = torch.hpu.mem_get_info(i)[0] + except Exception: + logger.info(f"Device {i} seems unavailable, Proceeding to check subsequent devices.") + continue + else: + for i in range(torch.cuda.device_count()): + try: + _ = torch.tensor([0], device=i) + max_memory[i] = torch.cuda.mem_get_info(i)[0] + except Exception: + logger.info(f"Device {i} seems unavailable, Proceeding to check subsequent devices.") + continue + # allocate everything in the mps device as the RAM is shared + if is_mps_available(): + max_memory["mps"] = psutil.virtual_memory().available + else: + max_memory["cpu"] = psutil.virtual_memory().available + return max_memory + + for key in max_memory: + if isinstance(max_memory[key], str): + max_memory[key] = convert_file_size_to_int(max_memory[key]) + + # Need to sort the device by type to make sure that we allocate the gpu first. + # As gpu/npu/xpu are represented by int, we need to sort them first. + gpu_devices = [k for k in max_memory.keys() if isinstance(k, int)] + gpu_devices.sort() + # check if gpu/npu/xpu devices are available and if not, throw a warning + if is_npu_available(): + num_devices = torch.npu.device_count() + elif is_mlu_available(): + num_devices = torch.mlu.device_count() + elif is_sdaa_available(): + num_devices = torch.sdaa.device_count() + elif is_musa_available(): + num_devices = torch.musa.device_count() + elif is_xpu_available(): + num_devices = torch.xpu.device_count() + elif is_hpu_available(): + num_devices = torch.hpu.device_count() + else: + num_devices = torch.cuda.device_count() + for device in gpu_devices: + if device >= num_devices or device < 0: + logger.warning(f"Device {device} is not available, available devices are {list(range(num_devices))}") + # Add the other devices in the preset order if they are available + all_devices = gpu_devices + [k for k in ["mps", "cpu", "disk"] if k in max_memory.keys()] + # Raise an error if a device is not recognized + for k in max_memory.keys(): + if k not in all_devices: + raise ValueError( + f"Device {k} is not recognized, available devices are integers(for GPU/XPU), 'mps', 'cpu' and 'disk'" + ) + max_memory = {k: max_memory[k] for k in all_devices} + + return max_memory + + +def clean_device_map(device_map: dict[str, Union[int, str, torch.device]], module_name: str = ""): + """ + Cleans a device_map by grouping all submodules that go on the same device together. + """ + # Get the value of the current module and if there is only one split across several keys, regroup it. + prefix = "" if module_name == "" else f"{module_name}." + values = [v for k, v in device_map.items() if k.startswith(prefix)] + if len(set(values)) == 1 and len(values) > 1: + for k in [k for k in device_map if k.startswith(prefix)]: + del device_map[k] + device_map[module_name] = values[0] + + # Recurse over the children + children_modules = [k for k in device_map.keys() if k.startswith(prefix) and len(k) > len(module_name)] + idx = len(module_name.split(".")) + 1 if len(module_name) > 0 else 1 + children_modules = set(".".join(k.split(".")[:idx]) for k in children_modules) + for child in children_modules: + clean_device_map(device_map, module_name=child) + + return device_map + + +def load_offloaded_weights(model, index, offload_folder): + """ + Loads the weights from the offload folder into the model. + + Args: + model (`torch.nn.Module`): + The model to load the weights into. + index (`dict`): + A dictionary containing the parameter name and its metadata for each parameter that was offloaded from the + model. + offload_folder (`str`): + The folder where the offloaded weights are stored. + """ + if index is None or len(index) == 0: + # Nothing to do + return + for param_name, metadata in index.items(): + if "SCB" in param_name: + continue + fp16_statistics = None + if "weight" in param_name and param_name.replace("weight", "SCB") in index.keys(): + weight_name = param_name.replace("weight", "SCB") + fp16_statistics = load_offloaded_weight( + os.path.join(offload_folder, f"{weight_name}.dat"), index[weight_name] + ) + tensor_file = os.path.join(offload_folder, f"{param_name}.dat") + weight = load_offloaded_weight(tensor_file, metadata) + set_module_tensor_to_device(model, param_name, "cpu", value=weight, fp16_statistics=fp16_statistics) + + +def get_module_leaves(module_sizes): + module_children = {} + for module in module_sizes: + if module == "" or "." not in module: + continue + parent = module.rsplit(".", 1)[0] + module_children[parent] = module_children.get(parent, 0) + 1 + leaves = [module for module in module_sizes if module_children.get(module, 0) == 0 and module != ""] + return leaves + + +def get_balanced_memory( + model: nn.Module, + max_memory: Optional[dict[Union[int, str], Union[int, str]]] = None, + no_split_module_classes: Optional[list[str]] = None, + dtype: Optional[Union[str, torch.dtype]] = None, + special_dtypes: Optional[dict[str, Union[str, torch.device]]] = None, + low_zero: bool = False, +): + """ + Compute a `max_memory` dictionary for [`infer_auto_device_map`] that will balance the use of each available GPU. + + + + All computation is done analyzing sizes and dtypes of the model parameters. As a result, the model can be on the + meta device (as it would if initialized within the `init_empty_weights` context manager). + + + + Args: + model (`torch.nn.Module`): + The model to analyze. + max_memory (`Dict`, *optional*): + A dictionary device identifier to maximum memory. Will default to the maximum memory available if unset. + Example: `max_memory={0: "1GB"}`. + no_split_module_classes (`List[str]`, *optional*): + A list of layer class names that should never be split across device (for instance any layer that has a + residual connection). + dtype (`str` or `torch.dtype`, *optional*): + If provided, the weights will be converted to that type when loaded. + special_dtypes (`Dict[str, Union[str, torch.device]]`, *optional*): + If provided, special dtypes to consider for some specific weights (will override dtype used as default for + all weights). + low_zero (`bool`, *optional*): + Minimizes the number of weights on GPU 0, which is convenient when it's used for other operations (like the + Transformers generate function). + """ + # Get default / clean up max_memory + user_not_set_max_memory = max_memory is None + max_memory = get_max_memory(max_memory) + + if is_npu_available(): + expected_device_type = "npu" + elif is_mlu_available(): + expected_device_type = "mlu" + elif is_sdaa_available(): + expected_device_type = "sdaa" + elif is_musa_available(): + expected_device_type = "musa" + elif is_xpu_available(): + expected_device_type = "xpu" + elif is_hpu_available(): + expected_device_type = "hpu" + elif is_mps_available(): + expected_device_type = "mps" + else: + expected_device_type = "cuda" + num_devices = len([d for d in max_memory if torch.device(d).type == expected_device_type and max_memory[d] > 0]) + + if num_devices == 0: + return max_memory + + if num_devices == 1: + # We cannot do low_zero on just one GPU, but we will still reserve some memory for the buffer + low_zero = False + # If user just asked us to handle memory usage, we should avoid OOM + if user_not_set_max_memory: + for key in max_memory.keys(): + if isinstance(key, int): + max_memory[key] *= 0.9 # 90% is a good compromise + logger.info( + f"We will use 90% of the memory on device {key} for storing the model, and 10% for the buffer to avoid OOM. " + "You can set `max_memory` in to a higher value to use more memory (at your own risk)." + ) + break # only one device + + module_sizes = compute_module_sizes(model, dtype=dtype, special_dtypes=special_dtypes) + per_gpu = module_sizes[""] // (num_devices - 1 if low_zero else num_devices) + + # We can't just set the memory to model_size // num_devices as it will end being too small: each GPU will get + # slightly less layers and some layers will end up offload at the end. So this function computes a buffer size to + # add which is the biggest of: + # - the size of no split block (if applicable) + # - the mean of the layer sizes + if no_split_module_classes is None: + no_split_module_classes = [] + elif not isinstance(no_split_module_classes, (list, tuple)): + no_split_module_classes = [no_split_module_classes] + + # Identify the size of the no_split_block modules + if len(no_split_module_classes) > 0: + no_split_children = {} + for name, size in module_sizes.items(): + if name == "": + continue + submodule = model + for submodule_name in name.split("."): + submodule = getattr(submodule, submodule_name) + class_name = submodule.__class__.__name__ + if class_name in no_split_module_classes and class_name not in no_split_children: + no_split_children[class_name] = size + + if set(no_split_children.keys()) == set(no_split_module_classes): + break + buffer = max(no_split_children.values()) if len(no_split_children) > 0 else 0 + else: + buffer = 0 + + # Compute mean of final modules. In the first dict of module sizes, leaves are the parameters + leaves = get_module_leaves(module_sizes) + leaves_set = set(leaves) # Convert to set for O(1) membership testing + module_sizes = {n: v for n, v in module_sizes.items() if n not in leaves_set} + # Once removed, leaves are the final modules. + leaves = get_module_leaves(module_sizes) + mean_leaves = int(sum([module_sizes[n] for n in leaves]) / max(len(leaves), 1)) + buffer = int(1.25 * max(buffer, mean_leaves)) + per_gpu += buffer + + # Sorted list of GPUs id (we may have some gpu ids not included in the our max_memory list - let's ignore them) + gpus_idx_list = list( + sorted( + device_id for device_id, device_mem in max_memory.items() if isinstance(device_id, int) and device_mem > 0 + ) + ) + # The last device is left with max_memory just in case the buffer is not enough. + for idx in gpus_idx_list[:-1]: + max_memory[idx] = min(max_memory[0] if low_zero and idx == 0 else per_gpu, max_memory[idx]) + + if low_zero: + min_zero = max(0, module_sizes[""] - sum([max_memory[i] for i in range(1, num_devices)])) + max_memory[0] = min(min_zero, max_memory[0]) + + return max_memory + + +def calculate_maximum_sizes(model: torch.nn.Module): + "Computes the total size of the model and its largest layer" + sizes = compute_module_sizes(model) + # `transformers` models store this information for us + no_split_modules = getattr(model, "_no_split_modules", None) + if no_split_modules is None: + no_split_modules = [] + + modules_to_treat = ( + list(model.named_parameters(recurse=False)) + + list(model.named_children()) + + list(model.named_buffers(recurse=False)) + ) + largest_layer = get_max_layer_size(modules_to_treat, sizes, no_split_modules) + total_size = sizes[""] + return total_size, largest_layer + + +def _init_infer_auto_device_map( + model: nn.Module, + max_memory: Optional[dict[Union[int, str], Union[int, str]]] = None, + no_split_module_classes: Optional[list[str]] = None, + dtype: Optional[Union[str, torch.dtype]] = None, + special_dtypes: Optional[dict[str, Union[str, torch.device]]] = None, +) -> tuple[ + list[Union[int, str]], + dict[Union[int, str], Union[int, str]], + list[Union[int, str]], + list[int], + dict[str, int], + list[list[str]], + list[str], + list[tuple[str, nn.Module]], +]: + """ + Initialize variables required for computing the device map for model allocation. + """ + max_memory = get_max_memory(max_memory) + if no_split_module_classes is None: + no_split_module_classes = [] + elif not isinstance(no_split_module_classes, (list, tuple)): + no_split_module_classes = [no_split_module_classes] + + devices = list(max_memory.keys()) + if "disk" not in devices: + devices.append("disk") + gpus = [device for device in devices if device not in ["cpu", "disk"]] + + # Devices that need to keep space for a potential offloaded layer. + if "mps" in gpus: + main_devices = ["mps"] + elif len(gpus) > 0: + main_devices = [gpus[0], "cpu"] + else: + main_devices = ["cpu"] + + module_sizes = compute_module_sizes(model, dtype=dtype, special_dtypes=special_dtypes) + tied_parameters = find_tied_parameters(model) + if check_tied_parameters_in_config(model) and len(tied_parameters) == 0: + logger.warning( + "The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function." + ) + + # Direct submodules and parameters + modules_to_treat = ( + list(model.named_parameters(recurse=False)) + + list(model.named_children()) + + list(model.named_buffers(recurse=False)) + ) + + return ( + devices, + max_memory, + main_devices, + gpus, + module_sizes, + tied_parameters, + no_split_module_classes, + modules_to_treat, + ) + + +def get_module_size_with_ties( + tied_params, + module_size, + module_sizes, + modules_to_treat, +) -> tuple[int, list[str], list[nn.Module]]: + """ + Calculate the total size of a module, including its tied parameters. + + Args: + tied_params (`List[str]`): The list of tied parameters. + module_size (`int`): The size of the module without tied parameters. + module_sizes (`Dict[str, int]`): A dictionary mapping each layer name to its size. + modules_to_treat (`List[Tuple[str, nn.Module]]`): The list of named modules to treat. + + Returns: + `Tuple[int, List[str], List[nn.Module]]`: The total size of the module, the names of the tied modules, and the + tied modules. + """ + if len(tied_params) < 1: + return module_size, [], [] + tied_module_names = [] + tied_modules = [] + + for tied_param in tied_params: + tied_module_index = [i for i, (n, _) in enumerate(modules_to_treat) if tied_param.startswith(n + ".")][0] + tied_module_names.append(modules_to_treat[tied_module_index][0]) + tied_modules.append(modules_to_treat[tied_module_index][1]) + + module_size_with_ties = module_size + for tied_param, tied_module_name in zip(tied_params, tied_module_names): + module_size_with_ties += module_sizes[tied_module_name] - module_sizes[tied_param] + + return module_size_with_ties, tied_module_names, tied_modules + + +def fallback_allocate( + modules: list[tuple[str, nn.Module]], + module_sizes: dict[str, int], + size_limit: Union[int, str], + no_split_module_classes: Optional[list[str]] = None, + tied_parameters: Optional[list[list[str]]] = None, +) -> tuple[Optional[str], Optional[nn.Module], list[tuple[str, nn.Module]]]: + """ + Find a module that fits in the size limit using BFS and return it with its name and the remaining modules. + + Args: + modules (`List[Tuple[str, nn.Module]]`): + The list of named modules to search in. + module_sizes (`Dict[str, int]`): + A dictionary mapping each layer name to its size (as generated by `compute_module_sizes`). + size_limit (`Union[int, str]`): + The maximum size a module can have. + no_split_module_classes (`Optional[List[str]]`, *optional*): + A list of class names for layers we don't want to be split. + tied_parameters (`Optional[List[List[str]]`, *optional*): + A list of lists of parameter names being all tied together. + + Returns: + `Tuple[Optional[str], Optional[nn.Module], List[Tuple[str, nn.Module]]]`: A tuple containing: + - The name of the module that fits within the size limit. + - The module itself. + - The list of remaining modules after the found module is removed. + """ + try: + size_limit = convert_file_size_to_int(size_limit) + except ValueError: + return None, None, modules + + if no_split_module_classes is None: + no_split_module_classes = [] + + if tied_parameters is None: + tied_parameters = [] + + modules_to_search = modules.copy() + module_found = False + + while modules_to_search: + name, module = modules_to_search.pop(0) + + tied_param_groups = [ + tied_group + for tied_group in tied_parameters + if any(name + "." in k + "." for k in tied_group) and not all(name + "." in k + "." for k in tied_group) + ] + + tied_params = sum( + [[p for p in tied_group if name + "." not in p + "."] for tied_group in tied_param_groups], [] + ) + + module_size_with_ties, _, _ = get_module_size_with_ties( + tied_params, module_sizes[name], module_sizes, modules_to_search + ) + + # If the module fits in the size limit, we found it. + if module_size_with_ties <= size_limit: + module_found = True + break + + # The module is too big, we need to split it if possible. + modules_children = ( + [] + if isinstance(module, nn.Parameter) or isinstance(module, torch.Tensor) + else list(module.named_children()) + ) + + # Split fails, move to the next module + if len(modules_children) == 0 or module.__class__.__name__ in no_split_module_classes: + continue + + # split is possible, add the children to the list of modules to search + modules_children = list(module.named_parameters(recurse=False)) + modules_children + modules_to_search = [(f"{name}.{n}", v) for n, v in modules_children] + modules_to_search + + if not module_found: + return None, None, modules + + # Prepare the module list for removal of the found module + current_names = [n for n, _ in modules] + dot_idx = [i for i, c in enumerate(name) if c == "."] + + for dot_index in dot_idx: + parent_name = name[:dot_index] + if parent_name in current_names: + parent_module_idx = current_names.index(parent_name) + _, parent_module = modules[parent_module_idx] + module_children = list(parent_module.named_parameters(recurse=False)) + list( + parent_module.named_children() + ) + modules = ( + modules[:parent_module_idx] + + [(f"{parent_name}.{n}", v) for n, v in module_children] + + modules[parent_module_idx + 1 :] + ) + current_names = [n for n, _ in modules] + + # Now the target module should be directly in the list + target_idx = current_names.index(name) + name, module = modules.pop(target_idx) + + return name, module, modules + + +def infer_auto_device_map( + model: nn.Module, + max_memory: Optional[dict[Union[int, str], Union[int, str]]] = None, + no_split_module_classes: Optional[list[str]] = None, + dtype: Optional[Union[str, torch.dtype]] = None, + special_dtypes: Optional[dict[str, Union[str, torch.dtype]]] = None, + verbose: bool = False, + clean_result: bool = True, + offload_buffers: bool = False, + fallback_allocation: bool = False, +): + """ + Compute a device map for a given model giving priority to GPUs, then offload on CPU and finally offload to disk, + such that: + - we don't exceed the memory available of any of the GPU. + - if offload to the CPU is needed, there is always room left on GPU 0 to put back the layer offloaded on CPU that + has the largest size. + - if offload to the CPU is needed,we don't exceed the RAM available on the CPU. + - if offload to the disk is needed, there is always room left on the CPU to put back the layer offloaded on disk + that has the largest size. + + + + All computation is done analyzing sizes and dtypes of the model parameters. As a result, the model can be on the + meta device (as it would if initialized within the `init_empty_weights` context manager). + + + + Args: + model (`torch.nn.Module`): + The model to analyze. + max_memory (`Dict`, *optional*): + A dictionary device identifier to maximum memory. Will default to the maximum memory available if unset. + Example: `max_memory={0: "1GB"}`. + no_split_module_classes (`List[str]`, *optional*): + A list of layer class names that should never be split across device (for instance any layer that has a + residual connection). + dtype (`str` or `torch.dtype`, *optional*): + If provided, the weights will be converted to that type when loaded. + special_dtypes (`Dict[str, Union[str, torch.device]]`, *optional*): + If provided, special dtypes to consider for some specific weights (will override dtype used as default for + all weights). + verbose (`bool`, *optional*, defaults to `False`): + Whether or not to provide debugging statements as the function builds the device_map. + clean_result (`bool`, *optional*, defaults to `True`): + Clean the resulting device_map by grouping all submodules that go on the same device together. + offload_buffers (`bool`, *optional*, defaults to `False`): + In the layers that are offloaded on the CPU or the hard drive, whether or not to offload the buffers as + well as the parameters. + fallback_allocation (`bool`, *optional*, defaults to `False`): + When regular allocation fails, try to allocate a module that fits in the size limit using BFS. + """ + + # Initialize the variables + ( + devices, + max_memory, + main_devices, + gpus, + module_sizes, + tied_parameters, + no_split_module_classes, + modules_to_treat, + ) = _init_infer_auto_device_map(model, max_memory, no_split_module_classes, dtype, special_dtypes) + + device_map = OrderedDict() + current_device = 0 + device_memory_used = {device: 0 for device in devices} + device_buffer_sizes = {} + device_minimum_assignment_memory = {} + + # Initialize maximum largest layer, to know which space to keep in memory + max_layer_size, max_layer_names = get_max_layer_size(modules_to_treat, module_sizes, no_split_module_classes) + + # Ready ? This is going to be a bit messy. + while len(modules_to_treat) > 0: + name, module = modules_to_treat.pop(0) + if verbose: + print(f"\nTreating module {name}.") + # Max size in the remaining layers may have changed since we took one, so we maybe update it. + max_layer_names = [n for n in max_layer_names if n != name and not n.startswith(name + ".")] + if len(max_layer_names) == 0: + max_layer_size, max_layer_names = get_max_layer_size( + [(n, m) for n, m in modules_to_treat if isinstance(m, torch.nn.Module)], + module_sizes, + no_split_module_classes, + ) + # Assess size needed + module_size = module_sizes[name] + + # We keep relevant tied parameters only: one of the tied parameters in the group is inside the current module + # and the other is not. + # Note: If we are currently processing the name `compute.weight`, an other parameter named + # e.g. `compute.weight_submodule.parameter` + # needs to be considered outside the current module, hence the check with additional dots. + tied_param_groups = [ + tied_group + for tied_group in tied_parameters + if any(name + "." in k + "." for k in tied_group) and not all(name + "." in k + "." for k in tied_group) + ] + + if verbose and len(tied_param_groups) > 0: + print(f" Found the relevant tied param groups {tied_param_groups}") + + # Then we keep track of all the parameters that are tied to the current module, but not in the current module + tied_params = sum( + [[p for p in tied_group if name + "." not in p + "."] for tied_group in tied_param_groups], [] + ) + + if verbose and len(tied_params) > 0: + print(f" So those parameters need to be taken into account {tied_params}") + + device = devices[current_device] + current_max_size = max_memory[device] if device != "disk" else None + current_memory_reserved = 0 + # Reduce max size available by the largest layer. + if devices[current_device] in main_devices: + current_max_size = current_max_size - max_layer_size + current_memory_reserved = max_layer_size + + module_size_with_ties, tied_module_names, tied_modules = get_module_size_with_ties( + tied_params, module_size, module_sizes, modules_to_treat + ) + + # The module and its tied modules fit on the current device. + if current_max_size is None or device_memory_used[device] + module_size_with_ties <= current_max_size: + if verbose: + output = f"Putting {name}" + + if tied_module_names: + output += f" and {tied_module_names}" + else: + output += f" (size={module_size})" + + if current_max_size is not None: + output += f" (available={current_max_size - device_memory_used[device]})" + + output += f" on {device}." + print(output) + + device_memory_used[device] += module_size_with_ties + + # Assign the primary module to the device. + device_map[name] = device + + # Assign tied modules if any. + for tied_module_name in tied_module_names: + if tied_module_name in [m[0] for m in modules_to_treat]: + # Find the index of the tied module in the list + tied_module_index = next(i for i, (n, _) in enumerate(modules_to_treat) if n == tied_module_name) + # Remove the tied module from the list to prevent reprocessing + modules_to_treat.pop(tied_module_index) + + # Assign the tied module to the device + device_map[tied_module_name] = device + + # Buffer Handling + if not offload_buffers and isinstance(module, nn.Module): + # Compute the total buffer size for the module + current_buffer_size = compute_module_total_buffer_size( + module, dtype=dtype, special_dtypes=special_dtypes + ) + # Update the buffer size on the device + device_buffer_sizes[device] = device_buffer_sizes.get(device, 0) + current_buffer_size + + continue + + # The current module itself fits, so we try to split the tied modules. + if len(tied_params) > 0 and device_memory_used[device] + module_size <= current_max_size: + # can we split one of the tied modules to make it smaller or do we need to go on the next device? + if verbose: + print( + f"Not enough space on {devices[current_device]} to put {name} and {tied_module_names} (space " + f"available {current_max_size - device_memory_used[device]}, needed size {module_size_with_ties})." + ) + split_happened = False + for tied_module_name, tied_module in zip(tied_module_names, tied_modules): + tied_module_children = list(tied_module.named_children()) + if len(tied_module_children) == 0 or tied_module.__class__.__name__ in no_split_module_classes: + # can't break this one. + continue + + if verbose: + print(f"Splitting {tied_module_name}.") + tied_module_children = list(tied_module.named_parameters(recurse=False)) + tied_module_children + tied_module_children = [(f"{tied_module_name}.{n}", v) for n, v in tied_module_children] + tied_module_index = [i for i, (n, _) in enumerate(modules_to_treat) if n == tied_module_name][0] + + modules_to_treat = ( + [(name, module)] + + modules_to_treat[:tied_module_index] + + tied_module_children + + modules_to_treat[tied_module_index + 1 :] + ) + # Update the max layer size. + max_layer_size, max_layer_names = get_max_layer_size( + [(n, m) for n, m in modules_to_treat if isinstance(m, torch.nn.Module)], + module_sizes, + no_split_module_classes, + ) + split_happened = True + break + + if split_happened: + continue + + # If the tied module is not split, we go to the next device + if verbose: + print("None of the tied module can be split, going to the next device.") + + # The current module itself doesn't fit, so we have to split it or go to the next device. + if device_memory_used[device] + module_size >= current_max_size: + # Split or not split? + modules_children = ( + [] + if isinstance(module, nn.Parameter) or isinstance(module, torch.Tensor) + else list(module.named_children()) + ) + if verbose: + print( + f"Not enough space on {devices[current_device]} to put {name} (space available " + f"{current_max_size - device_memory_used[device]}, module size {module_size})." + ) + if len(modules_children) == 0 or module.__class__.__name__ in no_split_module_classes: + # -> no split, we go to the next device + if verbose: + print("This module cannot be split, going to the next device.") + + else: + # -> split, we replace the module studied by its children + parameters + if verbose: + print(f"Splitting {name}.") + modules_children = list(module.named_parameters(recurse=False)) + modules_children + modules_to_treat = [(f"{name}.{n}", v) for n, v in modules_children] + modules_to_treat + # Update the max layer size. + max_layer_size, max_layer_names = get_max_layer_size( + [(n, m) for n, m in modules_to_treat if isinstance(m, torch.nn.Module)], + module_sizes, + no_split_module_classes, + ) + continue + + # If no module is assigned to the current device, we attempt to allocate a fallback module + # if fallback_allocation is enabled. + if device_memory_used[device] == 0 and fallback_allocation and device != "disk": + # We try to allocate a module that fits in the size limit using BFS. + # Recompute the current max size as we need to consider the current module as well. + current_max_size = max_memory[device] - max(max_layer_size, module_size_with_ties) + + fallback_module_name, fallback_module, remaining_modules = fallback_allocate( + modules_to_treat, + module_sizes, + current_max_size - device_memory_used[device], + no_split_module_classes, + tied_parameters, + ) + # use the next iteration to put the fallback module on the next device to avoid code duplication + if fallback_module is not None: + modules_to_treat = [(fallback_module_name, fallback_module)] + [(name, module)] + remaining_modules + continue + + if device_memory_used[device] == 0: + device_minimum_assignment_memory[device] = module_size_with_ties + current_memory_reserved + + # Neither the current module nor any tied modules can be split, so we move to the next device. + device_memory_used[device] = device_memory_used[device] + current_memory_reserved + current_device += 1 + modules_to_treat = [(name, module)] + modules_to_treat + + device_memory_used = {device: mem for device, mem in device_memory_used.items() if mem > 0} + + if clean_result: + device_map = clean_device_map(device_map) + + non_gpu_buffer_size = device_buffer_sizes.get("cpu", 0) + device_buffer_sizes.get("disk", 0) + if non_gpu_buffer_size > 0 and not offload_buffers: + is_buffer_fit_any_gpu = False + for gpu_device, gpu_max_memory in max_memory.items(): + if gpu_device == "cpu" or gpu_device == "disk": + continue + + if not is_buffer_fit_any_gpu: + gpu_memory_used = device_memory_used.get(gpu_device, 0) + + if gpu_max_memory >= non_gpu_buffer_size + gpu_memory_used: + is_buffer_fit_any_gpu = True + + if len(gpus) > 0 and not is_buffer_fit_any_gpu: + warnings.warn( + f"Current model requires {non_gpu_buffer_size} bytes of buffer for offloaded layers, which seems does " + f"not fit any GPU's remaining memory. If you are experiencing a OOM later, please consider using " + f"offload_buffers=True." + ) + + if device_minimum_assignment_memory: + devices_info = "\n".join( + f" - {device}: {mem} bytes required" for device, mem in device_minimum_assignment_memory.items() + ) + logger.info( + f"Based on the current allocation process, no modules could be assigned to the following devices due to " + f"insufficient memory:\n" + f"{devices_info}\n" + f"These minimum requirements are specific to this allocation attempt and may vary. Consider increasing " + f"the available memory for these devices to at least the specified minimum, or adjusting the model config." + ) + return device_map + + +def check_device_map(model: nn.Module, device_map: dict[str, Union[int, str, torch.device]]): + """ + Checks a device map covers everything in a given model. + + Args: + model (`torch.nn.Module`): The model to check the device map against. + device_map (`Dict[str, Union[int, str, torch.device]]`): The device map to check. + """ + all_module_names = dict(model.named_modules()) + invalid_keys = [k for k in device_map if k != "" and k not in all_module_names] + + if invalid_keys: + warnings.warn( + f"The following device_map keys do not match any submodules in the model: {invalid_keys}", UserWarning + ) + + all_model_tensors = [name for name, _ in model.state_dict().items()] + for module_name in device_map.keys(): + if module_name == "": + all_model_tensors.clear() + break + else: + all_model_tensors = [ + name + for name in all_model_tensors + if not name == module_name and not name.startswith(module_name + ".") + ] + if len(all_model_tensors) > 0: + non_covered_params = ", ".join(all_model_tensors) + raise ValueError( + f"The device_map provided does not give any device for the following parameters: {non_covered_params}" + ) + + +def load_state_dict(checkpoint_file, device_map=None): + """ + Load a checkpoint from a given file. If the checkpoint is in the safetensors format and a device map is passed, the + weights can be fast-loaded directly on the GPU. + + Args: + checkpoint_file (`str`): The path to the checkpoint to load. + device_map (`Dict[str, Union[int, str, torch.device]]`, *optional*): + A map that specifies where each submodule should go. It doesn't need to be refined to each parameter/buffer + name, once a given module name is inside, every submodule of it will be sent to the same device. + """ + if checkpoint_file.endswith(".safetensors"): + with safe_open(checkpoint_file, framework="pt") as f: + metadata = f.metadata() + weight_names = f.keys() + + if metadata is None: + logger.warning( + f"The safetensors archive passed at {checkpoint_file} does not contain metadata. " + "Make sure to save your model with the `save_pretrained` method. Defaulting to 'pt' metadata." + ) + metadata = {"format": "pt"} + + if metadata.get("format") not in ["pt", "tf", "flax"]: + raise OSError( + f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure " + "you save your model with the `save_pretrained` method." + ) + elif metadata["format"] != "pt": + raise ValueError(f"The checkpoint passed was saved with {metadata['format']}, we need a the pt format.") + if device_map is None: + return safe_load_file(checkpoint_file) + else: + # if we only have one device we can load everything directly + if len(set(device_map.values())) == 1: + device = list(device_map.values())[0] + target_device = device + if isinstance(device, int): + if is_npu_available(): + target_device = f"npu:{device}" + elif is_hpu_available(): + target_device = "hpu" + + return safe_load_file(checkpoint_file, device=target_device) + + devices = list(set(device_map.values()) - {"disk"}) + # cpu device should always exist as fallback option + if "cpu" not in devices: + devices.append("cpu") + + # For each device, get the weights that go there + device_weights = {device: [] for device in devices} + for module_name, device in device_map.items(): + if device in devices: + device_weights[device].extend( + [k for k in weight_names if k == module_name or k.startswith(module_name + ".")] + ) + + # all weights that haven't defined a device should be loaded on CPU + device_weights["cpu"].extend([k for k in weight_names if k not in sum(device_weights.values(), [])]) + tensors = {} + if is_tqdm_available(): + progress_bar = tqdm( + main_process_only=False, + total=sum([len(device_weights[device]) for device in devices]), + unit="w", + smoothing=0, + leave=False, + ) + else: + progress_bar = None + for device in devices: + target_device = device + if isinstance(device, int): + if is_npu_available(): + target_device = f"npu:{device}" + elif is_hpu_available(): + target_device = "hpu" + + with safe_open(checkpoint_file, framework="pt", device=target_device) as f: + for key in device_weights[device]: + if progress_bar is not None: + progress_bar.set_postfix(dev=device, refresh=False) + progress_bar.set_description(key) + tensors[key] = f.get_tensor(key) + if progress_bar is not None: + progress_bar.update() + if progress_bar is not None: + progress_bar.close() + + return tensors + else: + return torch.load(checkpoint_file, map_location=torch.device("cpu"), weights_only=True) + + +def get_state_dict_offloaded_model(model: nn.Module): + """ + Returns the state dictionary for an offloaded model via iterative onloading + + Args: + model (`torch.nn.Module`): + The offloaded model we want to save + """ + + state_dict = {} + placeholders = set() + for name, module in model.named_modules(): + if name == "": + continue + + try: + with align_module_device(module, "cpu"): + module_state_dict = module.state_dict() + except MemoryError: + raise MemoryError("Offloaded module must fit in CPU memory to call save_model!") from None + + for key in module_state_dict: + # ignore placeholder parameters that are still on the meta device + if module_state_dict[key].device == torch.device("meta"): + placeholders.add(name + f".{key}") + continue + params = module_state_dict[key] + state_dict[name + f".{key}"] = params.to("cpu") # move buffers to cpu + for key in placeholders.copy(): + if key in state_dict: + placeholders.remove(key) + if placeholders: + logger.warning(f"The following tensors were not saved because they were still on meta device: {placeholders}") + + return state_dict + + +def get_state_dict_from_offload( + module: nn.Module, + module_name: str, + state_dict: dict[str, Union[str, torch.tensor]], + device_to_put_offload: Union[int, str, torch.device] = "cpu", +): + """ + Retrieve the state dictionary (with parameters) from an offloaded module and load into a specified device (defaults + to cpu). + + Args: + module: (`torch.nn.Module`): + The module we want to retrieve a state dictionary from + module_name: (`str`): + The name of the module of interest + state_dict (`Dict[str, Union[int, str, torch.device]]`): + Dictionary of {module names: parameters} + device_to_put_offload (`Union[int, str, torch.device]`): + Device to load offloaded parameters into, defaults to the cpu. + """ + + root = module_name[: module_name.rfind(".")] # module name without .weight or .bias + + # do not move parameters if the module is not offloaded + if not has_offloaded_params(module): + device_to_put_offload = None + + # assign the device to which the offloaded parameters will be sent + with align_module_device(module, device_to_put_offload): + for m_key, params in module.state_dict().items(): + if (root + f".{m_key}") in state_dict: + state_dict[root + f".{m_key}"] = params + + return state_dict + + +def load_checkpoint_in_model( + model: nn.Module, + checkpoint: Union[str, os.PathLike], + device_map: Optional[dict[str, Union[int, str, torch.device]]] = None, + offload_folder: Optional[Union[str, os.PathLike]] = None, + dtype: Optional[Union[str, torch.dtype]] = None, + offload_state_dict: bool = False, + offload_buffers: bool = False, + keep_in_fp32_modules: Optional[list[str]] = None, + offload_8bit_bnb: bool = False, + strict: bool = False, + full_state_dict: bool = True, + broadcast_from_rank0: bool = False, +): + """ + Loads a (potentially sharded) checkpoint inside a model, potentially sending weights to a given device as they are + loaded. + + + + Once loaded across devices, you still need to call [`dispatch_model`] on your model to make it able to run. To + group the checkpoint loading and dispatch in one single call, use [`load_checkpoint_and_dispatch`]. + + + + Args: + model (`torch.nn.Module`): + The model in which we want to load a checkpoint. + checkpoint (`str` or `os.PathLike`): + The folder checkpoint to load. It can be: + - a path to a file containing a whole model state dict + - a path to a `.json` file containing the index to a sharded checkpoint + - a path to a folder containing a unique `.index.json` file and the shards of a checkpoint. + - a path to a folder containing a unique pytorch_model.bin or a model.safetensors file. + device_map (`Dict[str, Union[int, str, torch.device]]`, *optional*): + A map that specifies where each submodule should go. It doesn't need to be refined to each parameter/buffer + name, once a given module name is inside, every submodule of it will be sent to the same device. + offload_folder (`str` or `os.PathLike`, *optional*): + If the `device_map` contains any value `"disk"`, the folder where we will offload weights. + dtype (`str` or `torch.dtype`, *optional*): + If provided, the weights will be converted to that type when loaded. + offload_state_dict (`bool`, *optional*, defaults to `False`): + If `True`, will temporarily offload the CPU state dict on the hard drive to avoid getting out of CPU RAM if + the weight of the CPU state dict + the biggest shard does not fit. + offload_buffers (`bool`, *optional*, defaults to `False`): + Whether or not to include the buffers in the weights offloaded to disk. + keep_in_fp32_modules(`List[str]`, *optional*): + A list of the modules that we keep in `torch.float32` dtype. + offload_8bit_bnb (`bool`, *optional*): + Whether or not to enable offload of 8-bit modules on cpu/disk. + strict (`bool`, *optional*, defaults to `False`): + Whether to strictly enforce that the keys in the checkpoint state_dict match the keys of the model's + state_dict. + full_state_dict (`bool`, *optional*, defaults to `True`): if this is set to `True`, all the tensors in the + loaded state_dict will be gathered. No ShardedTensor and DTensor will be in the loaded state_dict. + broadcast_from_rank0 (`False`, *optional*, defaults to `False`): when the option is `True`, a distributed + `ProcessGroup` must be initialized. rank0 should receive a full state_dict and will broadcast the tensors + in the state_dict one by one to other ranks. Other ranks will receive the tensors and shard (if applicable) + according to the local shards in the model. + + """ + if offload_8bit_bnb: + from .bnb import quantize_and_offload_8bit + + tied_params = find_tied_parameters(model) + + if check_tied_parameters_in_config(model) and len(tied_params) == 0: + logger.warning( + "The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function." + ) + if device_map is not None: + check_tied_parameters_on_same_device(tied_params, device_map) + + if offload_folder is None and device_map is not None and "disk" in device_map.values(): + raise ValueError( + "At least one of the model submodule will be offloaded to disk, please pass along an `offload_folder`." + ) + elif offload_folder is not None and device_map is not None and "disk" in device_map.values(): + os.makedirs(offload_folder, exist_ok=True) + + if isinstance(dtype, str): + # We accept "torch.float16" or just "float16" + dtype = dtype.replace("torch.", "") + dtype = getattr(torch, dtype) + + checkpoint_files = None + index_filename = None + if os.path.isfile(checkpoint): + if str(checkpoint).endswith(".json"): + index_filename = checkpoint + else: + checkpoint_files = [checkpoint] + elif os.path.isdir(checkpoint): + # check if the whole state dict is present + potential_state_bin = [f for f in os.listdir(checkpoint) if f == WEIGHTS_NAME] + potential_state_safetensor = [f for f in os.listdir(checkpoint) if f == SAFE_WEIGHTS_NAME] + if len(potential_state_bin) == 1: + checkpoint_files = [os.path.join(checkpoint, potential_state_bin[0])] + elif len(potential_state_safetensor) == 1: + checkpoint_files = [os.path.join(checkpoint, potential_state_safetensor[0])] + else: + # otherwise check for sharded checkpoints + potential_index = [f for f in os.listdir(checkpoint) if f.endswith(".index.json")] + if len(potential_index) == 0: + raise ValueError( + f"{checkpoint} is not a folder containing a `.index.json` file or a {WEIGHTS_NAME} or a {SAFE_WEIGHTS_NAME} file" + ) + elif len(potential_index) == 1: + index_filename = os.path.join(checkpoint, potential_index[0]) + else: + raise ValueError( + f"{checkpoint} containing more than one `.index.json` file, delete the irrelevant ones." + ) + else: + raise ValueError( + "`checkpoint` should be the path to a file containing a whole state dict, or the index of a sharded " + f"checkpoint, or a folder containing a sharded checkpoint or the whole state dict, but got {checkpoint}." + ) + + if index_filename is not None: + checkpoint_folder = os.path.split(index_filename)[0] + with open(index_filename) as f: + index = json.loads(f.read()) + + if "weight_map" in index: + index = index["weight_map"] + checkpoint_files = sorted(list(set(index.values()))) + checkpoint_files = [os.path.join(checkpoint_folder, f) for f in checkpoint_files] + + # Logic for missing/unexepected keys goes here. + + offload_index = {} + if offload_state_dict: + state_dict_folder = tempfile.mkdtemp() + state_dict_index = {} + + unexpected_keys = set() + model_keys = set(model.state_dict().keys()) + buffer_names = [name for name, _ in model.named_buffers()] + model_devices = {t.device for t in model.state_dict().values() if isinstance(t, torch.Tensor)} + model_physical_devices = model_devices - {torch.device("meta")} + for checkpoint_file in checkpoint_files: + if device_map is None: + # exception for multi-device loading was made for the meta device in torch v2.7.0 + # https://github.com/pytorch/pytorch/blob/v2.6.0/torch/distributed/checkpoint/state_dict.py#L557-L563 + # https://github.com/pytorch/pytorch/blob/v2.7.0-rc2/torch/distributed/checkpoint/state_dict.py#L575-L587 + if is_torch_version(">=", "2.2.0") and ( + (is_torch_version(">=", "2.7.0") and len(model_physical_devices) <= 1) or len(model_devices) <= 1 + ): + from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict + + broadcast_from_rank0 &= is_torch_version(">=", "2.4.0") + loaded_checkpoint = ( + load_state_dict(checkpoint_file, device_map=device_map) + if not broadcast_from_rank0 or dist.get_rank() == 0 + else {} + ) + set_model_state_dict( + model, + loaded_checkpoint, + options=StateDictOptions( + full_state_dict=full_state_dict, + strict=strict, + **({"broadcast_from_rank0": broadcast_from_rank0} if is_torch_version(">=", "2.4.0") else {}), + ), + ) + else: + loaded_checkpoint = load_state_dict(checkpoint_file, device_map=device_map) + model.load_state_dict(loaded_checkpoint, strict=strict) + + unexpected_keys.update(set(loaded_checkpoint.keys()) - model_keys) + else: + loaded_checkpoint = load_state_dict(checkpoint_file, device_map=device_map) + + for param_name, param in loaded_checkpoint.items(): + # skip SCB parameter (for 8-bit serialization) + if "SCB" in param_name: + continue + + if param_name not in model_keys: + unexpected_keys.add(param_name) + if not strict: + continue # Skip loading this parameter. + + module_name = param_name + + while len(module_name) > 0 and module_name not in device_map: + module_name = ".".join(module_name.split(".")[:-1]) + if module_name == "" and "" not in device_map: + # TODO: group all errors and raise at the end. + raise ValueError(f"{param_name} doesn't have any device set.") + param_device = device_map[module_name] + new_dtype = dtype + if dtype is not None and torch.is_floating_point(param): + if keep_in_fp32_modules is not None and dtype == torch.float16: + proceed = False + for key in keep_in_fp32_modules: + if ((key in param_name) and (key + "." in param_name)) or key == param_name: + proceed = True + break + if proceed: + new_dtype = torch.float32 + + if "weight" in param_name and param_name.replace("weight", "SCB") in loaded_checkpoint.keys(): + if param.dtype == torch.int8: + fp16_statistics = loaded_checkpoint[param_name.replace("weight", "SCB")] + else: + fp16_statistics = None + + if param_device == "disk": + if offload_buffers or param_name not in buffer_names: + if new_dtype is None: + new_dtype = param.dtype + if offload_8bit_bnb: + quantize_and_offload_8bit( + model, param, param_name, new_dtype, offload_folder, offload_index, fp16_statistics + ) + continue + else: + set_module_tensor_to_device(model, param_name, "meta", dtype=new_dtype) + offload_weight(param, param_name, offload_folder, index=offload_index) + elif param_device == "cpu" and offload_state_dict: + if new_dtype is None: + new_dtype = param.dtype + if offload_8bit_bnb: + quantize_and_offload_8bit( + model, param, param_name, new_dtype, state_dict_folder, state_dict_index, fp16_statistics + ) + else: + set_module_tensor_to_device(model, param_name, "meta", dtype=new_dtype) + offload_weight(param, param_name, state_dict_folder, index=state_dict_index) + else: + set_module_tensor_to_device( + model, + param_name, + param_device, + value=param, + dtype=new_dtype, + fp16_statistics=fp16_statistics, + ) + + # Force Python to clean up. + del loaded_checkpoint + gc.collect() + + if not strict and len(unexpected_keys) > 0: + logger.warning( + f"Some weights of the model checkpoint at {checkpoint} were not used when" + f" initializing {model.__class__.__name__}: {unexpected_keys}. This may or may not be an issue - make sure that the checkpoint does not have unnecessary parameters, or that the model definition correctly corresponds to the checkpoint." + ) + + save_offload_index(offload_index, offload_folder) + + # Load back offloaded state dict on CPU + if offload_state_dict: + load_offloaded_weights(model, state_dict_index, state_dict_folder) + shutil.rmtree(state_dict_folder) + + retie_parameters(model, tied_params) + + +def get_mixed_precision_context_manager(native_amp: bool = False, autocast_kwargs: AutocastKwargs = None): + """ + Return a context manager for autocasting mixed precision + + Args: + native_amp (`bool`, *optional*, defaults to False): + Whether mixed precision is actually enabled. + cache_enabled (`bool`, *optional*, defaults to True): + Whether the weight cache inside autocast should be enabled. + """ + state = AcceleratorState() + if autocast_kwargs is None: + autocast_kwargs = {} + else: + autocast_kwargs = autocast_kwargs.to_kwargs() + if native_amp: + device_type = ( + "cuda" + if (state.distributed_type == DistributedType.XLA and is_torch_xla_available(check_is_gpu=True)) + else state.device.type + ) + if state.mixed_precision == "fp16": + return torch.autocast(device_type=device_type, dtype=torch.float16, **autocast_kwargs) + elif state.mixed_precision in ["bf16", "fp8"] and state.distributed_type in [ + DistributedType.NO, + DistributedType.MULTI_CPU, + DistributedType.MULTI_GPU, + DistributedType.MULTI_MLU, + DistributedType.MULTI_SDAA, + DistributedType.MULTI_MUSA, + DistributedType.MULTI_NPU, + DistributedType.MULTI_XPU, + DistributedType.MULTI_HPU, + DistributedType.MULTI_NEURON, + DistributedType.FSDP, + DistributedType.XLA, + ]: + return torch.autocast(device_type=device_type, dtype=torch.bfloat16, **autocast_kwargs) + else: + return torch.autocast(device_type=device_type, **autocast_kwargs) + else: + return contextlib.nullcontext() + + +def get_grad_scaler(distributed_type: DistributedType = None, **kwargs): + """ + A generic helper which will initialize the correct `GradScaler` implementation based on the environment and return + it. + + Args: + distributed_type (`DistributedType`, *optional*, defaults to None): + The type of distributed environment. + kwargs: + Additional arguments for the utilized `GradScaler` constructor. + """ + if distributed_type == DistributedType.FSDP: + from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler + + return ShardedGradScaler(**kwargs) + if is_torch_xla_available(check_is_gpu=True): + import torch_xla.amp as xamp + + return xamp.GradScaler(**kwargs) + elif is_mlu_available(): + return torch.mlu.amp.GradScaler(**kwargs) + elif is_sdaa_available(): + return torch.sdaa.amp.GradScaler(**kwargs) + elif is_musa_available(): + return torch.musa.amp.GradScaler(**kwargs) + elif is_npu_available(): + return torch.npu.amp.GradScaler(**kwargs) + elif is_hpu_available(): + return torch.amp.GradScaler("hpu", **kwargs) + elif is_xpu_available(): + return torch.amp.GradScaler("xpu", **kwargs) + elif is_mps_available(): + if not is_torch_version(">=", "2.8.0"): + raise ValueError("Grad Scaler with MPS device requires a Pytorch >= 2.8.0") + return torch.amp.GradScaler("mps", **kwargs) + else: + if is_torch_version(">=", "2.3"): + return torch.amp.GradScaler("cuda", **kwargs) + else: + return torch.cuda.amp.GradScaler(**kwargs) + + +def has_offloaded_params(module: torch.nn.Module) -> bool: + """ + Checks if a module has offloaded parameters by checking if the given module has a AlignDevicesHook attached with + offloading enabled + + Args: + module (`torch.nn.Module`): The module to check for an offload hook. + + Returns: + bool: `True` if the module has an offload hook and offloading is enabled, `False` otherwise. + """ + from ..hooks import AlignDevicesHook # avoid circular import + + return hasattr(module, "_hf_hook") and isinstance(module._hf_hook, AlignDevicesHook) and module._hf_hook.offload + + +@contextlib.contextmanager +def align_module_device(module: torch.nn.Module, execution_device: Optional[torch.device] = None): + """ + Context manager that moves a module's parameters to the specified execution device. + + Args: + module (`torch.nn.Module`): + Module with parameters to align. + execution_device (`torch.device`, *optional*): + If provided, overrides the module's execution device within the context. Otherwise, use hook execution + device or pass + """ + if has_offloaded_params(module): + if execution_device is not None: + original_device = module._hf_hook.execution_device + module._hf_hook.execution_device = execution_device + + try: + module._hf_hook.pre_forward(module) + yield + finally: + module._hf_hook.post_forward(module, None) + if execution_device is not None: + module._hf_hook.execution_device = original_device + + elif execution_device is not None: + devices = {name: param.device for name, param in module.named_parameters(recurse=False)} + try: + for name in devices: + set_module_tensor_to_device(module, name, execution_device) + yield + finally: + for name, device in devices.items(): + set_module_tensor_to_device(module, name, device) + + else: + yield diff --git a/accelerate/utils/offload.py b/accelerate/utils/offload.py new file mode 100644 index 0000000000000000000000000000000000000000..da8cfaf3ebf7f2965f408e2b952103ac26868647 --- /dev/null +++ b/accelerate/utils/offload.py @@ -0,0 +1,213 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +from collections.abc import Mapping +from typing import Optional, Union + +import numpy as np +import torch +from safetensors import safe_open + + +def offload_weight(weight, weight_name, offload_folder, index=None): + dtype = None + # Check the string instead of the dtype to be compatible with versions of PyTorch that don't have bfloat16. + if str(weight.dtype) == "torch.bfloat16": + # Need to reinterpret the underlined data as int16 since NumPy does not handle bfloat16s. + weight = weight.view(torch.int16) + dtype = "bfloat16" + array = weight.cpu().numpy() + tensor_file = os.path.join(offload_folder, f"{weight_name}.dat") + if index is not None: + if dtype is None: + dtype = str(array.dtype) + index[weight_name] = {"dtype": dtype, "shape": list(array.shape)} + if array.ndim == 0: + array = array[None] + file_array = np.memmap(tensor_file, dtype=array.dtype, mode="w+", shape=array.shape) + file_array[:] = array[:] + file_array.flush() + return index + + +def load_offloaded_weight(weight_file, weight_info): + shape = tuple(weight_info["shape"]) + if shape == (): + # NumPy memory-mapped arrays can't have 0 dims so it was saved as 1d tensor + shape = (1,) + + dtype = weight_info["dtype"] + if dtype == "bfloat16": + # NumPy does not support bfloat16 so this was saved as a int16 + dtype = "int16" + + weight = np.memmap(weight_file, dtype=dtype, shape=shape, mode="r") + + if len(weight_info["shape"]) == 0: + weight = weight[0] + weight = torch.tensor(weight) + if weight_info["dtype"] == "bfloat16": + weight = weight.view(torch.bfloat16) + + return weight + + +def save_offload_index(index, offload_folder): + if index is None or len(index) == 0: + # Nothing to save + return + + offload_index_file = os.path.join(offload_folder, "index.json") + if os.path.isfile(offload_index_file): + with open(offload_index_file, encoding="utf-8") as f: + current_index = json.load(f) + else: + current_index = {} + current_index.update(index) + + with open(offload_index_file, "w", encoding="utf-8") as f: + json.dump(current_index, f, indent=2) + + +def offload_state_dict(save_dir: Union[str, os.PathLike], state_dict: dict[str, torch.Tensor]): + """ + Offload a state dict in a given folder. + + Args: + save_dir (`str` or `os.PathLike`): + The directory in which to offload the state dict. + state_dict (`Dict[str, torch.Tensor]`): + The dictionary of tensors to offload. + """ + os.makedirs(save_dir, exist_ok=True) + index = {} + for name, parameter in state_dict.items(): + index = offload_weight(parameter, name, save_dir, index=index) + + # Update index + save_offload_index(index, save_dir) + + +class PrefixedDataset(Mapping): + """ + Will access keys in a given dataset by adding a prefix. + + Args: + dataset (`Mapping`): Any map with string keys. + prefix (`str`): A prefix to add when trying to access any element in the underlying dataset. + """ + + def __init__(self, dataset: Mapping, prefix: str): + self.dataset = dataset + self.prefix = prefix + + def __getitem__(self, key): + return self.dataset[f"{self.prefix}{key}"] + + def __iter__(self): + return iter([key for key in self.dataset if key.startswith(self.prefix)]) + + def __len__(self): + return len(self.dataset) + + +class OffloadedWeightsLoader(Mapping): + """ + A collection that loads weights stored in a given state dict or memory-mapped on disk. + + Args: + state_dict (`Dict[str, torch.Tensor]`, *optional*): + A dictionary parameter name to tensor. + save_folder (`str` or `os.PathLike`, *optional*): + The directory in which the weights are stored (by `offload_state_dict` for instance). + index (`Dict`, *optional*): + A dictionary from weight name to their information (`dtype`/ `shape` or safetensors filename). Will default + to the index saved in `save_folder`. + """ + + def __init__( + self, + state_dict: Optional[dict[str, torch.Tensor]] = None, + save_folder: Optional[Union[str, os.PathLike]] = None, + index: Optional[Mapping] = None, + device=None, + ): + if state_dict is None and save_folder is None and index is None: + raise ValueError("Need either a `state_dict`, a `save_folder` or an `index` containing offloaded weights.") + + self.state_dict = {} if state_dict is None else state_dict + self.save_folder = save_folder + if index is None and save_folder is not None: + with open(os.path.join(save_folder, "index.json")) as f: + index = json.load(f) + self.index = {} if index is None else index + self.all_keys = list(self.state_dict.keys()) + self.all_keys.extend([key for key in self.index if key not in self.all_keys]) + self.device = device + + def __getitem__(self, key: str): + # State dict gets priority + if key in self.state_dict: + return self.state_dict[key] + weight_info = self.index[key] + if weight_info.get("safetensors_file") is not None: + device = "cpu" if self.device is None else self.device + tensor = None + try: + with safe_open(weight_info["safetensors_file"], framework="pt", device=device) as f: + tensor = f.get_tensor(weight_info.get("weight_name", key)) + except TypeError: + # if failed to get_tensor on the device, such as bf16 on mps, try to load it on CPU first + with safe_open(weight_info["safetensors_file"], framework="pt", device="cpu") as f: + tensor = f.get_tensor(weight_info.get("weight_name", key)) + + if "dtype" in weight_info: + tensor = tensor.to(getattr(torch, weight_info["dtype"])) + + if tensor.device != torch.device(device): + tensor = tensor.to(device) + return tensor + + weight_file = os.path.join(self.save_folder, f"{key}.dat") + return load_offloaded_weight(weight_file, weight_info) + + def __iter__(self): + return iter(self.all_keys) + + def __len__(self): + return len(self.all_keys) + + +def extract_submodules_state_dict(state_dict: dict[str, torch.Tensor], submodule_names: list[str]): + """ + Extract the sub state-dict corresponding to a list of given submodules. + + Args: + state_dict (`Dict[str, torch.Tensor]`): The state dict to extract from. + submodule_names (`List[str]`): The list of submodule names we want to extract. + """ + result = {} + for module_name in submodule_names: + # We want to catch module_name parameter (module_name.xxx) or potentially module_name, but not any of the + # submodules that could being like module_name (transformers.h.1 and transformers.h.10 for instance) + result.update( + { + key: param + for key, param in state_dict.items() + if key == module_name or key.startswith(module_name + ".") + } + ) + return result diff --git a/accelerate/utils/operations.py b/accelerate/utils/operations.py new file mode 100644 index 0000000000000000000000000000000000000000..ceec9b457fe35cf146ffe3c65c73a4b2a8fb202a --- /dev/null +++ b/accelerate/utils/operations.py @@ -0,0 +1,871 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +A set of basic tensor ops compatible with tpu, gpu, and multigpu +""" + +import pickle +import warnings +from collections.abc import Mapping +from contextlib import contextmanager, nullcontext +from functools import update_wrapper, wraps +from typing import Any + +import torch + +from ..state import AcceleratorState, PartialState +from .constants import TORCH_DISTRIBUTED_OPERATION_TYPES +from .dataclasses import DistributedType, TensorInformation +from .imports import ( + is_npu_available, + is_torch_distributed_available, + is_torch_xla_available, +) +from .versions import is_torch_version + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + +if is_torch_distributed_available(): + from torch.distributed import ReduceOp + + +def is_torch_tensor(tensor): + return isinstance(tensor, torch.Tensor) + + +def is_torch_xpu_tensor(tensor): + return isinstance( + tensor, + torch.xpu.FloatTensor, + torch.xpu.ByteTensor, + torch.xpu.IntTensor, + torch.xpu.LongTensor, + torch.xpu.HalfTensor, + torch.xpu.DoubleTensor, + torch.xpu.BFloat16Tensor, + ) + + +def is_tensor_information(tensor_info): + return isinstance(tensor_info, TensorInformation) + + +def is_namedtuple(data): + """ + Checks if `data` is a `namedtuple` or not. Can have false positives, but only if a user is trying to mimic a + `namedtuple` perfectly. + """ + return isinstance(data, tuple) and hasattr(data, "_asdict") and hasattr(data, "_fields") + + +def honor_type(obj, generator): + """ + Cast a generator to the same type as obj (list, tuple, or namedtuple) + """ + # Some objects may not be able to instantiate from a generator directly + if is_namedtuple(obj): + return type(obj)(*list(generator)) + else: + return type(obj)(generator) + + +def recursively_apply(func, data, *args, test_type=is_torch_tensor, error_on_other_type=False, **kwargs): + """ + Recursively apply a function on a data structure that is a nested list/tuple/dictionary of a given base type. + + Args: + func (`callable`): + The function to recursively apply. + data (nested list/tuple/dictionary of `main_type`): + The data on which to apply `func` + *args: + Positional arguments that will be passed to `func` when applied on the unpacked data. + main_type (`type`, *optional*, defaults to `torch.Tensor`): + The base type of the objects to which apply `func`. + error_on_other_type (`bool`, *optional*, defaults to `False`): + Whether to return an error or not if after unpacking `data`, we get on an object that is not of type + `main_type`. If `False`, the function will leave objects of types different than `main_type` unchanged. + **kwargs (additional keyword arguments, *optional*): + Keyword arguments that will be passed to `func` when applied on the unpacked data. + + Returns: + The same data structure as `data` with `func` applied to every object of type `main_type`. + """ + if isinstance(data, (tuple, list)): + return honor_type( + data, + ( + recursively_apply( + func, o, *args, test_type=test_type, error_on_other_type=error_on_other_type, **kwargs + ) + for o in data + ), + ) + elif isinstance(data, Mapping): + return type(data)( + { + k: recursively_apply( + func, v, *args, test_type=test_type, error_on_other_type=error_on_other_type, **kwargs + ) + for k, v in data.items() + } + ) + elif test_type(data): + return func(data, *args, **kwargs) + elif error_on_other_type: + raise TypeError( + f"Unsupported types ({type(data)}) passed to `{func.__name__}`. Only nested list/tuple/dicts of " + f"objects that are valid for `{test_type.__name__}` should be passed." + ) + return data + + +def send_to_device(tensor, device, non_blocking=False, skip_keys=None): + """ + Recursively sends the elements in a nested list/tuple/dictionary of tensors to a given device. + + Args: + tensor (nested list/tuple/dictionary of `torch.Tensor`): + The data to send to a given device. + device (`torch.device`): + The device to send the data to. + + Returns: + The same data structure as `tensor` with all tensors sent to the proper device. + """ + if is_torch_tensor(tensor) or hasattr(tensor, "to"): + # `torch.Tensor.to("npu")` could not find context when called for the first time (see this [issue](https://gitee.com/ascend/pytorch/issues/I8KECW?from=project-issue)). + if device == "npu": + device = "npu:0" + try: + return tensor.to(device, non_blocking=non_blocking) + except TypeError: # .to() doesn't accept non_blocking as kwarg + return tensor.to(device) + except AssertionError as error: + # `torch.Tensor.to()` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)). + # This call is inside the try-block since is_npu_available is not supported by torch.compile. + if is_npu_available(): + if isinstance(device, int): + device = f"npu:{device}" + else: + raise error + try: + return tensor.to(device, non_blocking=non_blocking) + except TypeError: # .to() doesn't accept non_blocking as kwarg + return tensor.to(device) + elif isinstance(tensor, (tuple, list)): + return honor_type( + tensor, (send_to_device(t, device, non_blocking=non_blocking, skip_keys=skip_keys) for t in tensor) + ) + elif isinstance(tensor, Mapping): + if isinstance(skip_keys, str): + skip_keys = [skip_keys] + elif skip_keys is None: + skip_keys = [] + return type(tensor)( + { + k: t if k in skip_keys else send_to_device(t, device, non_blocking=non_blocking, skip_keys=skip_keys) + for k, t in tensor.items() + } + ) + else: + return tensor + + +def get_data_structure(data): + """ + Recursively gathers the information needed to rebuild a nested list/tuple/dictionary of tensors. + + Args: + data (nested list/tuple/dictionary of `torch.Tensor`): + The data to send to analyze. + + Returns: + The same data structure as `data` with [`~utils.TensorInformation`] instead of tensors. + """ + + def _get_data_structure(tensor): + return TensorInformation(shape=tensor.shape, dtype=tensor.dtype) + + return recursively_apply(_get_data_structure, data) + + +def get_shape(data): + """ + Recursively gathers the shape of a nested list/tuple/dictionary of tensors as a list. + + Args: + data (nested list/tuple/dictionary of `torch.Tensor`): + The data to send to analyze. + + Returns: + The same data structure as `data` with lists of tensor shapes instead of tensors. + """ + + def _get_shape(tensor): + return list(tensor.shape) + + return recursively_apply(_get_shape, data) + + +def initialize_tensors(data_structure): + """ + Recursively initializes tensors from a nested list/tuple/dictionary of [`~utils.TensorInformation`]. + + Returns: + The same data structure as `data` with tensors instead of [`~utils.TensorInformation`]. + """ + + def _initialize_tensor(tensor_info): + return torch.empty(*tensor_info.shape, dtype=tensor_info.dtype) + + return recursively_apply(_initialize_tensor, data_structure, test_type=is_tensor_information) + + +def find_batch_size(data): + """ + Recursively finds the batch size in a nested list/tuple/dictionary of lists of tensors. + + Args: + data (nested list/tuple/dictionary of `torch.Tensor`): The data from which to find the batch size. + + Returns: + `int`: The batch size. + """ + if isinstance(data, (tuple, list, Mapping)) and (len(data) == 0): + raise ValueError(f"Cannot find the batch size from empty {type(data)}.") + + if isinstance(data, (tuple, list)): + return find_batch_size(data[0]) + elif isinstance(data, Mapping): + for k in data.keys(): + return find_batch_size(data[k]) + elif not isinstance(data, torch.Tensor): + raise TypeError(f"Can only find the batch size of tensors but got {type(data)}.") + return data.shape[0] + + +def ignorant_find_batch_size(data): + """ + Same as [`utils.operations.find_batch_size`] except will ignore if `ValueError` and `TypeErrors` are raised + + Args: + data (nested list/tuple/dictionary of `torch.Tensor`): The data from which to find the batch size. + + Returns: + `int`: The batch size. + """ + try: + return find_batch_size(data) + except (ValueError, TypeError): + pass + return None + + +def listify(data): + """ + Recursively finds tensors in a nested list/tuple/dictionary and converts them to a list of numbers. + + Args: + data (nested list/tuple/dictionary of `torch.Tensor`): The data from which to convert to regular numbers. + + Returns: + The same data structure as `data` with lists of numbers instead of `torch.Tensor`. + """ + + def _convert_to_list(tensor): + tensor = tensor.detach().cpu() + if tensor.dtype == torch.bfloat16: + # As of Numpy 1.21.4, NumPy does not support bfloat16 (see + # https://github.com/numpy/numpy/blob/a47ecdea856986cd60eabbd53265c2ca5916ad5d/doc/source/user/basics.types.rst ). + # Until Numpy adds bfloat16, we must convert float32. + tensor = tensor.to(torch.float32) + return tensor.tolist() + + return recursively_apply(_convert_to_list, data) + + +def _tpu_gather(tensor): + def _tpu_gather_one(tensor): + if tensor.ndim == 0: + tensor = tensor.clone()[None] + + # Can only gather contiguous tensors + if not tensor.is_contiguous(): + tensor = tensor.contiguous() + return xm.all_gather(tensor) + + res = recursively_apply(_tpu_gather_one, tensor, error_on_other_type=True) + xm.mark_step() + return res + + +def _gpu_gather(tensor): + state = PartialState() + gather_op = torch.distributed.all_gather_into_tensor + + # NOTE: need manually synchronize to workaourd a INT64 collectives bug in oneCCL before torch 2.9.0 + if state.device.type == "xpu" and is_torch_version("<=", "2.8"): + torch.xpu.synchronize() + + def _gpu_gather_one(tensor): + if tensor.ndim == 0: + tensor = tensor.clone()[None] + + # Can only gather contiguous tensors + if not tensor.is_contiguous(): + tensor = tensor.contiguous() + + if state.backend is not None and state.backend != "gloo": + # We use `empty` as `all_gather_into_tensor` slightly + # differs from `all_gather` for better efficiency, + # and we rely on the number of items in the tensor + # rather than its direct shape + output_tensors = torch.empty( + state.num_processes * tensor.numel(), + dtype=tensor.dtype, + device=state.device, + ) + gather_op(output_tensors, tensor) + return output_tensors.view(-1, *tensor.size()[1:]) + else: + # a backend of `None` is always CPU + # also gloo does not support `all_gather_into_tensor`, + # which will result in a larger memory overhead for the op + output_tensors = [torch.empty_like(tensor) for _ in range(state.num_processes)] + torch.distributed.all_gather(output_tensors, tensor) + return torch.cat(output_tensors, dim=0) + + return recursively_apply(_gpu_gather_one, tensor, error_on_other_type=True) + + +class DistributedOperationException(Exception): + """ + An exception class for distributed operations. Raised if the operation cannot be performed due to the shape of the + tensors. + """ + + pass + + +def verify_operation(function): + """ + Verifies that `tensor` is the same shape across all processes. Only ran if `PartialState().debug` is `True`. + """ + + @wraps(function) + def wrapper(*args, **kwargs): + if PartialState().distributed_type == DistributedType.NO or not PartialState().debug: + return function(*args, **kwargs) + operation = f"{function.__module__}.{function.__name__}" + if "tensor" in kwargs: + tensor = kwargs["tensor"] + else: + tensor = args[0] + if PartialState().device.type != find_device(tensor).type: + raise DistributedOperationException( + f"One or more of the tensors passed to {operation} were not on the {tensor.device.type} while the `Accelerator` is configured for {PartialState().device.type}. " + f"Please move it to the {PartialState().device.type} before calling {operation}." + ) + shapes = get_shape(tensor) + output = gather_object([shapes]) + if output[0] is not None: + are_same = output.count(output[0]) == len(output) + if not are_same: + process_shape_str = "\n - ".join([f"Process {i}: {shape}" for i, shape in enumerate(output)]) + raise DistributedOperationException( + f"Cannot apply desired operation due to shape mismatches. " + "All shapes across devices must be valid." + f"\n\nOperation: `{operation}`\nInput shapes:\n - {process_shape_str}" + ) + return function(*args, **kwargs) + + return wrapper + + +def chained_operation(function): + """ + Checks that `verify_operation` failed and if so reports a more helpful error chaining the existing + `DistributedOperationException`. + """ + + @wraps(function) + def wrapper(*args, **kwargs): + try: + return function(*args, **kwargs) + except DistributedOperationException as e: + operation = f"{function.__module__}.{function.__name__}" + raise DistributedOperationException( + f"Error found while calling `{operation}`. Please see the earlier error for more details." + ) from e + + return wrapper + + +@verify_operation +def gather(tensor): + """ + Recursively gather tensor in a nested list/tuple/dictionary of tensors from all devices. + + Args: + tensor (nested list/tuple/dictionary of `torch.Tensor`): + The data to gather. + + Returns: + The same data structure as `tensor` with all tensors sent to the proper device. + """ + if PartialState().distributed_type == DistributedType.XLA: + return _tpu_gather(tensor) + elif PartialState().distributed_type in TORCH_DISTRIBUTED_OPERATION_TYPES: + return _gpu_gather(tensor) + else: + return tensor + + +def _gpu_gather_object(object: Any): + output_objects = [None for _ in range(PartialState().num_processes)] + torch.distributed.all_gather_object(output_objects, object) + # all_gather_object returns a list of lists, so we need to flatten it + return [x for y in output_objects for x in y] + + +def gather_object(object: Any): + """ + Recursively gather object in a nested list/tuple/dictionary of objects from all devices. + + Args: + object (nested list/tuple/dictionary of picklable object): + The data to gather. + + Returns: + The same data structure as `object` with all the objects sent to every device. + """ + if PartialState().distributed_type == DistributedType.XLA: + raise NotImplementedError("gather objects in TPU is not supported") + elif PartialState().distributed_type in TORCH_DISTRIBUTED_OPERATION_TYPES: + return _gpu_gather_object(object) + else: + return object + + +def _gpu_broadcast(data, src=0): + def _gpu_broadcast_one(tensor, src=0): + torch.distributed.broadcast(tensor, src=src) + return tensor + + return recursively_apply(_gpu_broadcast_one, data, error_on_other_type=True, src=src) + + +def _tpu_broadcast(tensor, src=0, name="broadcast tensor"): + if isinstance(tensor, (list, tuple)): + return honor_type(tensor, (_tpu_broadcast(t, name=f"{name}_{i}") for i, t in enumerate(tensor))) + elif isinstance(tensor, Mapping): + return type(tensor)({k: _tpu_broadcast(v, name=f"{name}_{k}") for k, v in tensor.items()}) + return xm.mesh_reduce(name, tensor, lambda x: x[src]) + + +TENSOR_TYPE_TO_INT = { + torch.float: 1, + torch.double: 2, + torch.half: 3, + torch.bfloat16: 4, + torch.uint8: 5, + torch.int8: 6, + torch.int16: 7, + torch.int32: 8, + torch.int64: 9, + torch.bool: 10, +} + +TENSOR_INT_TO_DTYPE = {v: k for k, v in TENSOR_TYPE_TO_INT.items()} + + +def gather_tensor_shape(tensor): + """ + Grabs the shape of `tensor` only available on one process and returns a tensor of its shape + """ + # Allocate 80 bytes to store the shape + max_tensor_dimension = 2**20 + state = PartialState() + base_tensor = torch.empty(max_tensor_dimension, dtype=torch.int, device=state.device) + + # Since PyTorch can't just send a tensor to another GPU without + # knowing its size, we store the size of the tensor with data + # in an allocation + if tensor is not None: + shape = tensor.shape + tensor_dtype = TENSOR_TYPE_TO_INT[tensor.dtype] + base_tensor[: len(shape) + 1] = torch.tensor(list(shape) + [tensor_dtype], dtype=int) + # Perform a reduction to copy the size data onto all GPUs + base_tensor = reduce(base_tensor, reduction="sum") + base_tensor = base_tensor[base_tensor.nonzero()] + # The last non-zero data contains the coded dtype the source tensor is + dtype = int(base_tensor[-1:][0]) + base_tensor = base_tensor[:-1] + return base_tensor, dtype + + +def copy_tensor_to_devices(tensor=None) -> torch.Tensor: + """ + Copies a tensor that only exists on a single device and broadcasts it to other devices. Differs from `broadcast` as + each worker doesn't need to know its shape when used (and tensor can be `None`) + + Args: + tensor (`torch.tensor`): + The tensor that should be sent to all devices. Must only have it be defined on a single device, the rest + should be `None`. + """ + state = PartialState() + shape, dtype = gather_tensor_shape(tensor) + if tensor is None: + tensor = torch.zeros(shape, dtype=TENSOR_INT_TO_DTYPE[dtype]).to(state.device) + return reduce(tensor, reduction="sum") + + +@verify_operation +def broadcast(tensor, from_process: int = 0): + """ + Recursively broadcast tensor in a nested list/tuple/dictionary of tensors to all devices. + + Args: + tensor (nested list/tuple/dictionary of `torch.Tensor`): + The data to gather. + from_process (`int`, *optional*, defaults to 0): + The process from which to send the data + + Returns: + The same data structure as `tensor` with all tensors broadcasted to the proper device. + """ + if PartialState().distributed_type == DistributedType.XLA: + return _tpu_broadcast(tensor, src=from_process, name="accelerate.utils.broadcast") + elif PartialState().distributed_type in TORCH_DISTRIBUTED_OPERATION_TYPES: + return _gpu_broadcast(tensor, src=from_process) + else: + return tensor + + +def broadcast_object_list(object_list, from_process: int = 0): + """ + Broadcast a list of picklable objects from one process to the others. + + Args: + object_list (list of picklable objects): + The list of objects to broadcast. This list will be modified inplace. + from_process (`int`, *optional*, defaults to 0): + The process from which to send the data. + + Returns: + The same list containing the objects from process 0. + """ + if PartialState().distributed_type == DistributedType.XLA: + for i, obj in enumerate(object_list): + object_list[i] = xm.mesh_reduce("accelerate.utils.broadcast_object_list", obj, lambda x: x[from_process]) + elif PartialState().distributed_type in TORCH_DISTRIBUTED_OPERATION_TYPES: + torch.distributed.broadcast_object_list(object_list, src=from_process) + return object_list + + +def slice_tensors(data, tensor_slice, process_index=None, num_processes=None): + """ + Recursively takes a slice in a nested list/tuple/dictionary of tensors. + + Args: + data (nested list/tuple/dictionary of `torch.Tensor`): + The data to slice. + tensor_slice (`slice`): + The slice to take. + + Returns: + The same data structure as `data` with all the tensors slices. + """ + + def _slice_tensor(tensor, tensor_slice): + return tensor[tensor_slice] + + return recursively_apply(_slice_tensor, data, tensor_slice) + + +def concatenate(data, dim=0): + """ + Recursively concatenate the tensors in a nested list/tuple/dictionary of lists of tensors with the same shape. + If there is only a single batch of data, it is returned as-is. + + Args: + data (nested list/tuple/dictionary of lists of tensors `torch.Tensor`): + The data to concatenate. + dim (`int`, *optional*, defaults to 0): + The dimension on which to concatenate. + + Returns: + The same data structure as `data` with all the tensors concatenated. + """ + if isinstance(data[0], (tuple, list)): + return honor_type(data[0], (concatenate([d[i] for d in data], dim=dim) for i in range(len(data[0])))) + elif isinstance(data[0], Mapping): + return type(data[0])({k: concatenate([d[k] for d in data], dim=dim) for k in data[0].keys()}) + elif isinstance(data[0], torch.Tensor): + return torch.cat(data, dim=dim) + elif isinstance(data, (tuple, list)) and len(data) == 1: + return data[0] + else: + raise TypeError(f"Can only concatenate tensors but got {type(data[0])}") + + +class CannotPadNestedTensorWarning(UserWarning): + pass + + +@chained_operation +def pad_across_processes(tensor, dim=0, pad_index=0, pad_first=False): + """ + Recursively pad the tensors in a nested list/tuple/dictionary of tensors from all devices to the same size so they + can safely be gathered. + + Args: + tensor (nested list/tuple/dictionary of `torch.Tensor`): + The data to gather. + dim (`int`, *optional*, defaults to 0): + The dimension on which to pad. + pad_index (`int`, *optional*, defaults to 0): + The value with which to pad. + pad_first (`bool`, *optional*, defaults to `False`): + Whether to pad at the beginning or the end. + """ + + def _pad_across_processes(tensor, dim=0, pad_index=0, pad_first=False): + if getattr(tensor, "is_nested", False): + warnings.warn( + "Cannot pad nested tensors without more information. Leaving unprocessed.", + CannotPadNestedTensorWarning, + ) + return tensor + if dim >= len(tensor.shape) or dim < -len(tensor.shape): + return tensor + # Convert negative dimensions to non-negative + if dim < 0: + dim += len(tensor.shape) + + # Gather all sizes + size = torch.tensor(tensor.shape, device=tensor.device)[None] + sizes = gather(size).cpu() + # Then pad to the maximum size + max_size = max(s[dim] for s in sizes) + if max_size == tensor.shape[dim]: + return tensor + + old_size = tensor.shape + new_size = list(old_size) + new_size[dim] = max_size + new_tensor = tensor.new_zeros(tuple(new_size)) + pad_index + if pad_first: + indices = tuple( + slice(max_size - old_size[dim], max_size) if i == dim else slice(None) for i in range(len(new_size)) + ) + else: + indices = tuple(slice(0, old_size[dim]) if i == dim else slice(None) for i in range(len(new_size))) + new_tensor[indices] = tensor + return new_tensor + + return recursively_apply( + _pad_across_processes, tensor, error_on_other_type=True, dim=dim, pad_index=pad_index, pad_first=pad_first + ) + + +def pad_input_tensors(tensor, batch_size, num_processes, dim=0): + """ + Takes a `tensor` of arbitrary size and pads it so that it can work given `num_processes` needed dimensions. + + New tensors are just the last input repeated. + + E.g.: + Tensor: ([3,4,4]) Num processes: 4 Expected result shape: ([4,4,4]) + + """ + + def _pad_input_tensors(tensor, batch_size, num_processes, dim=0): + remainder = batch_size // num_processes + last_inputs = batch_size - (remainder * num_processes) + if batch_size // num_processes == 0: + to_pad = num_processes - batch_size + else: + to_pad = num_processes - (batch_size // num_processes) + # In the rare case that `to_pad` is negative, + # we need to pad the last inputs - the found `to_pad` + if last_inputs > to_pad & to_pad < 1: + to_pad = last_inputs - to_pad + old_size = tensor.shape + new_size = list(old_size) + new_size[0] = batch_size + to_pad + new_tensor = tensor.new_zeros(tuple(new_size)) + indices = tuple(slice(0, old_size[dim]) if i == dim else slice(None) for i in range(len(new_size))) + new_tensor[indices] = tensor + return new_tensor + + return recursively_apply( + _pad_input_tensors, + tensor, + error_on_other_type=True, + batch_size=batch_size, + num_processes=num_processes, + dim=dim, + ) + + +@verify_operation +def reduce(tensor, reduction="mean", scale=1.0): + """ + Recursively reduce the tensors in a nested list/tuple/dictionary of lists of tensors across all processes by the + mean of a given operation. + + Args: + tensor (nested list/tuple/dictionary of `torch.Tensor`): + The data to reduce. + reduction (`str`, *optional*, defaults to `"mean"`): + A reduction method. Can be of "mean", "sum", or "none" + scale (`float`, *optional*): + A default scaling value to be applied after the reduce, only valid on XLA. + + Returns: + The same data structure as `data` with all the tensors reduced. + """ + + def _reduce_across_processes(tensor, reduction="mean", scale=1.0): + state = PartialState() + cloned_tensor = tensor.clone() + if state.distributed_type == DistributedType.NO: + return cloned_tensor + if state.distributed_type == DistributedType.XLA: + # Some processes may have different HLO graphs than other + # processes, for example in the breakpoint API + # accelerator.set_trigger(). Use mark_step to make HLOs + # the same on all processes. + xm.mark_step() + xm.all_reduce(xm.REDUCE_SUM, [cloned_tensor], scale) + xm.mark_step() + elif state.distributed_type.value in TORCH_DISTRIBUTED_OPERATION_TYPES: + torch.distributed.all_reduce(cloned_tensor, ReduceOp.SUM) + if reduction == "mean": + cloned_tensor /= state.num_processes + return cloned_tensor + + return recursively_apply( + _reduce_across_processes, tensor, error_on_other_type=True, reduction=reduction, scale=scale + ) + + +def convert_to_fp32(tensor): + """ + Recursively converts the elements nested list/tuple/dictionary of tensors in FP16/BF16 precision to FP32. + + Args: + tensor (nested list/tuple/dictionary of `torch.Tensor`): + The data to convert from FP16/BF16 to FP32. + + Returns: + The same data structure as `tensor` with all tensors that were in FP16/BF16 precision converted to FP32. + """ + + def _convert_to_fp32(tensor): + return tensor.float() + + def _is_fp16_bf16_tensor(tensor): + return (is_torch_tensor(tensor) or hasattr(tensor, "dtype")) and tensor.dtype in ( + torch.float16, + torch.bfloat16, + ) + + return recursively_apply(_convert_to_fp32, tensor, test_type=_is_fp16_bf16_tensor) + + +class ConvertOutputsToFp32: + """ + Decorator to apply to a function outputting tensors (like a model forward pass) that ensures the outputs in FP16 + precision will be convert back to FP32. + + Args: + model_forward (`Callable`): + The function which outputs we want to treat. + + Returns: + The same function as `model_forward` but with converted outputs. + """ + + def __init__(self, model_forward): + self.model_forward = model_forward + update_wrapper(self, model_forward) + + def __call__(self, *args, **kwargs): + return convert_to_fp32(self.model_forward(*args, **kwargs)) + + def __getstate__(self): + raise pickle.PicklingError( + "Cannot pickle a prepared model with automatic mixed precision, please unwrap the model with `Accelerator.unwrap_model(model)` before pickling it." + ) + + +def convert_outputs_to_fp32(model_forward): + model_forward = ConvertOutputsToFp32(model_forward) + + def forward(*args, **kwargs): + return model_forward(*args, **kwargs) + + # To act like a decorator so that it can be popped when doing `extract_model_from_parallel` + forward.__wrapped__ = model_forward + + return forward + + +def find_device(data): + """ + Finds the device on which a nested dict/list/tuple of tensors lies (assuming they are all on the same device). + + Args: + (nested list/tuple/dictionary of `torch.Tensor`): The data we want to know the device of. + """ + if isinstance(data, Mapping): + for obj in data.values(): + device = find_device(obj) + if device is not None: + return device + elif isinstance(data, (tuple, list)): + for obj in data: + device = find_device(obj) + if device is not None: + return device + elif isinstance(data, torch.Tensor): + return data.device + + +@contextmanager +def GatheredParameters(params, modifier_rank=None, fwd_module=None, enabled=True): + """ + Wrapper around `deepspeed.runtime.zero.GatheredParameters`, but if Zero-3 is not enabled, will be a no-op context + manager. + """ + # We need to use the `AcceleratorState` here since it has access to the deepspeed plugin + if AcceleratorState().distributed_type != DistributedType.DEEPSPEED or ( + AcceleratorState().deepspeed_plugin is not None + and not AcceleratorState().deepspeed_plugin.is_zero3_init_enabled() + ): + gather_param_context = nullcontext() + else: + import deepspeed + + gather_param_context = deepspeed.zero.GatheredParameters( + params, modifier_rank=modifier_rank, fwd_module=fwd_module, enabled=enabled + ) + with gather_param_context: + yield diff --git a/accelerate/utils/other.py b/accelerate/utils/other.py new file mode 100644 index 0000000000000000000000000000000000000000..398639a164320b9f9f6a14ab82166ddaf8d75f4f --- /dev/null +++ b/accelerate/utils/other.py @@ -0,0 +1,568 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +import platform +import re +import socket +from codecs import encode +from collections import OrderedDict +from functools import partial, reduce +from types import MethodType +from typing import Optional + +import numpy as np +import torch +from packaging.version import Version +from safetensors.torch import save_file as safe_save_file + +from ..commands.config.default import write_basic_config # noqa: F401 +from ..logging import get_logger +from ..state import PartialState +from .constants import FSDP_PYTORCH_VERSION +from .dataclasses import DistributedType +from .imports import ( + is_deepspeed_available, + is_numpy_available, + is_torch_distributed_available, + is_torch_xla_available, + is_weights_only_available, +) +from .modeling import id_tensor_storage +from .transformer_engine import convert_model +from .versions import is_torch_version + + +logger = get_logger(__name__) + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + +def is_compiled_module(module: torch.nn.Module) -> bool: + """ + Check whether the module was compiled with torch.compile() + """ + if not hasattr(torch, "_dynamo"): + return False + + return isinstance(module, torch._dynamo.eval_frame.OptimizedModule) + + +def has_compiled_regions(module: torch.nn.Module) -> bool: + """ + Check whether the module has submodules that were compiled with `torch.compile()`. + """ + if not hasattr(torch, "_dynamo"): + return False + + if module._modules: + for submodule in module.modules(): + if isinstance(submodule, torch._dynamo.eval_frame.OptimizedModule): + return True + + return False + + +def is_repeated_blocks(module: torch.nn.Module) -> bool: + """ + Check whether the module is a repeated block, i.e. `torch.nn.ModuleList` with all children of the same class. This + is useful to determine whether we should apply regional compilation to the module. + """ + + return ( + isinstance(module, torch.nn.ModuleList) + and len(module) > 0 + and all(isinstance(m, module[0].__class__) for m in module) + ) + + +def has_repeated_blocks(module: torch.nn.Module) -> bool: + """ + Check whether the module has repeated blocks, i.e. `torch.nn.ModuleList` with all children of the same class, at + any level of the module hierarchy. This is useful to determine whether we should apply regional compilation to the + module. + """ + if module._modules: + for submodule in module.modules(): + if is_repeated_blocks(submodule): + return True + + return False + + +def compile_regions(module: torch.nn.Module, **compile_kwargs) -> torch.nn.Module: + """ + Performs regional compilation where we target repeated blocks of the same class and compile them sequentially to + hit the compiler's cache. For example, in `GPT2LMHeadModel`, the repeated block/class is `GPT2Block`, and can be + accessed as `model.transformer.h[0]`. The rest of the model (e.g. model.lm_head) is compiled separately. + + This allows us to speed up the compilation overhead / cold start of models like LLMs and Transformers in general. + See https://pytorch.org/tutorials/recipes/regional_compilation.html for more details. + + Args: + module (`torch.nn.Module`): + The model to compile. + **compile_kwargs: + Additional keyword arguments to pass to `torch.compile()`. + + Returns: + `torch.nn.Module`: A new instance of the model with some compiled regions. + + Example: + ```python + >>> from accelerate.utils import compile_regions + >>> from transformers import AutoModelForCausalLM + + >>> model = AutoModelForCausalLM.from_pretrained("gpt2") + >>> compiled_model = compile_regions(model, mode="reduce-overhead") + >>> compiled_model.transformer.h[0] + OptimizedModule( + (_orig_mod): GPT2Block( + (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True) + (attn): GPT2Attention( + (c_attn): Conv1D(nf=2304, nx=768) + (c_proj): Conv1D(nf=768, nx=768) + (attn_dropout): Dropout(p=0.1, inplace=False) + (resid_dropout): Dropout(p=0.1, inplace=False) + ) + (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True) + (mlp): GPT2MLP( + (c_fc): Conv1D(nf=3072, nx=768) + (c_proj): Conv1D(nf=768, nx=3072) + (act): NewGELUActivation() + (dropout): Dropout(p=0.1, inplace=False) + ) + ) + ) + ``` + """ + + def _compile_regions(module: torch.nn.Module, **compile_kwargs) -> torch.nn.Module: + if is_repeated_blocks(module): + new_module = torch.nn.ModuleList() + for submodule in module: + new_module.append(torch.compile(submodule, **compile_kwargs)) + elif has_repeated_blocks(module): + new_module = module.__class__.__new__(module.__class__) + new_module.__dict__.update(module.__dict__) + new_module._modules = {} + for name, submodule in module.named_children(): + new_module.add_module(name, _compile_regions(submodule, **compile_kwargs)) + else: + new_module = torch.compile(module, **compile_kwargs) + + return new_module + + new_module = _compile_regions(module, **compile_kwargs) + + if "_orig_mod" not in new_module.__dict__: + # Keeps a reference to the original module to decompile/unwrap it later + new_module.__dict__["_orig_mod"] = module + + return new_module + + +def compile_regions_deepspeed(module: torch.nn.Module, **compile_kwargs): + """ + Performs regional compilation the same way as `compile_regions`, but specifically for `DeepSpeedEngine.module`. + Since the model is wrapped in a `DeepSpeedEngine` and has many added hooks, offloaded parameters, etc that + `torch.compile(...)` interferes with, version of trgional compilation uses the inplace `module.compile()` method + instead. + + Args: + module (`torch.nn.Module`): + The model to compile. + **compile_kwargs: + Additional keyword arguments to pass to `module.compile()`. + """ + + if is_repeated_blocks(module): + for submodule in module: + submodule.compile(**compile_kwargs) + elif has_repeated_blocks(module): + for child in module.children(): + compile_regions_deepspeed(child, **compile_kwargs) + else: # leaf node + module.compile(**compile_kwargs) + + +def model_has_dtensor(model: torch.nn.Module) -> bool: + """ + Check if the model has DTensor parameters. + + Args: + model (`torch.nn.Module`): + The model to check. + + Returns: + `bool`: Whether the model has DTensor parameters. + """ + if is_torch_version(">=", "2.5.0"): + from torch.distributed.tensor import DTensor + else: + # from torch 2.0.0 (oldest supported accelerate torch version), DTensor is in torch.distributed._tensor + from torch.distributed._tensor import DTensor + + return any(isinstance(p, DTensor) for p in model.parameters()) + + +def extract_model_from_parallel( + model, keep_fp32_wrapper: bool = True, keep_torch_compile: bool = True, recursive: bool = False +): + """ + Extract a model from its distributed containers. + + Args: + model (`torch.nn.Module`): + The model to extract. + keep_fp32_wrapper (`bool`, *optional*): + Whether to remove mixed precision hooks from the model. + keep_torch_compile (`bool`, *optional*): + Whether to unwrap compiled model. + recursive (`bool`, *optional*, defaults to `False`): + Whether to recursively extract all cases of `module.module` from `model` as well as unwrap child sublayers + recursively, not just the top-level distributed containers. + + Returns: + `torch.nn.Module`: The extracted model. + """ + options = (torch.nn.parallel.DistributedDataParallel, torch.nn.DataParallel) + + is_compiled = is_compiled_module(model) + has_compiled = has_compiled_regions(model) + + compiled_model = None + if is_compiled: + compiled_model = model + model = model._orig_mod + elif has_compiled: + # Skip if top-level not compiled, subs stay wrapped + if "_orig_mod" in model.__dict__: + compiled_model = model + model = model.__dict__["_orig_mod"] + + if is_deepspeed_available(): + from deepspeed import DeepSpeedEngine + + options += (DeepSpeedEngine,) + + if is_torch_version(">=", FSDP_PYTORCH_VERSION) and is_torch_distributed_available(): + from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP + + options += (FSDP,) + + while isinstance(model, options): + model = model.module + + if recursive: + # This is needed in cases such as using FSDPv2 on XLA + def _recursive_unwrap(module): + # Wrapped modules are standardly wrapped as `module`, similar to the cases earlier + # with DDP, DataParallel, DeepSpeed, and FSDP + if hasattr(module, "module"): + unwrapped_module = _recursive_unwrap(module.module) + else: + unwrapped_module = module + # Next unwrap child sublayers recursively + for name, child in unwrapped_module.named_children(): + setattr(unwrapped_module, name, _recursive_unwrap(child)) + return unwrapped_module + + # Start with top-level + model = _recursive_unwrap(model) + + if not keep_fp32_wrapper: + forward = model.forward + original_forward = model.__dict__.pop("_original_forward", None) + if original_forward is not None: + while hasattr(forward, "__wrapped__"): + forward = forward.__wrapped__ + if forward == original_forward: + break + model.forward = MethodType(forward, model) + if getattr(model, "_converted_to_transformer_engine", False): + convert_model(model, to_transformer_engine=False) + + if keep_torch_compile and compiled_model is not None: + if is_compiled: + compiled_model._orig_mod = model + model = compiled_model + elif has_compiled: + compiled_model.__dict__["_orig_mod"] = model + model = compiled_model + + return model + + +def wait_for_everyone(): + """ + Introduces a blocking point in the script, making sure all processes have reached this point before continuing. + + + + Make sure all processes will reach this instruction otherwise one of your processes will hang forever. + + + """ + PartialState().wait_for_everyone() + + +def clean_state_dict_for_safetensors(state_dict: dict): + """ + Cleans the state dictionary from a model and removes tensor aliasing if present. + + Args: + state_dict (`dict`): + The state dictionary from a model + """ + ptrs = collections.defaultdict(list) + # When bnb serialization is used, weights in state dict can be strings + for name, tensor in state_dict.items(): + if not isinstance(tensor, str): + ptrs[id_tensor_storage(tensor)].append(name) + + # These are all pointers of tensors with shared memory + shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1} + warn_names = set() + for names in shared_ptrs.values(): + # When not all duplicates have been cleaned, we still remove those keys but put a clear warning. + # If the link between tensors was done at runtime then `from_pretrained` will not get + # the key back leading to random tensor. A proper warning will be shown + # during reload (if applicable), but since the file is not necessarily compatible with + # the config, better show a proper warning. + found_names = [name for name in names if name in state_dict] + warn_names.update(found_names[1:]) + for name in found_names[1:]: + del state_dict[name] + if len(warn_names) > 0: + logger.warning( + f"Removed shared tensor {warn_names} while saving. This should be OK, but check by verifying that you don't receive any warning while reloading", + ) + state_dict = {k: v.contiguous() if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()} + return state_dict + + +def save(obj, f, save_on_each_node: bool = False, safe_serialization: bool = False): + """ + Save the data to disk. Use in place of `torch.save()`. + + Args: + obj: + The data to save + f: + The file (or file-like object) to use to save the data + save_on_each_node (`bool`, *optional*, defaults to `False`): + Whether to only save on the global main process + safe_serialization (`bool`, *optional*, defaults to `False`): + Whether to save `obj` using `safetensors` or the traditional PyTorch way (that uses `pickle`). + """ + # When TorchXLA is enabled, it's necessary to transfer all data to the CPU before saving. + # Another issue arises with `id_tensor_storage`, which treats all XLA tensors as identical. + # If tensors remain on XLA, calling `clean_state_dict_for_safetensors` will result in only + # one XLA tensor remaining. + if PartialState().distributed_type == DistributedType.XLA: + obj = xm._maybe_convert_to_cpu(obj) + # Check if it's a model and remove duplicates + if safe_serialization: + save_func = partial(safe_save_file, metadata={"format": "pt"}) + if isinstance(obj, OrderedDict): + obj = clean_state_dict_for_safetensors(obj) + else: + save_func = torch.save + + if PartialState().is_main_process and not save_on_each_node: + save_func(obj, f) + elif PartialState().is_local_main_process and save_on_each_node: + save_func(obj, f) + + +# The following are considered "safe" globals to reconstruct various types of objects when using `weights_only=True` +# These should be added and then removed after loading in the file +np_core = np._core if is_numpy_available("2.0.0") else np.core +TORCH_SAFE_GLOBALS = [ + # numpy arrays are just numbers, not objects, so we can reconstruct them safely + np_core.multiarray._reconstruct, + np.ndarray, + # The following are needed for the RNG states + encode, + np.dtype, +] + +if is_numpy_available("1.25.0"): + TORCH_SAFE_GLOBALS.append(np.dtypes.UInt32DType) + + +def load(f, map_location=None, **kwargs): + """ + Compatible drop-in replacement of `torch.load()` which allows for `weights_only` to be used if `torch` version is + 2.4.0 or higher. Otherwise will ignore the kwarg. + + Will also add (and then remove) an exception for numpy arrays + + Args: + f: + The file (or file-like object) to use to load the data + map_location: + a function, `torch.device`, string or a dict specifying how to remap storage locations + **kwargs: + Additional keyword arguments to pass to `torch.load()`. + """ + try: + if is_weights_only_available(): + old_safe_globals = torch.serialization.get_safe_globals() + if "weights_only" not in kwargs: + kwargs["weights_only"] = True + torch.serialization.add_safe_globals(TORCH_SAFE_GLOBALS) + else: + kwargs.pop("weights_only", None) + loaded_obj = torch.load(f, map_location=map_location, **kwargs) + finally: + if is_weights_only_available(): + torch.serialization.clear_safe_globals() + if old_safe_globals: + torch.serialization.add_safe_globals(old_safe_globals) + return loaded_obj + + +def get_pretty_name(obj): + """ + Gets a pretty name from `obj`. + """ + if not hasattr(obj, "__qualname__") and not hasattr(obj, "__name__"): + obj = getattr(obj, "__class__", obj) + if hasattr(obj, "__qualname__"): + return obj.__qualname__ + if hasattr(obj, "__name__"): + return obj.__name__ + return str(obj) + + +def merge_dicts(source, destination): + """ + Recursively merges two dictionaries. + + Args: + source (`dict`): The dictionary to merge into `destination`. + destination (`dict`): The dictionary to merge `source` into. + """ + for key, value in source.items(): + if isinstance(value, dict): + node = destination.setdefault(key, {}) + merge_dicts(value, node) + else: + destination[key] = value + + return destination + + +def is_port_in_use(port: Optional[int] = None) -> bool: + """ + Checks if a port is in use on `localhost`. Useful for checking if multiple `accelerate launch` commands have been + run and need to see if the port is already in use. + """ + if port is None: + port = 29500 + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + return s.connect_ex(("localhost", port)) == 0 + + +def get_free_port() -> int: + """ + Gets a free port on `localhost`. Useful for automatic port selection when port 0 is specified in distributed + training scenarios. + + Returns: + int: An available port number + """ + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) # bind to port 0 for OS to assign a free port + return s.getsockname()[1] + + +def convert_bytes(size): + "Converts `size` from bytes to the largest possible unit" + for x in ["bytes", "KB", "MB", "GB", "TB"]: + if size < 1024.0: + return f"{round(size, 2)} {x}" + size /= 1024.0 + + return f"{round(size, 2)} PB" + + +def check_os_kernel(): + """Warns if the kernel version is below the recommended minimum on Linux.""" + # see issue #1929 + info = platform.uname() + system = info.system + if system != "Linux": + return + + _, version, *_ = re.split(r"(\d+\.\d+\.\d+)", info.release) + min_version = "5.5.0" + if Version(version) < Version(min_version): + msg = ( + f"Detected kernel version {version}, which is below the recommended minimum of {min_version}; this can " + "cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher." + ) + logger.warning(msg, main_process_only=True) + + +def recursive_getattr(obj, attr: str): + """ + Recursive `getattr`. + + Args: + obj: + A class instance holding the attribute. + attr (`str`): + The attribute that is to be retrieved, e.g. 'attribute1.attribute2'. + """ + + def _getattr(obj, attr): + return getattr(obj, attr) + + return reduce(_getattr, [obj] + attr.split(".")) + + +def get_module_children_bottom_up(model: torch.nn.Module, return_fqns: bool = False) -> list[torch.nn.Module]: + """Traverse the model in bottom-up order and return the children modules in that order. + + Args: + model (`torch.nn.Module`): the model to get the children of + + Returns: + `list[torch.nn.Module]`: a list of children modules of `model` in bottom-up order. The last element is the + `model` itself. + """ + top = model if not return_fqns else ("", model) + stack = [top] + ordered_modules = [] + while stack: + current_module = stack.pop() + if return_fqns: + current_module_name, current_module = current_module + for name, attr in current_module.named_children(): + if isinstance(attr, torch.nn.Module): + if return_fqns: + child_name = current_module_name + "." + name if current_module_name else name + stack.append((child_name, attr)) + else: + stack.append(attr) + if return_fqns: + ordered_modules.append((current_module_name, current_module)) + else: + ordered_modules.append(current_module) + return ordered_modules[::-1] diff --git a/accelerate/utils/random.py b/accelerate/utils/random.py new file mode 100644 index 0000000000000000000000000000000000000000..49639b417dfa41259223c468d6f56b5f7926502c --- /dev/null +++ b/accelerate/utils/random.py @@ -0,0 +1,165 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +from typing import Optional, Union + +import numpy as np +import torch + +from ..state import AcceleratorState +from .constants import CUDA_DISTRIBUTED_TYPES +from .dataclasses import DistributedType, RNGType +from .imports import ( + is_hpu_available, + is_mlu_available, + is_musa_available, + is_neuron_available, + is_npu_available, + is_sdaa_available, + is_torch_xla_available, + is_xpu_available, +) + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + +def set_seed(seed: int, device_specific: bool = False, deterministic: bool = False): + """ + Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`. + + Args: + seed (`int`): + The seed to set. + device_specific (`bool`, *optional*, defaults to `False`): + Whether to differ the seed on each device slightly with `self.process_index`. + deterministic (`bool`, *optional*, defaults to `False`): + Whether to use deterministic algorithms where available. Can slow down training. + """ + if device_specific: + seed += AcceleratorState().process_index + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if is_xpu_available(): + torch.xpu.manual_seed_all(seed) + elif is_npu_available(): + torch.npu.manual_seed_all(seed) + elif is_mlu_available(): + torch.mlu.manual_seed_all(seed) + elif is_sdaa_available(): + torch.sdaa.manual_seed_all(seed) + elif is_musa_available(): + torch.musa.manual_seed_all(seed) + elif is_hpu_available(): + torch.hpu.manual_seed_all(seed) + elif is_neuron_available(): + torch.neuron.manual_seed_all(seed) + else: + torch.cuda.manual_seed_all(seed) + # ^^ safe to call this function even if cuda is not available + if is_torch_xla_available(): + xm.set_rng_state(seed) + + if deterministic: + torch.use_deterministic_algorithms(True) + + +def synchronize_rng_state(rng_type: Optional[RNGType] = None, generator: Optional[torch.Generator] = None): + # Get the proper rng state + if rng_type == RNGType.TORCH: + rng_state = torch.get_rng_state() + elif rng_type == RNGType.CUDA: + rng_state = torch.cuda.get_rng_state() + elif rng_type == RNGType.XLA: + assert is_torch_xla_available(), "Can't synchronize XLA seeds as torch_xla is unavailable." + rng_state = torch.tensor(xm.get_rng_state()) + elif rng_type == RNGType.NPU: + assert is_npu_available(), "Can't synchronize NPU seeds on an environment without NPUs." + rng_state = torch.npu.get_rng_state() + elif rng_type == RNGType.MLU: + assert is_mlu_available(), "Can't synchronize MLU seeds on an environment without MLUs." + rng_state = torch.mlu.get_rng_state() + elif rng_type == RNGType.SDAA: + assert is_sdaa_available(), "Can't synchronize SDAA seeds on an environment without SDAAs." + rng_state = torch.sdaa.get_rng_state() + elif rng_type == RNGType.MUSA: + assert is_musa_available(), "Can't synchronize MUSA seeds on an environment without MUSAs." + rng_state = torch.musa.get_rng_state() + elif rng_type == RNGType.XPU: + assert is_xpu_available(), "Can't synchronize XPU seeds on an environment without XPUs." + rng_state = torch.xpu.get_rng_state() + elif rng_type == RNGType.HPU: + assert is_hpu_available(), "Can't synchronize HPU seeds on an environment without HPUs." + rng_state = torch.hpu.get_rng_state() + elif rng_type == RNGType.NEURON: + assert is_neuron_available(), "Can't synchronize Neuron seeds on an environment without Neuron Cores." + rng_state = torch.neuron.get_rng_state() + elif rng_type == RNGType.GENERATOR: + assert generator is not None, "Need a generator to synchronize its seed." + rng_state = generator.get_state() + + # Broadcast the rng state from device 0 to other devices + state = AcceleratorState() + if state.distributed_type == DistributedType.XLA: + rng_state = rng_state.to(xm.xla_device()) + xm.collective_broadcast([rng_state]) + xm.mark_step() + rng_state = rng_state.cpu() + elif ( + state.distributed_type in CUDA_DISTRIBUTED_TYPES + or state.distributed_type == DistributedType.MULTI_MLU + or state.distributed_type == DistributedType.MULTI_SDAA + or state.distributed_type == DistributedType.MULTI_MUSA + or state.distributed_type == DistributedType.MULTI_NPU + or state.distributed_type == DistributedType.MULTI_XPU + or state.distributed_type == DistributedType.MULTI_HPU + or state.distributed_type == DistributedType.MULTI_NEURON + ): + rng_state = rng_state.to(state.device) + torch.distributed.broadcast(rng_state, 0) + rng_state = rng_state.cpu() + elif state.distributed_type == DistributedType.MULTI_CPU: + torch.distributed.broadcast(rng_state, 0) + + # Set the broadcast rng state + if rng_type == RNGType.TORCH: + torch.set_rng_state(rng_state) + elif rng_type == RNGType.CUDA: + torch.cuda.set_rng_state(rng_state) + elif rng_type == RNGType.NPU: + torch.npu.set_rng_state(rng_state) + elif rng_type == RNGType.MLU: + torch.mlu.set_rng_state(rng_state) + elif rng_type == RNGType.SDAA: + torch.sdaa.set_rng_state(rng_state) + elif rng_type == RNGType.MUSA: + torch.musa.set_rng_state(rng_state) + elif rng_type == RNGType.XPU: + torch.xpu.set_rng_state(rng_state) + elif rng_type == RNGType.HPU: + torch.hpu.set_rng_state(rng_state) + elif rng_type == RNGType.NEURON: + torch.neuron.set_rng_state(rng_state) + elif rng_type == RNGType.XLA: + xm.set_rng_state(rng_state.item()) + elif rng_type == RNGType.GENERATOR: + generator.set_state(rng_state) + + +def synchronize_rng_states(rng_types: list[Union[str, RNGType]], generator: Optional[torch.Generator] = None): + for rng_type in rng_types: + synchronize_rng_state(RNGType(rng_type), generator=generator) diff --git a/accelerate/utils/rich.py b/accelerate/utils/rich.py new file mode 100644 index 0000000000000000000000000000000000000000..2d48661b7fcef92ef1168b74cc275c6d3ccc67a1 --- /dev/null +++ b/accelerate/utils/rich.py @@ -0,0 +1,24 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .imports import is_rich_available + + +if is_rich_available(): + from rich.traceback import install + + install(show_locals=False) + +else: + raise ModuleNotFoundError("To use the rich extension, install rich with `pip install rich`") diff --git a/accelerate/utils/torch_xla.py b/accelerate/utils/torch_xla.py new file mode 100644 index 0000000000000000000000000000000000000000..140133926c2f88d39c70f5a9f46a08f88bed36da --- /dev/null +++ b/accelerate/utils/torch_xla.py @@ -0,0 +1,51 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib.metadata +import subprocess +import sys + + +def install_xla(upgrade: bool = False): + """ + Helper function to install appropriate xla wheels based on the `torch` version in Google Colaboratory. + + Args: + upgrade (`bool`, *optional*, defaults to `False`): + Whether to upgrade `torch` and install the latest `torch_xla` wheels. + + Example: + + ```python + >>> from accelerate.utils import install_xla + + >>> install_xla(upgrade=True) + ``` + """ + in_colab = False + if "IPython" in sys.modules: + in_colab = "google.colab" in str(sys.modules["IPython"].get_ipython()) + + if in_colab: + if upgrade: + torch_install_cmd = ["pip", "install", "-U", "torch"] + subprocess.run(torch_install_cmd, check=True) + # get the current version of torch + torch_version = importlib.metadata.version("torch") + torch_version_trunc = torch_version[: torch_version.rindex(".")] + xla_wheel = f"https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-{torch_version_trunc}-cp37-cp37m-linux_x86_64.whl" + xla_install_cmd = ["pip", "install", xla_wheel] + subprocess.run(xla_install_cmd, check=True) + else: + raise RuntimeError("`install_xla` utility works only on google colab.") diff --git a/accelerate/utils/tqdm.py b/accelerate/utils/tqdm.py new file mode 100644 index 0000000000000000000000000000000000000000..2d4873c1573eb2ee7392162f440a76d4f07cd8ce --- /dev/null +++ b/accelerate/utils/tqdm.py @@ -0,0 +1,43 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .imports import is_tqdm_available + + +if is_tqdm_available(): + from tqdm.auto import tqdm as _tqdm + +from ..state import PartialState + + +def tqdm(*args, main_process_only: bool = True, **kwargs): + """ + Wrapper around `tqdm.tqdm` that optionally displays only on the main process. + + Args: + main_process_only (`bool`, *optional*): + Whether to display the progress bar only on the main process + """ + if not is_tqdm_available(): + raise ImportError("Accelerate's `tqdm` module requires `tqdm` to be installed. Please run `pip install tqdm`.") + if len(args) > 0 and isinstance(args[0], bool): + raise ValueError( + "Passing `True` or `False` as the first argument to Accelerate's `tqdm` wrapper is unsupported. " + "Please use the `main_process_only` keyword argument instead." + ) + disable = kwargs.pop("disable", False) + if main_process_only and not disable: + disable = PartialState().local_process_index != 0 + return _tqdm(*args, **kwargs, disable=disable) diff --git a/accelerate/utils/transformer_engine.py b/accelerate/utils/transformer_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..31b61364fe2630f2b03a0ac64749333b14f6eada --- /dev/null +++ b/accelerate/utils/transformer_engine.py @@ -0,0 +1,186 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from types import MethodType + +import torch.nn as nn + +from .imports import is_hpu_available, is_transformer_engine_available +from .operations import GatheredParameters + + +# Do not import `transformer_engine` at package level to avoid potential issues + + +def convert_model(model, to_transformer_engine=True, _convert_linear=True, _convert_ln=True): + """ + Recursively converts the linear and layernorm layers of a model to their `transformers_engine` counterpart. + """ + if not is_transformer_engine_available(): + raise ImportError("Using `convert_model` requires transformer_engine to be installed.") + + if is_hpu_available(): + import intel_transformer_engine as te + + if not hasattr(te, "LayerNorm"): + # HPU does not have a LayerNorm implementation in TE + te.LayerNorm = nn.LayerNorm + else: + import transformer_engine.pytorch as te + + for name, module in model.named_children(): + if isinstance(module, nn.Linear) and to_transformer_engine and _convert_linear: + has_bias = module.bias is not None + params_to_gather = [module.weight] + if has_bias: + params_to_gather.append(module.bias) + + with GatheredParameters(params_to_gather, modifier_rank=0): + if any(p % 16 != 0 for p in module.weight.shape): + return + te_module = te.Linear( + module.in_features, module.out_features, bias=has_bias, params_dtype=module.weight.dtype + ) + te_module.weight.copy_(module.weight) + if has_bias: + te_module.bias.copy_(module.bias) + + setattr(model, name, te_module) + # Note: @xrsrke (Phuc) found that te.LayerNorm doesn't have any real memory savings or speedups over nn.LayerNorm + elif isinstance(module, nn.LayerNorm) and to_transformer_engine and _convert_ln: + with GatheredParameters([module.weight, module.bias], modifier_rank=0): + has_bias = module.bias is not None + te_module = te.LayerNorm(module.normalized_shape[0], eps=module.eps, params_dtype=module.weight.dtype) + te_module.weight.copy_(module.weight) + if has_bias: + te_module.bias.copy_(module.bias) + + setattr(model, name, te_module) + elif isinstance(module, te.Linear) and not to_transformer_engine and _convert_linear: + has_bias = module.bias is not None + new_module = nn.Linear( + module.in_features, module.out_features, bias=has_bias, params_dtype=module.weight.dtype + ) + new_module.weight.copy_(module.weight) + if has_bias: + new_module.bias.copy_(module.bias) + + setattr(model, name, new_module) + elif isinstance(module, te.LayerNorm) and not to_transformer_engine and _convert_ln: + new_module = nn.LayerNorm(module.normalized_shape[0], eps=module.eps, params_dtype=module.weight.dtype) + new_module.weight.copy_(module.weight) + new_module.bias.copy_(module.bias) + + setattr(model, name, new_module) + else: + convert_model( + module, + to_transformer_engine=to_transformer_engine, + _convert_linear=_convert_linear, + _convert_ln=_convert_ln, + ) + + +def has_transformer_engine_layers(model): + """ + Returns whether a given model has some `transformer_engine` layer or not. + """ + if not is_transformer_engine_available(): + raise ImportError("Using `has_transformer_engine_layers` requires transformer_engine to be installed.") + + if is_hpu_available(): + import intel_transformer_engine as te + + module_cls_to_check = te.Linear + else: + import transformer_engine.pytorch as te + + module_cls_to_check = (te.LayerNorm, te.Linear, te.TransformerLayer) + + for m in model.modules(): + if isinstance(m, module_cls_to_check): + return True + + return False + + +def contextual_fp8_autocast(model_forward, fp8_recipe, use_during_eval=False): + """ + Wrapper for a model's forward method to apply FP8 autocast. Is context aware, meaning that by default it will + disable FP8 autocast during eval mode, which is generally better for more accurate metrics. + """ + if not is_transformer_engine_available(): + raise ImportError("Using `contextual_fp8_autocast` requires transformer_engine to be installed.") + + if is_hpu_available(): + from intel_transformer_engine import fp8_autocast + else: + from transformer_engine.pytorch import fp8_autocast + + def forward(self, *args, **kwargs): + enabled = use_during_eval or self.training + with fp8_autocast(enabled=enabled, fp8_recipe=fp8_recipe): + return model_forward(*args, **kwargs) + + # To act like a decorator so that it can be popped when doing `extract_model_from_parallel` + forward.__wrapped__ = model_forward + + return forward + + +def apply_fp8_autowrap(model, fp8_recipe_handler): + """ + Applies FP8 context manager to the model's forward method + """ + if not is_transformer_engine_available(): + raise ImportError("Using `apply_fp8_autowrap` requires transformer_engine to be installed.") + + if is_hpu_available(): + import intel_transformer_engine.recipe as te_recipe + + is_fp8_block_scaling_available = False + message = "MXFP8 block scaling is not available on HPU." + + else: + import transformer_engine.common.recipe as te_recipe + from transformer_engine.pytorch.fp8 import check_mxfp8_support + + is_fp8_block_scaling_available, message = check_mxfp8_support() + + kwargs = fp8_recipe_handler.to_kwargs() if fp8_recipe_handler is not None else {} + if "fp8_format" in kwargs: + kwargs["fp8_format"] = getattr(te_recipe.Format, kwargs["fp8_format"]) + use_during_eval = kwargs.pop("use_autocast_during_eval", False) + use_mxfp8_block_scaling = kwargs.pop("use_mxfp8_block_scaling", False) + + if use_mxfp8_block_scaling and not is_fp8_block_scaling_available: + raise ValueError(f"MXFP8 block scaling is not available: {message}") + + if use_mxfp8_block_scaling: + if "amax_compute_algo" in kwargs: + raise ValueError("`amax_compute_algo` is not supported for MXFP8 block scaling.") + if "amax_history_len" in kwargs: + raise ValueError("`amax_history_len` is not supported for MXFP8 block scaling.") + fp8_recipe = te_recipe.MXFP8BlockScaling(**kwargs) + else: + fp8_recipe = te_recipe.DelayedScaling(**kwargs) + + new_forward = contextual_fp8_autocast(model.forward, fp8_recipe, use_during_eval) + + if hasattr(model.forward, "__func__"): + model.forward = MethodType(new_forward, model) + else: + model.forward = new_forward + + return model diff --git a/accelerate/utils/versions.py b/accelerate/utils/versions.py new file mode 100644 index 0000000000000000000000000000000000000000..985c918f0e057bacc70c372f6906071bb73db577 --- /dev/null +++ b/accelerate/utils/versions.py @@ -0,0 +1,56 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib.metadata +from typing import Union + +from packaging.version import Version, parse + +from .constants import STR_OPERATION_TO_FUNC + + +torch_version = parse(importlib.metadata.version("torch")) + + +def compare_versions(library_or_version: Union[str, Version], operation: str, requirement_version: str): + """ + Compares a library version to some requirement using a given operation. + + Args: + library_or_version (`str` or `packaging.version.Version`): + A library name or a version to check. + operation (`str`): + A string representation of an operator, such as `">"` or `"<="`. + requirement_version (`str`): + The version to compare the library version against + """ + if operation not in STR_OPERATION_TO_FUNC.keys(): + raise ValueError(f"`operation` must be one of {list(STR_OPERATION_TO_FUNC.keys())}, received {operation}") + operation = STR_OPERATION_TO_FUNC[operation] + if isinstance(library_or_version, str): + library_or_version = parse(importlib.metadata.version(library_or_version)) + return operation(library_or_version, parse(requirement_version)) + + +def is_torch_version(operation: str, version: str): + """ + Compares the current PyTorch version to a given reference with an operation. + + Args: + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A string version of PyTorch + """ + return compare_versions(torch_version, operation, version) diff --git a/bitsandbytes-0.49.2.dist-info/licenses/LICENSE b/bitsandbytes-0.49.2.dist-info/licenses/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..b96dcb0480a0b0be0727976e5202a1e7b23edc3f --- /dev/null +++ b/bitsandbytes-0.49.2.dist-info/licenses/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) Facebook, Inc. and its affiliates. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/certifi-2026.2.25.dist-info/licenses/LICENSE b/certifi-2026.2.25.dist-info/licenses/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..62b076cdee58ec8f34034141ba0befd9015b0c7e --- /dev/null +++ b/certifi-2026.2.25.dist-info/licenses/LICENSE @@ -0,0 +1,20 @@ +This package contains a modified version of ca-bundle.crt: + +ca-bundle.crt -- Bundle of CA Root Certificates + +This is a bundle of X.509 certificates of public Certificate Authorities +(CA). These were automatically extracted from Mozilla's root certificates +file (certdata.txt). This file can be found in the mozilla source tree: +https://hg.mozilla.org/mozilla-central/file/tip/security/nss/lib/ckfw/builtins/certdata.txt +It contains the certificates in PEM format and therefore +can be directly used with curl / libcurl / php_curl, or with +an Apache+mod_ssl webserver for SSL client authentication. +Just configure this file as the SSLCACertificateFile.# + +***** BEGIN LICENSE BLOCK ***** +This Source Code Form is subject to the terms of the Mozilla Public License, +v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain +one at http://mozilla.org/MPL/2.0/. + +***** END LICENSE BLOCK ***** +@(#) $RCSfile: certdata.txt,v $ $Revision: 1.80 $ $Date: 2011/11/03 15:11:58 $ diff --git a/click-8.3.1.dist-info/licenses/LICENSE.txt b/click-8.3.1.dist-info/licenses/LICENSE.txt new file mode 100644 index 0000000000000000000000000000000000000000..d12a849186982399c537c5b9a8fd77bf2edd5eab --- /dev/null +++ b/click-8.3.1.dist-info/licenses/LICENSE.txt @@ -0,0 +1,28 @@ +Copyright 2014 Pallets + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + +1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED +TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/cuda_bindings-12.9.4.dist-info/licenses/LICENSE b/cuda_bindings-12.9.4.dist-info/licenses/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..b7d042fcee35d11ffb37ffa7b8fb7d2bee3f7999 --- /dev/null +++ b/cuda_bindings-12.9.4.dist-info/licenses/LICENSE @@ -0,0 +1,48 @@ +NVIDIA SOFTWARE LICENSE + +This license is a legal agreement between you and NVIDIA Corporation ("NVIDIA") and governs your use of the NVIDIA CUDA Python software and materials provided hereunder ("SOFTWARE"). + +This license can be accepted only by an adult of legal age of majority in the country in which the SOFTWARE is used. If you are under the legal age of majority, you must ask your parent or legal guardian to consent to this license. By taking delivery of the SOFTWARE, you affirm that you have reached the legal age of majority, you accept the terms of this license, and you take legal and financial responsibility for the actions of your permitted users. + +You agree to use the SOFTWARE only for purposes that are permitted by (a) this license, and (b) any applicable law, regulation or generally accepted practices or guidelines in the relevant jurisdictions. + +1. LICENSE. Subject to the terms of this license, NVIDIA grants you a non-exclusive limited license to: (a) install and use the SOFTWARE, and (b) distribute the SOFTWARE subject to the distribution requirements described in this license. NVIDIA reserves all rights, title and interest in and to the SOFTWARE not expressly granted to you under this license. + +2. DISTRIBUTION REQUIREMENTS. These are the distribution requirements for you to exercise the distribution grant: +a. The terms under which you distribute the SOFTWARE must be consistent with the terms of this license, including (without limitation) terms relating to the license grant and license restrictions and protection of NVIDIA's intellectual property rights. +b. You agree to notify NVIDIA in writing of any known or suspected distribution or use of the SOFTWARE not in compliance with the requirements of this license, and to enforce the terms of your agreements with respect to distributed SOFTWARE. + +3. LIMITATIONS. Your license to use the SOFTWARE is restricted as follows: +a. The SOFTWARE is licensed for you to develop applications only for use in systems with NVIDIA GPUs. +b. You may not reverse engineer, decompile or disassemble, or remove copyright or other proprietary notices from any portion of the SOFTWARE or copies of the SOFTWARE. +c. You may not modify or create derivative works of any portion of the SOFTWARE. +d. You may not bypass, disable, or circumvent any technical measure, encryption, security, digital rights management or authentication mechanism in the SOFTWARE. +e. You may not use the SOFTWARE in any manner that would cause it to become subject to an open source software license. As examples, licenses that require as a condition of use, modification, and/or distribution that the SOFTWARE be (i) disclosed or distributed in source code form; (ii) licensed for the purpose of making derivative works; or (iii) redistributable at no charge. +f. Unless you have an agreement with NVIDIA for this purpose, you may not use the SOFTWARE with any system or application where the use or failure of the system or application can reasonably be expected to threaten or result in personal injury, death, or catastrophic loss. Examples include use in avionics, navigation, military, medical, life support or other life critical applications. NVIDIA does not design, test or manufacture the SOFTWARE for these critical uses and NVIDIA shall not be liable to you or any third party, in whole or in part, for any claims or damages arising from such uses. +g. You agree to defend, indemnify and hold harmless NVIDIA and its affiliates, and their respective employees, contractors, agents, officers and directors, from and against any and all claims, damages, obligations, losses, liabilities, costs or debt, fines, restitutions and expenses (including but not limited to attorney's fees and costs incident to establishing the right of indemnification) arising out of or related to use of the SOFTWARE outside of the scope of this Agreement, or not in compliance with its terms. + +4. PRE-RELEASE. SOFTWARE versions identified as alpha, beta, preview, early access or otherwise as pre-release may not be fully functional, may contain errors or design flaws, and may have reduced or different security, privacy, availability, and reliability standards relative to commercial versions of NVIDIA software and materials. You may use a pre-release SOFTWARE version at your own risk, understanding that these versions are not intended for use in production or business-critical systems. + +5. OWNERSHIP. The SOFTWARE and the related intellectual property rights therein are and will remain the sole and exclusive property of NVIDIA or its licensors. The SOFTWARE is copyrighted and protected by the laws of the United States and other countries, and international treaty provisions. NVIDIA may make changes to the SOFTWARE, at any time without notice, but is not obligated to support or update the SOFTWARE. + +6. COMPONENTS UNDER OTHER LICENSES. The SOFTWARE may include NVIDIA or third-party components with separate legal notices or terms as may be described in proprietary notices accompanying the SOFTWARE. If and to the extent there is a conflict between the terms in this license and the license terms associated with a component, the license terms associated with the components control only to the extent necessary to resolve the conflict. + +7. FEEDBACK. You may, but don't have to, provide to NVIDIA any Feedback. "Feedback" means any suggestions, bug fixes, enhancements, modifications, feature requests or other feedback regarding the SOFTWARE. For any Feedback that you voluntarily provide, you hereby grant NVIDIA and its affiliates a perpetual, non-exclusive, worldwide, irrevocable license to use, reproduce, modify, license, sublicense (through multiple tiers of sublicensees), and distribute (through multiple tiers of distributors) the Feedback without the payment of any royalties or fees to you. NVIDIA will use Feedback at its choice. + +8. NO WARRANTIES. THE SOFTWARE IS PROVIDED "AS IS" WITHOUT ANY EXPRESS OR IMPLIED WARRANTY OF ANY KIND INCLUDING, BUT NOT LIMITED TO, WARRANTIES OF MERCHANTABILITY, NONINFRINGEMENT, OR FITNESS FOR A PARTICULAR PURPOSE. NVIDIA DOES NOT WARRANT THAT THE SOFTWARE WILL MEET YOUR REQUIREMENTS OR THAT THE OPERATION THEREOF WILL BE UNINTERRUPTED OR ERROR-FREE, OR THAT ALL ERRORS WILL BE CORRECTED. + +9. LIMITATIONS OF LIABILITY. TO THE MAXIMUM EXTENT PERMITTED BY LAW, NVIDIA AND ITS AFFILIATES SHALL NOT BE LIABLE FOR ANY SPECIAL, INCIDENTAL, PUNITIVE OR CONSEQUENTIAL DAMAGES, OR ANY LOST PROFITS, PROJECT DELAYS, LOSS OF USE, LOSS OF DATA OR LOSS OF GOODWILL, OR THE COSTS OF PROCURING SUBSTITUTE PRODUCTS, ARISING OUT OF OR IN CONNECTION WITH THIS LICENSE OR THE USE OR PERFORMANCE OF THE SOFTWARE, WHETHER SUCH LIABILITY ARISES FROM ANY CLAIM BASED UPON BREACH OF CONTRACT, BREACH OF WARRANTY, TORT (INCLUDING NEGLIGENCE), PRODUCT LIABILITY OR ANY OTHER CAUSE OF ACTION OR THEORY OF LIABILITY, EVEN IF NVIDIA HAS PREVIOUSLY BEEN ADVISED OF, OR COULD REASONABLY HAVE FORESEEN, THE POSSIBILITY OF SUCH DAMAGES. IN NO EVENT WILL NVIDIA'S AND ITS AFFILIATES TOTAL CUMULATIVE LIABILITY UNDER OR ARISING OUT OF THIS LICENSE EXCEED US$10.00. THE NATURE OF THE LIABILITY OR THE NUMBER OF CLAIMS OR SUITS SHALL NOT ENLARGE OR EXTEND THIS LIMIT. + +10. TERMINATION. Your rights under this license will terminate automatically without notice from NVIDIA if you fail to comply with any term and condition of this license or if you commence or participate in any legal proceeding against NVIDIA with respect to the SOFTWARE. NVIDIA may terminate this license with advance written notice to you if NVIDIA decides to no longer provide the SOFTWARE in a country or, in NVIDIA's sole discretion, the continued use of it is no longer commercially viable. Upon any termination of this license, you agree to promptly discontinue use of the SOFTWARE and destroy all copies in your possession or control. Your prior distributions in accordance with this license are not affected by the termination of this license. All provisions of this license will survive termination, except for the license granted to you. + +11. APPLICABLE LAW. This license will be governed in all respects by the laws of the United States and of the State of Delaware as those laws are applied to contracts entered into and performed entirely within Delaware by Delaware residents, without regard to the conflicts of laws principles. The United Nations Convention on Contracts for the International Sale of Goods is specifically disclaimed. You agree to all terms of this Agreement in the English language. The state or federal courts residing in Santa Clara County, California shall have exclusive jurisdiction over any dispute or claim arising out of this license. Notwithstanding this, you agree that NVIDIA shall still be allowed to apply for injunctive remedies or an equivalent type of urgent legal relief in any jurisdiction. + +12. NO ASSIGNMENT. This license and your rights and obligations thereunder may not be assigned by you by any means or operation of law without NVIDIA's permission. Any attempted assignment not approved by NVIDIA in writing shall be void and of no effect. + +13. EXPORT. The SOFTWARE is subject to United States export laws and regulations. You agree that you will not ship, transfer or export the SOFTWARE into any country, or use the SOFTWARE in any manner, prohibited by the United States Bureau of Industry and Security or economic sanctions regulations administered by the U.S. Department of Treasury's Office of Foreign Assets Control (OFAC), or any applicable export laws, restrictions or regulations. These laws include restrictions on destinations, end users and end use. By accepting this license, you confirm that you are not a resident or citizen of any country currently embargoed by the U.S. and that you are not otherwise prohibited from receiving the SOFTWARE. + +14. GOVERNMENT USE. The SOFTWARE has been developed entirely at private expense and is "commercial items" consisting of "commercial computer software" and "commercial computer software documentation" provided with RESTRICTED RIGHTS. Use, duplication or disclosure by the U.S. Government or a U.S. Government subcontractor is subject to the restrictions in this license pursuant to DFARS 227.7202-3(a) or as set forth in subparagraphs (b)(1) and (2) of the Commercial Computer Software - Restricted Rights clause at FAR 52.227-19, as applicable. Contractor/manufacturer is NVIDIA, 2788 San Tomas Expressway, Santa Clara, CA 95051. + +15. ENTIRE AGREEMENT. This license is the final, complete and exclusive agreement between the parties relating to the subject matter of this license and supersedes all prior or contemporaneous understandings and agreements relating to this subject matter, whether oral or written. If any court of competent jurisdiction determines that any provision of this license is illegal, invalid or unenforceable, the remaining provisions will remain in full force and effect. This license may only be modified in a writing signed by an authorized representative of each party. + +(v. May 12, 2021) diff --git a/diffusers-0.37.0.dist-info/licenses/LICENSE b/diffusers-0.37.0.dist-info/licenses/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..038e32f6445e8f265bde482613cf0d2f43d86dbc --- /dev/null +++ b/diffusers-0.37.0.dist-info/licenses/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, Any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/diffusers/experimental/rl/__init__.py b/diffusers/experimental/rl/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7b338d3173e12d478b6b6d6fd0e50650a0ab5a4c --- /dev/null +++ b/diffusers/experimental/rl/__init__.py @@ -0,0 +1 @@ +from .value_guided_sampling import ValueGuidedRLPipeline diff --git a/diffusers/experimental/rl/value_guided_sampling.py b/diffusers/experimental/rl/value_guided_sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..c69d308ecc6894edd39557d7f16fc01914e4bd85 --- /dev/null +++ b/diffusers/experimental/rl/value_guided_sampling.py @@ -0,0 +1,153 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch +import tqdm + +from ...models.unets.unet_1d import UNet1DModel +from ...pipelines import DiffusionPipeline +from ...utils.dummy_pt_objects import DDPMScheduler +from ...utils.torch_utils import randn_tensor + + +class ValueGuidedRLPipeline(DiffusionPipeline): + r""" + Pipeline for value-guided sampling from a diffusion model trained to predict sequences of states. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Parameters: + value_function ([`UNet1DModel`]): + A specialized UNet for fine-tuning trajectories base on reward. + unet ([`UNet1DModel`]): + UNet architecture to denoise the encoded trajectories. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded trajectories. Default for this + application is [`DDPMScheduler`]. + env (): + An environment following the OpenAI gym API to act in. For now only Hopper has pretrained models. + """ + + def __init__( + self, + value_function: UNet1DModel, + unet: UNet1DModel, + scheduler: DDPMScheduler, + env, + ): + super().__init__() + + self.register_modules(value_function=value_function, unet=unet, scheduler=scheduler, env=env) + + self.data = env.get_dataset() + self.means = {} + for key in self.data.keys(): + try: + self.means[key] = self.data[key].mean() + except: # noqa: E722 + pass + self.stds = {} + for key in self.data.keys(): + try: + self.stds[key] = self.data[key].std() + except: # noqa: E722 + pass + self.state_dim = env.observation_space.shape[0] + self.action_dim = env.action_space.shape[0] + + def normalize(self, x_in, key): + return (x_in - self.means[key]) / self.stds[key] + + def de_normalize(self, x_in, key): + return x_in * self.stds[key] + self.means[key] + + def to_torch(self, x_in): + if isinstance(x_in, dict): + return {k: self.to_torch(v) for k, v in x_in.items()} + elif torch.is_tensor(x_in): + return x_in.to(self.unet.device) + return torch.tensor(x_in, device=self.unet.device) + + def reset_x0(self, x_in, cond, act_dim): + for key, val in cond.items(): + x_in[:, key, act_dim:] = val.clone() + return x_in + + def run_diffusion(self, x, conditions, n_guide_steps, scale): + batch_size = x.shape[0] + y = None + for i in tqdm.tqdm(self.scheduler.timesteps): + # create batch of timesteps to pass into model + timesteps = torch.full((batch_size,), i, device=self.unet.device, dtype=torch.long) + for _ in range(n_guide_steps): + with torch.enable_grad(): + x.requires_grad_() + + # permute to match dimension for pre-trained models + y = self.value_function(x.permute(0, 2, 1), timesteps).sample + grad = torch.autograd.grad([y.sum()], [x])[0] + + posterior_variance = self.scheduler._get_variance(i) + model_std = torch.exp(0.5 * posterior_variance) + grad = model_std * grad + + grad[timesteps < 2] = 0 + x = x.detach() + x = x + scale * grad + x = self.reset_x0(x, conditions, self.action_dim) + + prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1) + + # TODO: verify deprecation of this kwarg + x = self.scheduler.step(prev_x, i, x)["prev_sample"] + + # apply conditions to the trajectory (set the initial state) + x = self.reset_x0(x, conditions, self.action_dim) + x = self.to_torch(x) + return x, y + + def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, scale=0.1): + # normalize the observations and create batch dimension + obs = self.normalize(obs, "observations") + obs = obs[None].repeat(batch_size, axis=0) + + conditions = {0: self.to_torch(obs)} + shape = (batch_size, planning_horizon, self.state_dim + self.action_dim) + + # generate initial noise and apply our conditions (to make the trajectories start at current state) + x1 = randn_tensor(shape, device=self.unet.device) + x = self.reset_x0(x1, conditions, self.action_dim) + x = self.to_torch(x) + + # run the diffusion process + x, y = self.run_diffusion(x, conditions, n_guide_steps, scale) + + # sort output trajectories by value + sorted_idx = y.argsort(0, descending=True).squeeze() + sorted_values = x[sorted_idx] + actions = sorted_values[:, :, : self.action_dim] + actions = actions.detach().cpu().numpy() + denorm_actions = self.de_normalize(actions, key="actions") + + # select the action with the highest value + if y is not None: + selected_index = 0 + else: + # if we didn't run value guiding, select a random action + selected_index = np.random.randint(0, batch_size) + + denorm_actions = denorm_actions[selected_index, 0] + return denorm_actions diff --git a/diffusers/modular_pipelines/flux/__init__.py b/diffusers/modular_pipelines/flux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4754ed01ce6aee9c85259fef1774ea96fbd27009 --- /dev/null +++ b/diffusers/modular_pipelines/flux/__init__.py @@ -0,0 +1,49 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["modular_blocks_flux"] = ["FluxAutoBlocks"] + _import_structure["modular_blocks_flux_kontext"] = ["FluxKontextAutoBlocks"] + _import_structure["modular_pipeline"] = ["FluxKontextModularPipeline", "FluxModularPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .modular_blocks_flux import FluxAutoBlocks + from .modular_blocks_flux_kontext import FluxKontextAutoBlocks + from .modular_pipeline import FluxKontextModularPipeline, FluxModularPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/modular_pipelines/flux/before_denoise.py b/diffusers/modular_pipelines/flux/before_denoise.py new file mode 100644 index 0000000000000000000000000000000000000000..c28154775f5ad4d7c7b0bb6dd9471f934a9100cd --- /dev/null +++ b/diffusers/modular_pipelines/flux/before_denoise.py @@ -0,0 +1,618 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect + +import numpy as np +import torch + +from ...pipelines import FluxPipeline +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging +from ...utils.torch_utils import randn_tensor +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import FluxModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +def _get_initial_timesteps_and_optionals( + transformer, + scheduler, + batch_size, + height, + width, + vae_scale_factor, + num_inference_steps, + guidance_scale, + sigmas, + device, +): + image_seq_len = (int(height) // vae_scale_factor // 2) * (int(width) // vae_scale_factor // 2) + + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + if hasattr(scheduler.config, "use_flow_sigmas") and scheduler.config.use_flow_sigmas: + sigmas = None + mu = calculate_shift( + image_seq_len, + scheduler.config.get("base_image_seq_len", 256), + scheduler.config.get("max_image_seq_len", 4096), + scheduler.config.get("base_shift", 0.5), + scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps, device, sigmas=sigmas, mu=mu) + if transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(batch_size) + else: + guidance = None + + return timesteps, num_inference_steps, sigmas, guidance + + +class FluxSetTimestepsStep(ModularPipelineBlocks): + model_name = "flux" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)] + + @property + def description(self) -> str: + return "Step that sets the scheduler's timesteps for inference" + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("num_inference_steps", default=50), + InputParam("timesteps"), + InputParam("sigmas"), + InputParam("guidance_scale", default=3.5), + InputParam("latents", type_hint=torch.Tensor), + InputParam("num_images_per_prompt", default=1), + InputParam("height", type_hint=int), + InputParam("width", type_hint=int), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`. Can be generated in input step.", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), + OutputParam( + "num_inference_steps", + type_hint=int, + description="The number of denoising steps to perform at inference time", + ), + OutputParam("guidance", type_hint=torch.Tensor, description="Optional guidance to be used."), + ] + + @torch.no_grad() + def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.device = components._execution_device + + scheduler = components.scheduler + transformer = components.transformer + + batch_size = block_state.batch_size * block_state.num_images_per_prompt + timesteps, num_inference_steps, sigmas, guidance = _get_initial_timesteps_and_optionals( + transformer, + scheduler, + batch_size, + block_state.height, + block_state.width, + components.vae_scale_factor, + block_state.num_inference_steps, + block_state.guidance_scale, + block_state.sigmas, + block_state.device, + ) + block_state.timesteps = timesteps + block_state.num_inference_steps = num_inference_steps + block_state.sigmas = sigmas + block_state.guidance = guidance + + # We set the index here to remove DtoH sync, helpful especially during compilation. + # Check out more details here: https://github.com/huggingface/diffusers/pull/11696 + components.scheduler.set_begin_index(0) + + self.set_block_state(state, block_state) + return components, state + + +class FluxImg2ImgSetTimestepsStep(ModularPipelineBlocks): + model_name = "flux" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)] + + @property + def description(self) -> str: + return "Step that sets the scheduler's timesteps for inference" + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("num_inference_steps", default=50), + InputParam("timesteps"), + InputParam("sigmas"), + InputParam("strength", default=0.6), + InputParam("guidance_scale", default=3.5), + InputParam("num_images_per_prompt", default=1), + InputParam("height", type_hint=int), + InputParam("width", type_hint=int), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`. Can be generated in input step.", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), + OutputParam( + "num_inference_steps", + type_hint=int, + description="The number of denoising steps to perform at inference time", + ), + OutputParam("guidance", type_hint=torch.Tensor, description="Optional guidance to be used."), + ] + + @staticmethod + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps with self.scheduler->scheduler + def get_timesteps(scheduler, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = scheduler.timesteps[t_start * scheduler.order :] + if hasattr(scheduler, "set_begin_index"): + scheduler.set_begin_index(t_start * scheduler.order) + + return timesteps, num_inference_steps - t_start + + @torch.no_grad() + def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.device = components._execution_device + + block_state.height = block_state.height or components.default_height + block_state.width = block_state.width or components.default_width + + scheduler = components.scheduler + transformer = components.transformer + batch_size = block_state.batch_size * block_state.num_images_per_prompt + timesteps, num_inference_steps, sigmas, guidance = _get_initial_timesteps_and_optionals( + transformer, + scheduler, + batch_size, + block_state.height, + block_state.width, + components.vae_scale_factor, + block_state.num_inference_steps, + block_state.guidance_scale, + block_state.sigmas, + block_state.device, + ) + timesteps, num_inference_steps = self.get_timesteps( + scheduler, num_inference_steps, block_state.strength, block_state.device + ) + block_state.timesteps = timesteps + block_state.num_inference_steps = num_inference_steps + block_state.sigmas = sigmas + block_state.guidance = guidance + + self.set_block_state(state, block_state) + return components, state + + +class FluxPrepareLatentsStep(ModularPipelineBlocks): + model_name = "flux" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [] + + @property + def description(self) -> str: + return "Prepare latents step that prepares the latents for the text-to-image generation process" + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("height", type_hint=int), + InputParam("width", type_hint=int), + InputParam("latents", type_hint=torch.Tensor | None), + InputParam("num_images_per_prompt", type_hint=int, default=1), + InputParam("generator"), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`. Can be generated in input step.", + ), + InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process" + ), + ] + + @staticmethod + def check_inputs(components, block_state): + if (block_state.height is not None and block_state.height % (components.vae_scale_factor * 2) != 0) or ( + block_state.width is not None and block_state.width % (components.vae_scale_factor * 2) != 0 + ): + logger.warning( + f"`height` and `width` have to be divisible by {components.vae_scale_factor} but are {block_state.height} and {block_state.width}." + ) + + @staticmethod + def prepare_latents( + comp, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + height = 2 * (int(height) // (comp.vae_scale_factor * 2)) + width = 2 * (int(width) // (comp.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is not None: + return latents.to(device=device, dtype=dtype) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + # TODO: move packing latents code to a patchifier similar to Qwen + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = FluxPipeline._pack_latents(latents, batch_size, num_channels_latents, height, width) + + return latents + + @torch.no_grad() + def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.height = block_state.height or components.default_height + block_state.width = block_state.width or components.default_width + block_state.device = components._execution_device + block_state.num_channels_latents = components.num_channels_latents + + self.check_inputs(components, block_state) + batch_size = block_state.batch_size * block_state.num_images_per_prompt + block_state.latents = self.prepare_latents( + components, + batch_size, + block_state.num_channels_latents, + block_state.height, + block_state.width, + block_state.dtype, + block_state.device, + block_state.generator, + block_state.latents, + ) + + self.set_block_state(state, block_state) + + return components, state + + +class FluxImg2ImgPrepareLatentsStep(ModularPipelineBlocks): + model_name = "flux" + + @property + def description(self) -> str: + return "Step that adds noise to image latents for image-to-image. Should be run after `set_timesteps`," + " `prepare_latents`. Both noise and image latents should already be patchified." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + name="latents", + required=True, + type_hint=torch.Tensor, + description="The initial random noised, can be generated in prepare latent step.", + ), + InputParam( + name="image_latents", + required=True, + type_hint=torch.Tensor, + description="The image latents to use for the denoising process. Can be generated in vae encoder and packed in input step.", + ), + InputParam( + name="timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + name="initial_noise", + type_hint=torch.Tensor, + description="The initial random noised used for inpainting denoising.", + ), + ] + + @staticmethod + def check_inputs(image_latents, latents): + if image_latents.shape[0] != latents.shape[0]: + raise ValueError( + f"`image_latents` must have have same batch size as `latents`, but got {image_latents.shape[0]} and {latents.shape[0]}" + ) + + if image_latents.ndim != 3: + raise ValueError(f"`image_latents` must have 3 dimensions (patchified), but got {image_latents.ndim}") + + @torch.no_grad() + def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + self.check_inputs(image_latents=block_state.image_latents, latents=block_state.latents) + + # prepare latent timestep + latent_timestep = block_state.timesteps[:1].repeat(block_state.latents.shape[0]) + + # make copy of initial_noise + block_state.initial_noise = block_state.latents + + # scale noise + block_state.latents = components.scheduler.scale_noise( + block_state.image_latents, latent_timestep, block_state.latents + ) + + self.set_block_state(state, block_state) + + return components, state + + +class FluxRoPEInputsStep(ModularPipelineBlocks): + model_name = "flux" + + @property + def description(self) -> str: + return "Step that prepares the RoPE inputs for the denoising process. Should be placed after text encoder and latent preparation steps." + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam(name="height", required=True), + InputParam(name="width", required=True), + InputParam(name="prompt_embeds"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + name="txt_ids", + kwargs_type="denoiser_input_fields", + type_hint=list[int], + description="The sequence lengths of the prompt embeds, used for RoPE calculation.", + ), + OutputParam( + name="img_ids", + kwargs_type="denoiser_input_fields", + type_hint=list[int], + description="The sequence lengths of the image latents, used for RoPE calculation.", + ), + ] + + def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + prompt_embeds = block_state.prompt_embeds + device, dtype = prompt_embeds.device, prompt_embeds.dtype + block_state.txt_ids = torch.zeros(prompt_embeds.shape[1], 3).to( + device=prompt_embeds.device, dtype=prompt_embeds.dtype + ) + + height = 2 * (int(block_state.height) // (components.vae_scale_factor * 2)) + width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2)) + block_state.img_ids = FluxPipeline._prepare_latent_image_ids(None, height // 2, width // 2, device, dtype) + + self.set_block_state(state, block_state) + + return components, state + + +class FluxKontextRoPEInputsStep(ModularPipelineBlocks): + model_name = "flux-kontext" + + @property + def description(self) -> str: + return "Step that prepares the RoPE inputs for the denoising process of Flux Kontext. Should be placed after text encoder and latent preparation steps." + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam(name="image_height"), + InputParam(name="image_width"), + InputParam(name="height"), + InputParam(name="width"), + InputParam(name="prompt_embeds"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + name="txt_ids", + kwargs_type="denoiser_input_fields", + type_hint=list[int], + description="The sequence lengths of the prompt embeds, used for RoPE calculation.", + ), + OutputParam( + name="img_ids", + kwargs_type="denoiser_input_fields", + type_hint=list[int], + description="The sequence lengths of the image latents, used for RoPE calculation.", + ), + ] + + def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + prompt_embeds = block_state.prompt_embeds + device, dtype = prompt_embeds.device, prompt_embeds.dtype + block_state.txt_ids = torch.zeros(prompt_embeds.shape[1], 3).to( + device=prompt_embeds.device, dtype=prompt_embeds.dtype + ) + + img_ids = None + if ( + getattr(block_state, "image_height", None) is not None + and getattr(block_state, "image_width", None) is not None + ): + image_latent_height = 2 * (int(block_state.image_height) // (components.vae_scale_factor * 2)) + image_latent_width = 2 * (int(block_state.image_width) // (components.vae_scale_factor * 2)) + img_ids = FluxPipeline._prepare_latent_image_ids( + None, image_latent_height // 2, image_latent_width // 2, device, dtype + ) + # image ids are the same as latent ids with the first dimension set to 1 instead of 0 + img_ids[..., 0] = 1 + + height = 2 * (int(block_state.height) // (components.vae_scale_factor * 2)) + width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2)) + latent_ids = FluxPipeline._prepare_latent_image_ids(None, height // 2, width // 2, device, dtype) + + if img_ids is not None: + latent_ids = torch.cat([latent_ids, img_ids], dim=0) + + block_state.img_ids = latent_ids + + self.set_block_state(state, block_state) + + return components, state diff --git a/diffusers/modular_pipelines/flux/decoders.py b/diffusers/modular_pipelines/flux/decoders.py new file mode 100644 index 0000000000000000000000000000000000000000..5da861e78fcba81baa14d4eacadae232a17b324f --- /dev/null +++ b/diffusers/modular_pipelines/flux/decoders.py @@ -0,0 +1,109 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import numpy as np +import PIL +import torch + +from ...configuration_utils import FrozenDict +from ...models import AutoencoderKL +from ...utils import logging +from ...video_processor import VaeImageProcessor +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + +class FluxDecodeStep(ModularPipelineBlocks): + model_name = "flux" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKL), + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 16}), + default_creation_method="from_config", + ), + ] + + @property + def description(self) -> str: + return "Step that decodes the denoised latents into images" + + @property + def inputs(self) -> list[tuple[str, Any]]: + return [ + InputParam("output_type", default="pil"), + InputParam("height", default=1024), + InputParam("width", default=1024), + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The denoised latents from the denoising step", + ), + ] + + @property + def intermediate_outputs(self) -> list[str]: + return [ + OutputParam( + "images", + type_hint=list[PIL.Image.Image] | torch.Tensor | np.ndarray, + description="The generated images, can be a list of PIL.Image.Image, torch.Tensor or a numpy array", + ) + ] + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + vae = components.vae + + if not block_state.output_type == "latent": + latents = block_state.latents + latents = _unpack_latents(latents, block_state.height, block_state.width, components.vae_scale_factor) + latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor + block_state.images = vae.decode(latents, return_dict=False)[0] + block_state.images = components.image_processor.postprocess( + block_state.images, output_type=block_state.output_type + ) + else: + block_state.images = block_state.latents + + self.set_block_state(state, block_state) + + return components, state diff --git a/diffusers/modular_pipelines/flux/denoise.py b/diffusers/modular_pipelines/flux/denoise.py new file mode 100644 index 0000000000000000000000000000000000000000..babb4a867e5958efcf4921104336999c20b98de0 --- /dev/null +++ b/diffusers/modular_pipelines/flux/denoise.py @@ -0,0 +1,330 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import torch + +from ...models import FluxTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging +from ..modular_pipeline import ( + BlockState, + LoopSequentialPipelineBlocks, + ModularPipelineBlocks, + PipelineState, +) +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import FluxModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class FluxLoopDenoiser(ModularPipelineBlocks): + model_name = "flux" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ComponentSpec("transformer", FluxTransformer2DModel)] + + @property + def description(self) -> str: + return ( + "Step within the denoising loop that denoise the latents. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `FluxDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> list[tuple[str, Any]]: + return [ + InputParam("joint_attention_kwargs"), + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", + ), + InputParam( + "guidance", + required=False, + type_hint=torch.Tensor, + description="Guidance scale as a tensor", + ), + InputParam( + "prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="Prompt embeddings", + ), + InputParam( + "pooled_prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="Pooled prompt embeddings", + ), + InputParam( + "txt_ids", + required=True, + type_hint=torch.Tensor, + description="IDs computed from text sequence needed for RoPE", + ), + InputParam( + "img_ids", + required=True, + type_hint=torch.Tensor, + description="IDs computed from image sequence needed for RoPE", + ), + ] + + @torch.no_grad() + def __call__( + self, components: FluxModularPipeline, block_state: BlockState, i: int, t: torch.Tensor + ) -> PipelineState: + noise_pred = components.transformer( + hidden_states=block_state.latents, + timestep=t.flatten() / 1000, + guidance=block_state.guidance, + encoder_hidden_states=block_state.prompt_embeds, + pooled_projections=block_state.pooled_prompt_embeds, + joint_attention_kwargs=block_state.joint_attention_kwargs, + txt_ids=block_state.txt_ids, + img_ids=block_state.img_ids, + return_dict=False, + )[0] + block_state.noise_pred = noise_pred + + return components, block_state + + +class FluxKontextLoopDenoiser(ModularPipelineBlocks): + model_name = "flux-kontext" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ComponentSpec("transformer", FluxTransformer2DModel)] + + @property + def description(self) -> str: + return ( + "Step within the denoising loop that denoise the latents for Flux Kontext. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `FluxDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> list[tuple[str, Any]]: + return [ + InputParam("joint_attention_kwargs"), + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", + ), + InputParam( + "image_latents", + type_hint=torch.Tensor, + description="Image latents to use for the denoising process. Can be generated in prepare_latent step.", + ), + InputParam( + "guidance", + required=False, + type_hint=torch.Tensor, + description="Guidance scale as a tensor", + ), + InputParam( + "prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="Prompt embeddings", + ), + InputParam( + "pooled_prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="Pooled prompt embeddings", + ), + InputParam( + "txt_ids", + required=True, + type_hint=torch.Tensor, + description="IDs computed from text sequence needed for RoPE", + ), + InputParam( + "img_ids", + required=True, + type_hint=torch.Tensor, + description="IDs computed from latent sequence needed for RoPE", + ), + ] + + @torch.no_grad() + def __call__( + self, components: FluxModularPipeline, block_state: BlockState, i: int, t: torch.Tensor + ) -> PipelineState: + latents = block_state.latents + latent_model_input = latents + image_latents = block_state.image_latents + if image_latents is not None: + latent_model_input = torch.cat([latent_model_input, image_latents], dim=1) + + timestep = t.expand(latents.shape[0]).to(latents.dtype) + noise_pred = components.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=block_state.guidance, + encoder_hidden_states=block_state.prompt_embeds, + pooled_projections=block_state.pooled_prompt_embeds, + joint_attention_kwargs=block_state.joint_attention_kwargs, + txt_ids=block_state.txt_ids, + img_ids=block_state.img_ids, + return_dict=False, + )[0] + noise_pred = noise_pred[:, : latents.size(1)] + block_state.noise_pred = noise_pred + + return components, block_state + + +class FluxLoopAfterDenoiser(ModularPipelineBlocks): + model_name = "flux" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)] + + @property + def description(self) -> str: + return ( + "step within the denoising loop that update the latents. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `FluxDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> list[tuple[str, Any]]: + return [] + + @property + def intermediate_inputs(self) -> list[str]: + return [InputParam("generator")] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] + + @torch.no_grad() + def __call__(self, components: FluxModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + # Perform scheduler step using the predicted output + latents_dtype = block_state.latents.dtype + block_state.latents = components.scheduler.step( + block_state.noise_pred, + t, + block_state.latents, + return_dict=False, + )[0] + + if block_state.latents.dtype != latents_dtype: + block_state.latents = block_state.latents.to(latents_dtype) + + return components, block_state + + +class FluxDenoiseLoopWrapper(LoopSequentialPipelineBlocks): + model_name = "flux" + + @property + def description(self) -> str: + return ( + "Pipeline block that iteratively denoise the latents over `timesteps`. " + "The specific steps with each iteration can be customized with `sub_blocks` attributes" + ) + + @property + def loop_expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ComponentSpec("transformer", FluxTransformer2DModel), + ] + + @property + def loop_inputs(self) -> list[InputParam]: + return [ + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", + ), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", + ), + ] + + @torch.no_grad() + def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.num_warmup_steps = max( + len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0 + ) + with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: + for i, t in enumerate(block_state.timesteps): + components, block_state = self.loop_step(components, block_state, i=i, t=t) + if i == len(block_state.timesteps) - 1 or ( + (i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0 + ): + progress_bar.update() + + self.set_block_state(state, block_state) + + return components, state + + +class FluxDenoiseStep(FluxDenoiseLoopWrapper): + block_classes = [FluxLoopDenoiser, FluxLoopAfterDenoiser] + block_names = ["denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. \n" + "Its loop logic is defined in `FluxDenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" + " - `FluxLoopDenoiser`\n" + " - `FluxLoopAfterDenoiser`\n" + "This block supports both text2image and img2img tasks." + ) + + +class FluxKontextDenoiseStep(FluxDenoiseLoopWrapper): + model_name = "flux-kontext" + block_classes = [FluxKontextLoopDenoiser, FluxLoopAfterDenoiser] + block_names = ["denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. \n" + "Its loop logic is defined in `FluxDenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" + " - `FluxKontextLoopDenoiser`\n" + " - `FluxLoopAfterDenoiser`\n" + "This block supports both text2image and img2img tasks." + ) diff --git a/diffusers/modular_pipelines/flux/encoders.py b/diffusers/modular_pipelines/flux/encoders.py new file mode 100644 index 0000000000000000000000000000000000000000..583c139ff22e60173efc31a97654b243e8576d1c --- /dev/null +++ b/diffusers/modular_pipelines/flux/encoders.py @@ -0,0 +1,480 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html + +import regex as re +import torch +from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast + +from ...configuration_utils import FrozenDict +from ...image_processor import VaeImageProcessor, is_valid_image, is_valid_image_imagelist +from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL +from ...utils import USE_PEFT_BACKEND, is_ftfy_available, logging, scale_lora_layers, unscale_lora_layers +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import FluxModularPipeline + + +if is_ftfy_available(): + import ftfy + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +def encode_vae_image(vae: AutoencoderKL, image: torch.Tensor, generator: torch.Generator, sample_mode="sample"): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i], sample_mode=sample_mode) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(vae.encode(image), generator=generator, sample_mode=sample_mode) + + image_latents = (image_latents - vae.config.shift_factor) * vae.config.scaling_factor + + return image_latents + + +class FluxProcessImagesInputStep(ModularPipelineBlocks): + model_name = "flux" + + @property + def description(self) -> str: + return "Image Preprocess step." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 16}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [InputParam("resized_image"), InputParam("image"), InputParam("height"), InputParam("width")] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [OutputParam(name="processed_image")] + + @staticmethod + def check_inputs(height, width, vae_scale_factor): + if height is not None and height % (vae_scale_factor * 2) != 0: + raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}") + + if width is not None and width % (vae_scale_factor * 2) != 0: + raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}") + + @torch.no_grad() + def __call__(self, components: FluxModularPipeline, state: PipelineState): + block_state = self.get_block_state(state) + + if block_state.resized_image is None and block_state.image is None: + raise ValueError("`resized_image` and `image` cannot be None at the same time") + + if block_state.resized_image is None: + image = block_state.image + self.check_inputs( + height=block_state.height, width=block_state.width, vae_scale_factor=components.vae_scale_factor + ) + height = block_state.height or components.default_height + width = block_state.width or components.default_width + else: + width, height = block_state.resized_image[0].size + image = block_state.resized_image + + block_state.processed_image = components.image_processor.preprocess(image=image, height=height, width=width) + + self.set_block_state(state, block_state) + return components, state + + +class FluxKontextProcessImagesInputStep(ModularPipelineBlocks): + model_name = "flux-kontext" + + @property + def description(self) -> str: + return ( + "Image preprocess step for Flux Kontext. The preprocessed image goes to the VAE.\n" + "Kontext works as a T2I model, too, in case no input image is provided." + ) + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 16}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [InputParam("image"), InputParam("_auto_resize", type_hint=bool, default=True)] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [OutputParam(name="processed_image")] + + @torch.no_grad() + def __call__(self, components: FluxModularPipeline, state: PipelineState): + from ...pipelines.flux.pipeline_flux_kontext import PREFERRED_KONTEXT_RESOLUTIONS + + block_state = self.get_block_state(state) + images = block_state.image + + if images is None: + block_state.processed_image = None + + else: + multiple_of = components.image_processor.config.vae_scale_factor + + if not is_valid_image_imagelist(images): + raise ValueError(f"Images must be image or list of images but are {type(images)}") + + if is_valid_image(images): + images = [images] + + img = images[0] + image_height, image_width = components.image_processor.get_default_height_width(img) + aspect_ratio = image_width / image_height + _auto_resize = block_state._auto_resize + if _auto_resize: + # Kontext is trained on specific resolutions, using one of them is recommended + _, image_width, image_height = min( + (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS + ) + image_width = image_width // multiple_of * multiple_of + image_height = image_height // multiple_of * multiple_of + images = components.image_processor.resize(images, image_height, image_width) + block_state.processed_image = components.image_processor.preprocess(images, image_height, image_width) + + self.set_block_state(state, block_state) + return components, state + + +class FluxVaeEncoderStep(ModularPipelineBlocks): + model_name = "flux" + + def __init__( + self, input_name: str = "processed_image", output_name: str = "image_latents", sample_mode: str = "sample" + ): + """Initialize a VAE encoder step for converting images to latent representations. + + Both the input and output names are configurable so this block can be configured to process to different image + inputs (e.g., "processed_image" -> "image_latents", "processed_control_image" -> "control_image_latents"). + + Args: + input_name (str, optional): Name of the input image tensor. Defaults to "processed_image". + Examples: "processed_image" or "processed_control_image" + output_name (str, optional): Name of the output latent tensor. Defaults to "image_latents". + Examples: "image_latents" or "control_image_latents" + sample_mode (str, optional): Sampling mode to be used. + + Examples: + # Basic usage with default settings (includes image processor): # FluxImageVaeEncoderDynamicStep() + + # Custom input/output names for control image: # FluxImageVaeEncoderDynamicStep( + input_name="processed_control_image", output_name="control_image_latents" + ) + """ + self._image_input_name = input_name + self._image_latents_output_name = output_name + self.sample_mode = sample_mode + super().__init__() + + @property + def description(self) -> str: + return f"Dynamic VAE Encoder step that converts {self._image_input_name} into latent representations {self._image_latents_output_name}.\n" + + @property + def expected_components(self) -> list[ComponentSpec]: + components = [ComponentSpec("vae", AutoencoderKL)] + return components + + @property + def inputs(self) -> list[InputParam]: + inputs = [InputParam(self._image_input_name), InputParam("generator")] + return inputs + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + self._image_latents_output_name, + type_hint=torch.Tensor, + description="The latents representing the reference image", + ) + ] + + @torch.no_grad() + def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + image = getattr(block_state, self._image_input_name) + + if image is None: + setattr(block_state, self._image_latents_output_name, None) + else: + device = components._execution_device + dtype = components.vae.dtype + image = image.to(device=device, dtype=dtype) + + # Encode image into latents + image_latents = encode_vae_image( + image=image, vae=components.vae, generator=block_state.generator, sample_mode=self.sample_mode + ) + setattr(block_state, self._image_latents_output_name, image_latents) + + self.set_block_state(state, block_state) + + return components, state + + +class FluxTextEncoderStep(ModularPipelineBlocks): + model_name = "flux" + + @property + def description(self) -> str: + return "Text Encoder step that generate text_embeddings to guide the image generation" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("text_encoder", CLIPTextModel), + ComponentSpec("tokenizer", CLIPTokenizer), + ComponentSpec("text_encoder_2", T5EncoderModel), + ComponentSpec("tokenizer_2", T5TokenizerFast), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("prompt"), + InputParam("prompt_2"), + InputParam("max_sequence_length", type_hint=int, default=512, required=False), + InputParam("joint_attention_kwargs"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "prompt_embeds", + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="text embeddings used to guide the image generation", + ), + OutputParam( + "pooled_prompt_embeds", + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="pooled text embeddings used to guide the image generation", + ), + ] + + @staticmethod + def check_inputs(block_state): + for prompt in [block_state.prompt, block_state.prompt_2]: + if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` or `prompt_2` has to be of type `str` or `list` but is {type(prompt)}") + + @staticmethod + def _get_t5_prompt_embeds(components, prompt: str | list[str], max_sequence_length: int, device: torch.device): + dtype = components.text_encoder_2.dtype + prompt = [prompt] if isinstance(prompt, str) else prompt + + if isinstance(components, TextualInversionLoaderMixin): + prompt = components.maybe_convert_prompt(prompt, components.tokenizer_2) + + text_inputs = components.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + untruncated_ids = components.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = components.tokenizer_2.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = components.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + return prompt_embeds + + @staticmethod + def _get_clip_prompt_embeds(components, prompt: str | list[str], device: torch.device): + prompt = [prompt] if isinstance(prompt, str) else prompt + + if isinstance(components, TextualInversionLoaderMixin): + prompt = components.maybe_convert_prompt(prompt, components.tokenizer) + + text_inputs = components.tokenizer( + prompt, + padding="max_length", + max_length=components.tokenizer.model_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + tokenizer_max_length = components.tokenizer.model_max_length + untruncated_ids = components.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = components.tokenizer.batch_decode(untruncated_ids[:, tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = components.text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=components.text_encoder.dtype, device=device) + + return prompt_embeds + + @staticmethod + def encode_prompt( + components, + prompt: str | list[str], + prompt_2: str | list[str], + device: torch.device | None = None, + prompt_embeds: torch.FloatTensor | None = None, + pooled_prompt_embeds: torch.FloatTensor | None = None, + max_sequence_length: int = 512, + lora_scale: float | None = None, + ): + device = device or components._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(components, FluxLoraLoaderMixin): + components._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if components.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(components.text_encoder, lora_scale) + if components.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(components.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # We only use the pooled prompt output from the CLIPTextModel + pooled_prompt_embeds = FluxTextEncoderStep._get_clip_prompt_embeds( + components, + prompt=prompt, + device=device, + ) + prompt_embeds = FluxTextEncoderStep._get_t5_prompt_embeds( + components, + prompt=prompt_2, + max_sequence_length=max_sequence_length, + device=device, + ) + + if components.text_encoder is not None: + if isinstance(components, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(components.text_encoder, lora_scale) + + if components.text_encoder_2 is not None: + if isinstance(components, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(components.text_encoder_2, lora_scale) + + return prompt_embeds, pooled_prompt_embeds + + @torch.no_grad() + def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState: + # Get inputs and intermediates + block_state = self.get_block_state(state) + self.check_inputs(block_state) + + block_state.device = components._execution_device + + # Encode input prompt + block_state.text_encoder_lora_scale = ( + block_state.joint_attention_kwargs.get("scale", None) + if block_state.joint_attention_kwargs is not None + else None + ) + block_state.prompt_embeds, block_state.pooled_prompt_embeds = self.encode_prompt( + components, + prompt=block_state.prompt, + prompt_2=None, + prompt_embeds=None, + pooled_prompt_embeds=None, + device=block_state.device, + max_sequence_length=block_state.max_sequence_length, + lora_scale=block_state.text_encoder_lora_scale, + ) + + # Add outputs + self.set_block_state(state, block_state) + return components, state diff --git a/diffusers/modular_pipelines/flux/inputs.py b/diffusers/modular_pipelines/flux/inputs.py new file mode 100644 index 0000000000000000000000000000000000000000..9d2f69dbe26fb0fdf1fb3c96ed2fb54b181b023f --- /dev/null +++ b/diffusers/modular_pipelines/flux/inputs.py @@ -0,0 +1,363 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch + +from ...pipelines import FluxPipeline +from ...utils import logging +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import InputParam, OutputParam + +# TODO: consider making these common utilities for modular if they are not pipeline-specific. +from ..qwenimage.inputs import calculate_dimension_from_latents, repeat_tensor_to_batch_size +from .modular_pipeline import FluxModularPipeline + + +logger = logging.get_logger(__name__) + + +class FluxTextInputStep(ModularPipelineBlocks): + model_name = "flux" + + @property + def description(self) -> str: + return ( + "Text input processing step that standardizes text embeddings for the pipeline.\n" + "This step:\n" + " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n" + " 2. Ensures all text embeddings have consistent batch sizes (batch_size * num_images_per_prompt)" + ) + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("num_images_per_prompt", default=1), + InputParam( + "prompt_embeds", + required=True, + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="Pre-generated text embeddings. Can be generated from text_encoder step.", + ), + InputParam( + "pooled_prompt_embeds", + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="Pre-generated pooled text embeddings. Can be generated from text_encoder step.", + ), + # TODO: support negative embeddings? + ] + + @property + def intermediate_outputs(self) -> list[str]: + return [ + OutputParam( + "batch_size", + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt", + ), + OutputParam( + "dtype", + type_hint=torch.dtype, + description="Data type of model tensor inputs (determined by `prompt_embeds`)", + ), + OutputParam( + "prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="text embeddings used to guide the image generation", + ), + OutputParam( + "pooled_prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="pooled text embeddings used to guide the image generation", + ), + # TODO: support negative embeddings? + ] + + def check_inputs(self, components, block_state): + if block_state.prompt_embeds is not None and block_state.pooled_prompt_embeds is not None: + if block_state.prompt_embeds.shape[0] != block_state.pooled_prompt_embeds.shape[0]: + raise ValueError( + "`prompt_embeds` and `pooled_prompt_embeds` must have the same batch size when passed directly, but" + f" got: `prompt_embeds` {block_state.prompt_embeds.shape} != `pooled_prompt_embeds`" + f" {block_state.pooled_prompt_embeds.shape}." + ) + + @torch.no_grad() + def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState: + # TODO: consider adding negative embeddings? + block_state = self.get_block_state(state) + self.check_inputs(components, block_state) + + block_state.batch_size = block_state.prompt_embeds.shape[0] + block_state.dtype = block_state.prompt_embeds.dtype + + _, seq_len, _ = block_state.prompt_embeds.shape + block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) + block_state.prompt_embeds = block_state.prompt_embeds.view( + block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1 + ) + pooled_prompt_embeds = block_state.pooled_prompt_embeds.repeat(1, block_state.num_images_per_prompt) + block_state.pooled_prompt_embeds = pooled_prompt_embeds.view( + block_state.batch_size * block_state.num_images_per_prompt, -1 + ) + self.set_block_state(state, block_state) + + return components, state + + +# Adapted from `QwenImageAdditionalInputsStep` +class FluxAdditionalInputsStep(ModularPipelineBlocks): + model_name = "flux" + + def __init__( + self, + image_latent_inputs: list[str] = ["image_latents"], + additional_batch_inputs: list[str] = [], + ): + if not isinstance(image_latent_inputs, list): + image_latent_inputs = [image_latent_inputs] + if not isinstance(additional_batch_inputs, list): + additional_batch_inputs = [additional_batch_inputs] + + self._image_latent_inputs = image_latent_inputs + self._additional_batch_inputs = additional_batch_inputs + super().__init__() + + @property + def description(self) -> str: + # Functionality section + summary_section = ( + "Input processing step that:\n" + " 1. For image latent inputs: Updates height/width if None, patchifies latents, and expands batch size\n" + " 2. For additional batch inputs: Expands batch dimensions to match final batch size" + ) + + # Inputs info + inputs_info = "" + if self._image_latent_inputs or self._additional_batch_inputs: + inputs_info = "\n\nConfigured inputs:" + if self._image_latent_inputs: + inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}" + if self._additional_batch_inputs: + inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}" + + # Placement guidance + placement_section = "\n\nThis block should be placed after the encoder steps and the text input step." + + return summary_section + inputs_info + placement_section + + @property + def inputs(self) -> list[InputParam]: + inputs = [ + InputParam(name="num_images_per_prompt", default=1), + InputParam(name="batch_size", required=True), + InputParam(name="height"), + InputParam(name="width"), + ] + + # Add image latent inputs + for image_latent_input_name in self._image_latent_inputs: + inputs.append(InputParam(name=image_latent_input_name)) + + # Add additional batch inputs + for input_name in self._additional_batch_inputs: + inputs.append(InputParam(name=input_name)) + + return inputs + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam(name="image_height", type_hint=int, description="The height of the image latents"), + OutputParam(name="image_width", type_hint=int, description="The width of the image latents"), + ] + + def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + # Process image latent inputs (height/width calculation, patchify, and batch expansion) + for image_latent_input_name in self._image_latent_inputs: + image_latent_tensor = getattr(block_state, image_latent_input_name) + if image_latent_tensor is None: + continue + + # 1. Calculate height/width from latents + height, width = calculate_dimension_from_latents(image_latent_tensor, components.vae_scale_factor) + block_state.height = block_state.height or height + block_state.width = block_state.width or width + + if not hasattr(block_state, "image_height"): + block_state.image_height = height + if not hasattr(block_state, "image_width"): + block_state.image_width = width + + # 2. Patchify the image latent tensor + # TODO: Implement patchifier for Flux. + latent_height, latent_width = image_latent_tensor.shape[2:] + image_latent_tensor = FluxPipeline._pack_latents( + image_latent_tensor, block_state.batch_size, image_latent_tensor.shape[1], latent_height, latent_width + ) + + # 3. Expand batch size + image_latent_tensor = repeat_tensor_to_batch_size( + input_name=image_latent_input_name, + input_tensor=image_latent_tensor, + num_images_per_prompt=block_state.num_images_per_prompt, + batch_size=block_state.batch_size, + ) + + setattr(block_state, image_latent_input_name, image_latent_tensor) + + # Process additional batch inputs (only batch expansion) + for input_name in self._additional_batch_inputs: + input_tensor = getattr(block_state, input_name) + if input_tensor is None: + continue + + # Only expand batch size + input_tensor = repeat_tensor_to_batch_size( + input_name=input_name, + input_tensor=input_tensor, + num_images_per_prompt=block_state.num_images_per_prompt, + batch_size=block_state.batch_size, + ) + + setattr(block_state, input_name, input_tensor) + + self.set_block_state(state, block_state) + return components, state + + +class FluxKontextAdditionalInputsStep(FluxAdditionalInputsStep): + model_name = "flux-kontext" + + def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + # Process image latent inputs (height/width calculation, patchify, and batch expansion) + for image_latent_input_name in self._image_latent_inputs: + image_latent_tensor = getattr(block_state, image_latent_input_name) + if image_latent_tensor is None: + continue + + # 1. Calculate height/width from latents + # Unlike the `FluxAdditionalInputsStep`, we don't overwrite the `block.height` and `block.width` + height, width = calculate_dimension_from_latents(image_latent_tensor, components.vae_scale_factor) + if not hasattr(block_state, "image_height"): + block_state.image_height = height + if not hasattr(block_state, "image_width"): + block_state.image_width = width + + # 2. Patchify the image latent tensor + # TODO: Implement patchifier for Flux. + latent_height, latent_width = image_latent_tensor.shape[2:] + image_latent_tensor = FluxPipeline._pack_latents( + image_latent_tensor, block_state.batch_size, image_latent_tensor.shape[1], latent_height, latent_width + ) + + # 3. Expand batch size + image_latent_tensor = repeat_tensor_to_batch_size( + input_name=image_latent_input_name, + input_tensor=image_latent_tensor, + num_images_per_prompt=block_state.num_images_per_prompt, + batch_size=block_state.batch_size, + ) + + setattr(block_state, image_latent_input_name, image_latent_tensor) + + # Process additional batch inputs (only batch expansion) + for input_name in self._additional_batch_inputs: + input_tensor = getattr(block_state, input_name) + if input_tensor is None: + continue + + # Only expand batch size + input_tensor = repeat_tensor_to_batch_size( + input_name=input_name, + input_tensor=input_tensor, + num_images_per_prompt=block_state.num_images_per_prompt, + batch_size=block_state.batch_size, + ) + + setattr(block_state, input_name, input_tensor) + + self.set_block_state(state, block_state) + return components, state + + +class FluxKontextSetResolutionStep(ModularPipelineBlocks): + model_name = "flux-kontext" + + @property + def description(self): + return ( + "Determines the height and width to be used during the subsequent computations.\n" + "It should always be placed _before_ the latent preparation step." + ) + + @property + def inputs(self) -> list[InputParam]: + inputs = [ + InputParam(name="height"), + InputParam(name="width"), + InputParam(name="max_area", type_hint=int, default=1024**2), + ] + return inputs + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam(name="height", type_hint=int, description="The height of the initial noisy latents"), + OutputParam(name="width", type_hint=int, description="The width of the initial noisy latents"), + ] + + @staticmethod + def check_inputs(height, width, vae_scale_factor): + if height is not None and height % (vae_scale_factor * 2) != 0: + raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}") + + if width is not None and width % (vae_scale_factor * 2) != 0: + raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}") + + def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + height = block_state.height or components.default_height + width = block_state.width or components.default_width + self.check_inputs(height, width, components.vae_scale_factor) + + original_height, original_width = height, width + max_area = block_state.max_area + aspect_ratio = width / height + width = round((max_area * aspect_ratio) ** 0.5) + height = round((max_area / aspect_ratio) ** 0.5) + + multiple_of = components.vae_scale_factor * 2 + width = width // multiple_of * multiple_of + height = height // multiple_of * multiple_of + + if height != original_height or width != original_width: + logger.warning( + f"Generation `height` and `width` have been adjusted to {height} and {width} to fit the model requirements." + ) + + block_state.height = height + block_state.width = width + + self.set_block_state(state, block_state) + return components, state diff --git a/diffusers/modular_pipelines/flux/modular_blocks_flux.py b/diffusers/modular_pipelines/flux/modular_blocks_flux.py new file mode 100644 index 0000000000000000000000000000000000000000..f2e78e9334482234f2473403c2d53f8a6407f835 --- /dev/null +++ b/diffusers/modular_pipelines/flux/modular_blocks_flux.py @@ -0,0 +1,586 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...utils import logging +from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks +from ..modular_pipeline_utils import InsertableDict, OutputParam +from .before_denoise import ( + FluxImg2ImgPrepareLatentsStep, + FluxImg2ImgSetTimestepsStep, + FluxPrepareLatentsStep, + FluxRoPEInputsStep, + FluxSetTimestepsStep, +) +from .decoders import FluxDecodeStep +from .denoise import FluxDenoiseStep +from .encoders import ( + FluxProcessImagesInputStep, + FluxTextEncoderStep, + FluxVaeEncoderStep, +) +from .inputs import ( + FluxAdditionalInputsStep, + FluxTextInputStep, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# vae encoder (run before before_denoise) + + +# auto_docstring +class FluxImg2ImgVaeEncoderStep(SequentialPipelineBlocks): + """ + Vae encoder step that preprocess andencode the image inputs into their latent representations. + + Components: + image_processor (`VaeImageProcessor`) vae (`AutoencoderKL`) + + Inputs: + resized_image (`None`, *optional*): + TODO: Add description. + image (`None`, *optional*): + TODO: Add description. + height (`None`, *optional*): + TODO: Add description. + width (`None`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + + Outputs: + processed_image (`None`): + TODO: Add description. + image_latents (`Tensor`): + The latents representing the reference image + """ + + model_name = "flux" + + block_classes = [FluxProcessImagesInputStep(), FluxVaeEncoderStep()] + block_names = ["preprocess", "encode"] + + @property + def description(self) -> str: + return "Vae encoder step that preprocess andencode the image inputs into their latent representations." + + +# auto_docstring +class FluxAutoVaeEncoderStep(AutoPipelineBlocks): + """ + Vae encoder step that encode the image inputs into their latent representations. + This is an auto pipeline block that works for img2img tasks. + - `FluxImg2ImgVaeEncoderStep` (img2img) is used when only `image` is provided. - if `image` is not provided, + step will be skipped. + + Components: + image_processor (`VaeImageProcessor`) vae (`AutoencoderKL`) + + Inputs: + resized_image (`None`, *optional*): + TODO: Add description. + image (`None`, *optional*): + TODO: Add description. + height (`None`, *optional*): + TODO: Add description. + width (`None`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + + Outputs: + processed_image (`None`): + TODO: Add description. + image_latents (`Tensor`): + The latents representing the reference image + """ + + model_name = "flux" + block_classes = [FluxImg2ImgVaeEncoderStep] + block_names = ["img2img"] + block_trigger_inputs = ["image"] + + @property + def description(self): + return ( + "Vae encoder step that encode the image inputs into their latent representations.\n" + + "This is an auto pipeline block that works for img2img tasks.\n" + + " - `FluxImg2ImgVaeEncoderStep` (img2img) is used when only `image` is provided." + + " - if `image` is not provided, step will be skipped." + ) + + +# before_denoise: text2img +# auto_docstring +class FluxBeforeDenoiseStep(SequentialPipelineBlocks): + """ + Before denoise step that prepares the inputs for the denoise step in text-to-image generation. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + height (`int`, *optional*): + TODO: Add description. + width (`int`, *optional*): + TODO: Add description. + latents (`Tensor | NoneType`, *optional*): + TODO: Add description. + num_images_per_prompt (`int`, *optional*, defaults to 1): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + batch_size (`int`): + Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`. + Can be generated in input step. + dtype (`dtype`, *optional*): + The dtype of the model inputs + num_inference_steps (`None`, *optional*, defaults to 50): + TODO: Add description. + timesteps (`None`, *optional*): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + guidance_scale (`None`, *optional*, defaults to 3.5): + TODO: Add description. + prompt_embeds (`None`, *optional*): + TODO: Add description. + + Outputs: + latents (`Tensor`): + The initial latents to use for the denoising process + timesteps (`Tensor`): + The timesteps to use for inference + num_inference_steps (`int`): + The number of denoising steps to perform at inference time + guidance (`Tensor`): + Optional guidance to be used. + txt_ids (`list`): + The sequence lengths of the prompt embeds, used for RoPE calculation. + img_ids (`list`): + The sequence lengths of the image latents, used for RoPE calculation. + """ + + model_name = "flux" + block_classes = [FluxPrepareLatentsStep(), FluxSetTimestepsStep(), FluxRoPEInputsStep()] + block_names = ["prepare_latents", "set_timesteps", "prepare_rope_inputs"] + + @property + def description(self): + return "Before denoise step that prepares the inputs for the denoise step in text-to-image generation." + + +# before_denoise: img2img +# auto_docstring +class FluxImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks): + """ + Before denoise step that prepare the inputs for the denoise step for img2img task. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + height (`int`, *optional*): + TODO: Add description. + width (`int`, *optional*): + TODO: Add description. + latents (`Tensor | NoneType`, *optional*): + TODO: Add description. + num_images_per_prompt (`int`, *optional*, defaults to 1): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + batch_size (`int`): + Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`. + Can be generated in input step. + dtype (`dtype`, *optional*): + The dtype of the model inputs + num_inference_steps (`None`, *optional*, defaults to 50): + TODO: Add description. + timesteps (`None`, *optional*): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + strength (`None`, *optional*, defaults to 0.6): + TODO: Add description. + guidance_scale (`None`, *optional*, defaults to 3.5): + TODO: Add description. + image_latents (`Tensor`): + The image latents to use for the denoising process. Can be generated in vae encoder and packed in input + step. + prompt_embeds (`None`, *optional*): + TODO: Add description. + + Outputs: + latents (`Tensor`): + The initial latents to use for the denoising process + timesteps (`Tensor`): + The timesteps to use for inference + num_inference_steps (`int`): + The number of denoising steps to perform at inference time + guidance (`Tensor`): + Optional guidance to be used. + initial_noise (`Tensor`): + The initial random noised used for inpainting denoising. + txt_ids (`list`): + The sequence lengths of the prompt embeds, used for RoPE calculation. + img_ids (`list`): + The sequence lengths of the image latents, used for RoPE calculation. + """ + + model_name = "flux" + block_classes = [ + FluxPrepareLatentsStep(), + FluxImg2ImgSetTimestepsStep(), + FluxImg2ImgPrepareLatentsStep(), + FluxRoPEInputsStep(), + ] + block_names = ["prepare_latents", "set_timesteps", "prepare_img2img_latents", "prepare_rope_inputs"] + + @property + def description(self): + return "Before denoise step that prepare the inputs for the denoise step for img2img task." + + +# before_denoise: all task (text2img, img2img) +# auto_docstring +class FluxAutoBeforeDenoiseStep(AutoPipelineBlocks): + """ + Before denoise step that prepare the inputs for the denoise step. + This is an auto pipeline block that works for text2image. + - `FluxBeforeDenoiseStep` (text2image) is used. + - `FluxImg2ImgBeforeDenoiseStep` (img2img) is used when only `image_latents` is provided. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + height (`int`): + TODO: Add description. + width (`int`): + TODO: Add description. + latents (`Tensor | NoneType`, *optional*): + TODO: Add description. + num_images_per_prompt (`int`, *optional*, defaults to 1): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + batch_size (`int`): + Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`. + Can be generated in input step. + dtype (`dtype`, *optional*): + The dtype of the model inputs + num_inference_steps (`None`, *optional*, defaults to 50): + TODO: Add description. + timesteps (`None`, *optional*): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + strength (`None`, *optional*, defaults to 0.6): + TODO: Add description. + guidance_scale (`None`, *optional*, defaults to 3.5): + TODO: Add description. + image_latents (`Tensor`, *optional*): + The image latents to use for the denoising process. Can be generated in vae encoder and packed in input + step. + prompt_embeds (`None`, *optional*): + TODO: Add description. + + Outputs: + latents (`Tensor`): + The initial latents to use for the denoising process + timesteps (`Tensor`): + The timesteps to use for inference + num_inference_steps (`int`): + The number of denoising steps to perform at inference time + guidance (`Tensor`): + Optional guidance to be used. + initial_noise (`Tensor`): + The initial random noised used for inpainting denoising. + txt_ids (`list`): + The sequence lengths of the prompt embeds, used for RoPE calculation. + img_ids (`list`): + The sequence lengths of the image latents, used for RoPE calculation. + """ + + model_name = "flux" + block_classes = [FluxImg2ImgBeforeDenoiseStep, FluxBeforeDenoiseStep] + block_names = ["img2img", "text2image"] + block_trigger_inputs = ["image_latents", None] + + @property + def description(self): + return ( + "Before denoise step that prepare the inputs for the denoise step.\n" + + "This is an auto pipeline block that works for text2image.\n" + + " - `FluxBeforeDenoiseStep` (text2image) is used.\n" + + " - `FluxImg2ImgBeforeDenoiseStep` (img2img) is used when only `image_latents` is provided.\n" + ) + + +# inputs: text2image/img2img + + +# auto_docstring +class FluxImg2ImgInputStep(SequentialPipelineBlocks): + """ + Input step that prepares the inputs for the img2img denoising step. It: + + Inputs: + num_images_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + prompt_embeds (`Tensor`): + Pre-generated text embeddings. Can be generated from text_encoder step. + pooled_prompt_embeds (`Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be generated from text_encoder step. + height (`None`, *optional*): + TODO: Add description. + width (`None`, *optional*): + TODO: Add description. + image_latents (`None`, *optional*): + TODO: Add description. + + Outputs: + batch_size (`int`): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt + dtype (`dtype`): + Data type of model tensor inputs (determined by `prompt_embeds`) + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation + pooled_prompt_embeds (`Tensor`): + pooled text embeddings used to guide the image generation + image_height (`int`): + The height of the image latents + image_width (`int`): + The width of the image latents + """ + + model_name = "flux" + block_classes = [FluxTextInputStep(), FluxAdditionalInputsStep()] + block_names = ["text_inputs", "additional_inputs"] + + @property + def description(self): + return "Input step that prepares the inputs for the img2img denoising step. It:\n" + " - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents`).\n" + " - update height/width based `image_latents`, patchify `image_latents`." + + +# auto_docstring +class FluxAutoInputStep(AutoPipelineBlocks): + """ + Input step that standardize the inputs for the denoising step, e.g. make sure inputs have consistent batch size, + and patchified. + This is an auto pipeline block that works for text2image/img2img tasks. + - `FluxImg2ImgInputStep` (img2img) is used when `image_latents` is provided. + - `FluxTextInputStep` (text2image) is used when `image_latents` are not provided. + + Inputs: + num_images_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + prompt_embeds (`Tensor`): + Pre-generated text embeddings. Can be generated from text_encoder step. + pooled_prompt_embeds (`Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be generated from text_encoder step. + height (`None`, *optional*): + TODO: Add description. + width (`None`, *optional*): + TODO: Add description. + image_latents (`None`, *optional*): + TODO: Add description. + + Outputs: + batch_size (`int`): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt + dtype (`dtype`): + Data type of model tensor inputs (determined by `prompt_embeds`) + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation + pooled_prompt_embeds (`Tensor`): + pooled text embeddings used to guide the image generation + image_height (`int`): + The height of the image latents + image_width (`int`): + The width of the image latents + """ + + model_name = "flux" + + block_classes = [FluxImg2ImgInputStep, FluxTextInputStep] + block_names = ["img2img", "text2image"] + block_trigger_inputs = ["image_latents", None] + + @property + def description(self): + return ( + "Input step that standardize the inputs for the denoising step, e.g. make sure inputs have consistent batch size, and patchified. \n" + " This is an auto pipeline block that works for text2image/img2img tasks.\n" + + " - `FluxImg2ImgInputStep` (img2img) is used when `image_latents` is provided.\n" + + " - `FluxTextInputStep` (text2image) is used when `image_latents` are not provided.\n" + ) + + +# auto_docstring +class FluxCoreDenoiseStep(SequentialPipelineBlocks): + """ + Core step that performs the denoising process for Flux. + This step supports text-to-image and image-to-image tasks for Flux: + - for image-to-image generation, you need to provide `image_latents` + - for text-to-image generation, all you need to provide is prompt embeddings. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) transformer (`FluxTransformer2DModel`) + + Inputs: + num_images_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + prompt_embeds (`Tensor`): + Pre-generated text embeddings. Can be generated from text_encoder step. + pooled_prompt_embeds (`Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be generated from text_encoder step. + height (`None`, *optional*): + TODO: Add description. + width (`None`, *optional*): + TODO: Add description. + image_latents (`None`, *optional*): + TODO: Add description. + latents (`Tensor | NoneType`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + num_inference_steps (`None`, *optional*, defaults to 50): + TODO: Add description. + timesteps (`None`, *optional*): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + strength (`None`, *optional*, defaults to 0.6): + TODO: Add description. + guidance_scale (`None`, *optional*, defaults to 3.5): + TODO: Add description. + joint_attention_kwargs (`None`, *optional*): + TODO: Add description. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "flux" + block_classes = [FluxAutoInputStep, FluxAutoBeforeDenoiseStep, FluxDenoiseStep] + block_names = ["input", "before_denoise", "denoise"] + + @property + def description(self): + return ( + "Core step that performs the denoising process for Flux.\n" + + "This step supports text-to-image and image-to-image tasks for Flux:\n" + + " - for image-to-image generation, you need to provide `image_latents`\n" + + " - for text-to-image generation, all you need to provide is prompt embeddings." + ) + + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + + +# Auto blocks (text2image and img2img) +AUTO_BLOCKS = InsertableDict( + [ + ("text_encoder", FluxTextEncoderStep()), + ("vae_encoder", FluxAutoVaeEncoderStep()), + ("denoise", FluxCoreDenoiseStep()), + ("decode", FluxDecodeStep()), + ] +) + + +# auto_docstring +class FluxAutoBlocks(SequentialPipelineBlocks): + """ + Auto Modular pipeline for text-to-image and image-to-image using Flux. + + Supported workflows: + - `text2image`: requires `prompt` + - `image2image`: requires `image`, `prompt` + + Components: + text_encoder (`CLIPTextModel`) tokenizer (`CLIPTokenizer`) text_encoder_2 (`T5EncoderModel`) tokenizer_2 + (`T5TokenizerFast`) image_processor (`VaeImageProcessor`) vae (`AutoencoderKL`) scheduler + (`FlowMatchEulerDiscreteScheduler`) transformer (`FluxTransformer2DModel`) + + Inputs: + prompt (`None`, *optional*): + TODO: Add description. + prompt_2 (`None`, *optional*): + TODO: Add description. + max_sequence_length (`int`, *optional*, defaults to 512): + TODO: Add description. + joint_attention_kwargs (`None`, *optional*): + TODO: Add description. + resized_image (`None`, *optional*): + TODO: Add description. + image (`None`, *optional*): + TODO: Add description. + height (`None`, *optional*): + TODO: Add description. + width (`None`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + num_images_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + image_latents (`None`, *optional*): + TODO: Add description. + latents (`Tensor | NoneType`, *optional*): + TODO: Add description. + num_inference_steps (`None`, *optional*, defaults to 50): + TODO: Add description. + timesteps (`None`, *optional*): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + strength (`None`, *optional*, defaults to 0.6): + TODO: Add description. + guidance_scale (`None`, *optional*, defaults to 3.5): + TODO: Add description. + output_type (`None`, *optional*, defaults to pil): + TODO: Add description. + + Outputs: + images (`list`): + Generated images. + """ + + model_name = "flux" + + block_classes = AUTO_BLOCKS.values() + block_names = AUTO_BLOCKS.keys() + + _workflow_map = { + "text2image": {"prompt": True}, + "image2image": {"image": True, "prompt": True}, + } + + @property + def description(self): + return "Auto Modular pipeline for text-to-image and image-to-image using Flux." + + @property + def outputs(self): + return [OutputParam.template("images")] diff --git a/diffusers/modular_pipelines/flux/modular_blocks_flux_kontext.py b/diffusers/modular_pipelines/flux/modular_blocks_flux_kontext.py new file mode 100644 index 0000000000000000000000000000000000000000..b5a5dbf78c0eb9ef2cd7a8102aca2a29aa9b6db3 --- /dev/null +++ b/diffusers/modular_pipelines/flux/modular_blocks_flux_kontext.py @@ -0,0 +1,585 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...utils import logging +from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks +from ..modular_pipeline_utils import InsertableDict, OutputParam +from .before_denoise import ( + FluxKontextRoPEInputsStep, + FluxPrepareLatentsStep, + FluxRoPEInputsStep, + FluxSetTimestepsStep, +) +from .decoders import FluxDecodeStep +from .denoise import FluxKontextDenoiseStep +from .encoders import ( + FluxKontextProcessImagesInputStep, + FluxTextEncoderStep, + FluxVaeEncoderStep, +) +from .inputs import ( + FluxKontextAdditionalInputsStep, + FluxKontextSetResolutionStep, + FluxTextInputStep, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Flux Kontext vae encoder (run before before_denoise) +# auto_docstring +class FluxKontextVaeEncoderStep(SequentialPipelineBlocks): + """ + Vae encoder step that preprocess andencode the image inputs into their latent representations. + + Components: + image_processor (`VaeImageProcessor`) vae (`AutoencoderKL`) + + Inputs: + image (`None`, *optional*): + TODO: Add description. + _auto_resize (`bool`, *optional*, defaults to True): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + + Outputs: + processed_image (`None`): + TODO: Add description. + image_latents (`Tensor`): + The latents representing the reference image + """ + + model_name = "flux-kontext" + + block_classes = [FluxKontextProcessImagesInputStep(), FluxVaeEncoderStep(sample_mode="argmax")] + block_names = ["preprocess", "encode"] + + @property + def description(self) -> str: + return "Vae encoder step that preprocess andencode the image inputs into their latent representations." + + +# auto_docstring +class FluxKontextAutoVaeEncoderStep(AutoPipelineBlocks): + """ + Vae encoder step that encode the image inputs into their latent representations. + This is an auto pipeline block that works for image-conditioned tasks. + - `FluxKontextVaeEncoderStep` (image_conditioned) is used when only `image` is provided. - if `image` is not + provided, step will be skipped. + + Components: + image_processor (`VaeImageProcessor`) vae (`AutoencoderKL`) + + Inputs: + image (`None`, *optional*): + TODO: Add description. + _auto_resize (`bool`, *optional*, defaults to True): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + + Outputs: + processed_image (`None`): + TODO: Add description. + image_latents (`Tensor`): + The latents representing the reference image + """ + + model_name = "flux-kontext" + + block_classes = [FluxKontextVaeEncoderStep] + block_names = ["image_conditioned"] + block_trigger_inputs = ["image"] + + @property + def description(self): + return ( + "Vae encoder step that encode the image inputs into their latent representations.\n" + + "This is an auto pipeline block that works for image-conditioned tasks.\n" + + " - `FluxKontextVaeEncoderStep` (image_conditioned) is used when only `image` is provided." + + " - if `image` is not provided, step will be skipped." + ) + + +# before_denoise: text2img +# auto_docstring +class FluxKontextBeforeDenoiseStep(SequentialPipelineBlocks): + """ + Before denoise step that prepares the inputs for the denoise step for Flux Kontext + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + height (`int`, *optional*): + TODO: Add description. + width (`int`, *optional*): + TODO: Add description. + latents (`Tensor | NoneType`, *optional*): + TODO: Add description. + num_images_per_prompt (`int`, *optional*, defaults to 1): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + batch_size (`int`): + Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`. + Can be generated in input step. + dtype (`dtype`, *optional*): + The dtype of the model inputs + num_inference_steps (`None`, *optional*, defaults to 50): + TODO: Add description. + timesteps (`None`, *optional*): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + guidance_scale (`None`, *optional*, defaults to 3.5): + TODO: Add description. + prompt_embeds (`None`, *optional*): + TODO: Add description. + + Outputs: + latents (`Tensor`): + The initial latents to use for the denoising process + timesteps (`Tensor`): + The timesteps to use for inference + num_inference_steps (`int`): + The number of denoising steps to perform at inference time + guidance (`Tensor`): + Optional guidance to be used. + txt_ids (`list`): + The sequence lengths of the prompt embeds, used for RoPE calculation. + img_ids (`list`): + The sequence lengths of the image latents, used for RoPE calculation. + """ + + model_name = "flux-kontext" + + block_classes = [FluxPrepareLatentsStep(), FluxSetTimestepsStep(), FluxRoPEInputsStep()] + block_names = ["prepare_latents", "set_timesteps", "prepare_rope_inputs"] + + @property + def description(self): + return "Before denoise step that prepares the inputs for the denoise step for Flux Kontext\n" + "for text-to-image tasks." + + +# before_denoise: image-conditioned +# auto_docstring +class FluxKontextImageConditionedBeforeDenoiseStep(SequentialPipelineBlocks): + """ + Before denoise step that prepare the inputs for the denoise step for Flux Kontext + for image-conditioned tasks. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + height (`int`, *optional*): + TODO: Add description. + width (`int`, *optional*): + TODO: Add description. + latents (`Tensor | NoneType`, *optional*): + TODO: Add description. + num_images_per_prompt (`int`, *optional*, defaults to 1): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + batch_size (`int`): + Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`. + Can be generated in input step. + dtype (`dtype`, *optional*): + The dtype of the model inputs + num_inference_steps (`None`, *optional*, defaults to 50): + TODO: Add description. + timesteps (`None`, *optional*): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + guidance_scale (`None`, *optional*, defaults to 3.5): + TODO: Add description. + image_height (`None`, *optional*): + TODO: Add description. + image_width (`None`, *optional*): + TODO: Add description. + prompt_embeds (`None`, *optional*): + TODO: Add description. + + Outputs: + latents (`Tensor`): + The initial latents to use for the denoising process + timesteps (`Tensor`): + The timesteps to use for inference + num_inference_steps (`int`): + The number of denoising steps to perform at inference time + guidance (`Tensor`): + Optional guidance to be used. + txt_ids (`list`): + The sequence lengths of the prompt embeds, used for RoPE calculation. + img_ids (`list`): + The sequence lengths of the image latents, used for RoPE calculation. + """ + + model_name = "flux-kontext" + + block_classes = [FluxPrepareLatentsStep(), FluxSetTimestepsStep(), FluxKontextRoPEInputsStep()] + block_names = ["prepare_latents", "set_timesteps", "prepare_rope_inputs"] + + @property + def description(self): + return ( + "Before denoise step that prepare the inputs for the denoise step for Flux Kontext\n" + "for image-conditioned tasks." + ) + + +# auto_docstring +class FluxKontextAutoBeforeDenoiseStep(AutoPipelineBlocks): + """ + Before denoise step that prepare the inputs for the denoise step. + This is an auto pipeline block that works for text2image. + - `FluxKontextBeforeDenoiseStep` (text2image) is used. + - `FluxKontextImageConditionedBeforeDenoiseStep` (image_conditioned) is used when only `image_latents` is + provided. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + height (`int`, *optional*): + TODO: Add description. + width (`int`, *optional*): + TODO: Add description. + latents (`Tensor | NoneType`, *optional*): + TODO: Add description. + num_images_per_prompt (`int`, *optional*, defaults to 1): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + batch_size (`int`): + Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`. + Can be generated in input step. + dtype (`dtype`, *optional*): + The dtype of the model inputs + num_inference_steps (`None`, *optional*, defaults to 50): + TODO: Add description. + timesteps (`None`, *optional*): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + guidance_scale (`None`, *optional*, defaults to 3.5): + TODO: Add description. + image_height (`None`, *optional*): + TODO: Add description. + image_width (`None`, *optional*): + TODO: Add description. + prompt_embeds (`None`, *optional*): + TODO: Add description. + + Outputs: + latents (`Tensor`): + The initial latents to use for the denoising process + timesteps (`Tensor`): + The timesteps to use for inference + num_inference_steps (`int`): + The number of denoising steps to perform at inference time + guidance (`Tensor`): + Optional guidance to be used. + txt_ids (`list`): + The sequence lengths of the prompt embeds, used for RoPE calculation. + img_ids (`list`): + The sequence lengths of the image latents, used for RoPE calculation. + """ + + model_name = "flux-kontext" + + block_classes = [FluxKontextImageConditionedBeforeDenoiseStep, FluxKontextBeforeDenoiseStep] + block_names = ["image_conditioned", "text2image"] + block_trigger_inputs = ["image_latents", None] + + @property + def description(self): + return ( + "Before denoise step that prepare the inputs for the denoise step.\n" + + "This is an auto pipeline block that works for text2image.\n" + + " - `FluxKontextBeforeDenoiseStep` (text2image) is used.\n" + + " - `FluxKontextImageConditionedBeforeDenoiseStep` (image_conditioned) is used when only `image_latents` is provided.\n" + ) + + +# inputs: Flux Kontext +# auto_docstring +class FluxKontextInputStep(SequentialPipelineBlocks): + """ + Input step that prepares the inputs for the both text2img and img2img denoising step. It: + - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents`). + - update height/width based `image_latents`, patchify `image_latents`. + + Inputs: + height (`None`, *optional*): + TODO: Add description. + width (`None`, *optional*): + TODO: Add description. + max_area (`int`, *optional*, defaults to 1048576): + TODO: Add description. + num_images_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + prompt_embeds (`Tensor`): + Pre-generated text embeddings. Can be generated from text_encoder step. + pooled_prompt_embeds (`Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be generated from text_encoder step. + image_latents (`None`, *optional*): + TODO: Add description. + + Outputs: + height (`int`): + The height of the initial noisy latents + width (`int`): + The width of the initial noisy latents + batch_size (`int`): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt + dtype (`dtype`): + Data type of model tensor inputs (determined by `prompt_embeds`) + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation + pooled_prompt_embeds (`Tensor`): + pooled text embeddings used to guide the image generation + image_height (`int`): + The height of the image latents + image_width (`int`): + The width of the image latents + """ + + model_name = "flux-kontext" + block_classes = [FluxKontextSetResolutionStep(), FluxTextInputStep(), FluxKontextAdditionalInputsStep()] + block_names = ["set_resolution", "text_inputs", "additional_inputs"] + + @property + def description(self): + return ( + "Input step that prepares the inputs for the both text2img and img2img denoising step. It:\n" + " - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents`).\n" + " - update height/width based `image_latents`, patchify `image_latents`." + ) + + +# auto_docstring +class FluxKontextAutoInputStep(AutoPipelineBlocks): + """ + Input step that standardize the inputs for the denoising step, e.g. make sure inputs have consistent batch size, + and patchified. + This is an auto pipeline block that works for text2image/img2img tasks. + - `FluxKontextInputStep` (image_conditioned) is used when `image_latents` is provided. + - `FluxKontextInputStep` is also capable of handling text2image task when `image_latent` isn't present. + + Inputs: + height (`None`, *optional*): + TODO: Add description. + width (`None`, *optional*): + TODO: Add description. + max_area (`int`, *optional*, defaults to 1048576): + TODO: Add description. + num_images_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + prompt_embeds (`Tensor`): + Pre-generated text embeddings. Can be generated from text_encoder step. + pooled_prompt_embeds (`Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be generated from text_encoder step. + image_latents (`None`, *optional*): + TODO: Add description. + + Outputs: + height (`int`): + The height of the initial noisy latents + width (`int`): + The width of the initial noisy latents + batch_size (`int`): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt + dtype (`dtype`): + Data type of model tensor inputs (determined by `prompt_embeds`) + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation + pooled_prompt_embeds (`Tensor`): + pooled text embeddings used to guide the image generation + image_height (`int`): + The height of the image latents + image_width (`int`): + The width of the image latents + """ + + model_name = "flux-kontext" + block_classes = [FluxKontextInputStep, FluxTextInputStep] + block_names = ["image_conditioned", "text2image"] + block_trigger_inputs = ["image_latents", None] + + @property + def description(self): + return ( + "Input step that standardize the inputs for the denoising step, e.g. make sure inputs have consistent batch size, and patchified. \n" + " This is an auto pipeline block that works for text2image/img2img tasks.\n" + + " - `FluxKontextInputStep` (image_conditioned) is used when `image_latents` is provided.\n" + + " - `FluxKontextInputStep` is also capable of handling text2image task when `image_latent` isn't present." + ) + + +# auto_docstring +class FluxKontextCoreDenoiseStep(SequentialPipelineBlocks): + """ + Core step that performs the denoising process for Flux Kontext. + This step supports text-to-image and image-conditioned tasks for Flux Kontext: + - for image-conditioned generation, you need to provide `image_latents` + - for text-to-image generation, all you need to provide is prompt embeddings. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) transformer (`FluxTransformer2DModel`) + + Inputs: + height (`None`, *optional*): + TODO: Add description. + width (`None`, *optional*): + TODO: Add description. + max_area (`int`, *optional*, defaults to 1048576): + TODO: Add description. + num_images_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + prompt_embeds (`Tensor`): + Pre-generated text embeddings. Can be generated from text_encoder step. + pooled_prompt_embeds (`Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be generated from text_encoder step. + image_latents (`None`, *optional*): + TODO: Add description. + latents (`Tensor | NoneType`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + num_inference_steps (`None`, *optional*, defaults to 50): + TODO: Add description. + timesteps (`None`, *optional*): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + guidance_scale (`None`, *optional*, defaults to 3.5): + TODO: Add description. + joint_attention_kwargs (`None`, *optional*): + TODO: Add description. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "flux-kontext" + block_classes = [FluxKontextAutoInputStep, FluxKontextAutoBeforeDenoiseStep, FluxKontextDenoiseStep] + block_names = ["input", "before_denoise", "denoise"] + + @property + def description(self): + return ( + "Core step that performs the denoising process for Flux Kontext.\n" + + "This step supports text-to-image and image-conditioned tasks for Flux Kontext:\n" + + " - for image-conditioned generation, you need to provide `image_latents`\n" + + " - for text-to-image generation, all you need to provide is prompt embeddings." + ) + + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + + +AUTO_BLOCKS_KONTEXT = InsertableDict( + [ + ("text_encoder", FluxTextEncoderStep()), + ("vae_encoder", FluxKontextAutoVaeEncoderStep()), + ("denoise", FluxKontextCoreDenoiseStep()), + ("decode", FluxDecodeStep()), + ] +) + + +# auto_docstring +class FluxKontextAutoBlocks(SequentialPipelineBlocks): + """ + Modular pipeline for image-to-image using Flux Kontext. + + Supported workflows: + - `image_conditioned`: requires `image`, `prompt` + - `text2image`: requires `prompt` + + Components: + text_encoder (`CLIPTextModel`) tokenizer (`CLIPTokenizer`) text_encoder_2 (`T5EncoderModel`) tokenizer_2 + (`T5TokenizerFast`) image_processor (`VaeImageProcessor`) vae (`AutoencoderKL`) scheduler + (`FlowMatchEulerDiscreteScheduler`) transformer (`FluxTransformer2DModel`) + + Inputs: + prompt (`None`, *optional*): + TODO: Add description. + prompt_2 (`None`, *optional*): + TODO: Add description. + max_sequence_length (`int`, *optional*, defaults to 512): + TODO: Add description. + joint_attention_kwargs (`None`, *optional*): + TODO: Add description. + image (`None`, *optional*): + TODO: Add description. + _auto_resize (`bool`, *optional*, defaults to True): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + height (`None`, *optional*): + TODO: Add description. + width (`None`, *optional*): + TODO: Add description. + max_area (`int`, *optional*, defaults to 1048576): + TODO: Add description. + num_images_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + image_latents (`None`, *optional*): + TODO: Add description. + latents (`Tensor | NoneType`, *optional*): + TODO: Add description. + num_inference_steps (`None`, *optional*, defaults to 50): + TODO: Add description. + timesteps (`None`, *optional*): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + guidance_scale (`None`, *optional*, defaults to 3.5): + TODO: Add description. + output_type (`None`, *optional*, defaults to pil): + TODO: Add description. + + Outputs: + images (`list`): + Generated images. + """ + + model_name = "flux-kontext" + + block_classes = AUTO_BLOCKS_KONTEXT.values() + block_names = AUTO_BLOCKS_KONTEXT.keys() + _workflow_map = { + "image_conditioned": {"image": True, "prompt": True}, + "text2image": {"prompt": True}, + } + + @property + def description(self): + return "Modular pipeline for image-to-image using Flux Kontext." + + @property + def outputs(self): + return [OutputParam.template("images")] diff --git a/diffusers/modular_pipelines/flux/modular_pipeline.py b/diffusers/modular_pipelines/flux/modular_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..d8158f5d4fd618751e328bb92f699b818507ba05 --- /dev/null +++ b/diffusers/modular_pipelines/flux/modular_pipeline.py @@ -0,0 +1,67 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin +from ...utils import logging +from ..modular_pipeline import ModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class FluxModularPipeline(ModularPipeline, FluxLoraLoaderMixin, TextualInversionLoaderMixin): + """ + A ModularPipeline for Flux. + + > [!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "FluxAutoBlocks" + + @property + def default_height(self): + return self.default_sample_size * self.vae_scale_factor + + @property + def default_width(self): + return self.default_sample_size * self.vae_scale_factor + + @property + def default_sample_size(self): + return 128 + + @property + def vae_scale_factor(self): + vae_scale_factor = 8 + if getattr(self, "vae", None) is not None: + vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + return vae_scale_factor + + @property + def num_channels_latents(self): + num_channels_latents = 16 + if getattr(self, "transformer", None): + num_channels_latents = self.transformer.config.in_channels // 4 + return num_channels_latents + + +class FluxKontextModularPipeline(FluxModularPipeline): + """ + A ModularPipeline for Flux Kontext. + + > [!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "FluxKontextAutoBlocks" diff --git a/diffusers/modular_pipelines/flux2/__init__.py b/diffusers/modular_pipelines/flux2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d7cc8badcaf7d0e51f7e6eb2923df6c04d4172cd --- /dev/null +++ b/diffusers/modular_pipelines/flux2/__init__.py @@ -0,0 +1,57 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["encoders"] = ["Flux2RemoteTextEncoderStep"] + _import_structure["modular_blocks_flux2"] = ["Flux2AutoBlocks"] + _import_structure["modular_blocks_flux2_klein"] = ["Flux2KleinAutoBlocks"] + _import_structure["modular_blocks_flux2_klein_base"] = ["Flux2KleinBaseAutoBlocks"] + _import_structure["modular_pipeline"] = [ + "Flux2KleinBaseModularPipeline", + "Flux2KleinModularPipeline", + "Flux2ModularPipeline", + ] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .encoders import Flux2RemoteTextEncoderStep + from .modular_blocks_flux2 import Flux2AutoBlocks + from .modular_blocks_flux2_klein import Flux2KleinAutoBlocks + from .modular_blocks_flux2_klein_base import Flux2KleinBaseAutoBlocks + from .modular_pipeline import Flux2KleinBaseModularPipeline, Flux2KleinModularPipeline, Flux2ModularPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/modular_pipelines/flux2/before_denoise.py b/diffusers/modular_pipelines/flux2/before_denoise.py new file mode 100644 index 0000000000000000000000000000000000000000..6b1b3bd96324dd0fc90a293c18246a18e2133541 --- /dev/null +++ b/diffusers/modular_pipelines/flux2/before_denoise.py @@ -0,0 +1,591 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect + +import numpy as np +import torch + +from ...models import Flux2Transformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging +from ...utils.torch_utils import randn_tensor +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import Flux2ModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float: + """Compute empirical mu for Flux2 timestep scheduling.""" + a1, b1 = 8.73809524e-05, 1.89833333 + a2, b2 = 0.00016927, 0.45666666 + + if image_seq_len > 4300: + mu = a2 * image_seq_len + b2 + return float(mu) + + m_200 = a2 * image_seq_len + b2 + m_10 = a1 * image_seq_len + b1 + + a = (m_200 - m_10) / 190.0 + b = m_200 - 200.0 * a + mu = a * num_steps + b + + return float(mu) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class Flux2SetTimestepsStep(ModularPipelineBlocks): + model_name = "flux2" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ComponentSpec("transformer", Flux2Transformer2DModel), + ] + + @property + def description(self) -> str: + return "Step that sets the scheduler's timesteps for Flux2 inference using empirical mu calculation" + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("num_inference_steps", default=50), + InputParam("timesteps"), + InputParam("sigmas"), + InputParam("latents", type_hint=torch.Tensor), + InputParam("height", type_hint=int), + InputParam("width", type_hint=int), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), + OutputParam( + "num_inference_steps", + type_hint=int, + description="The number of denoising steps to perform at inference time", + ), + ] + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + device = components._execution_device + + scheduler = components.scheduler + + height = block_state.height or components.default_height + width = block_state.width or components.default_width + vae_scale_factor = components.vae_scale_factor + + latent_height = 2 * (int(height) // (vae_scale_factor * 2)) + latent_width = 2 * (int(width) // (vae_scale_factor * 2)) + image_seq_len = (latent_height // 2) * (latent_width // 2) + + num_inference_steps = block_state.num_inference_steps + sigmas = block_state.sigmas + timesteps = block_state.timesteps + + if timesteps is None and sigmas is None: + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + if hasattr(scheduler.config, "use_flow_sigmas") and scheduler.config.use_flow_sigmas: + sigmas = None + + mu = compute_empirical_mu(image_seq_len=image_seq_len, num_steps=num_inference_steps) + + timesteps, num_inference_steps = retrieve_timesteps( + scheduler, + num_inference_steps, + device, + timesteps=timesteps, + sigmas=sigmas, + mu=mu, + ) + block_state.timesteps = timesteps + block_state.num_inference_steps = num_inference_steps + + components.scheduler.set_begin_index(0) + + self.set_block_state(state, block_state) + return components, state + + +class Flux2PrepareLatentsStep(ModularPipelineBlocks): + model_name = "flux2" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [] + + @property + def description(self) -> str: + return "Prepare latents step that prepares the initial noise latents for Flux2 text-to-image generation" + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("height", type_hint=int), + InputParam("width", type_hint=int), + InputParam("latents", type_hint=torch.Tensor | None), + InputParam("num_images_per_prompt", type_hint=int, default=1), + InputParam("generator"), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.", + ), + InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process" + ), + OutputParam("latent_ids", type_hint=torch.Tensor, description="Position IDs for the latents (for RoPE)"), + ] + + @staticmethod + def check_inputs(components, block_state): + vae_scale_factor = components.vae_scale_factor + if (block_state.height is not None and block_state.height % (vae_scale_factor * 2) != 0) or ( + block_state.width is not None and block_state.width % (vae_scale_factor * 2) != 0 + ): + logger.warning( + f"`height` and `width` have to be divisible by {vae_scale_factor * 2} but are {block_state.height} and {block_state.width}." + ) + + @staticmethod + def _prepare_latent_ids(latents: torch.Tensor): + """ + Generates 4D position coordinates (T, H, W, L) for latent tensors. + + Args: + latents: Latent tensor of shape (B, C, H, W) + + Returns: + Position IDs tensor of shape (B, H*W, 4) + """ + batch_size, _, height, width = latents.shape + + t = torch.arange(1) + h = torch.arange(height) + w = torch.arange(width) + l = torch.arange(1) + + latent_ids = torch.cartesian_prod(t, h, w, l) + latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1) + + return latent_ids + + @staticmethod + def _pack_latents(latents): + """Pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels)""" + batch_size, num_channels, height, width = latents.shape + latents = latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1) + return latents + + @staticmethod + def prepare_latents( + comp, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + height = 2 * (int(height) // (comp.vae_scale_factor * 2)) + width = 2 * (int(width) // (comp.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents * 4, height // 2, width // 2) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + return latents + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.height = block_state.height or components.default_height + block_state.width = block_state.width or components.default_width + block_state.device = components._execution_device + block_state.num_channels_latents = components.num_channels_latents + + self.check_inputs(components, block_state) + batch_size = block_state.batch_size * block_state.num_images_per_prompt + + latents = self.prepare_latents( + components, + batch_size, + block_state.num_channels_latents, + block_state.height, + block_state.width, + block_state.dtype, + block_state.device, + block_state.generator, + block_state.latents, + ) + + latent_ids = self._prepare_latent_ids(latents) + latent_ids = latent_ids.to(block_state.device) + + latents = self._pack_latents(latents) + + block_state.latents = latents + block_state.latent_ids = latent_ids + + self.set_block_state(state, block_state) + return components, state + + +class Flux2RoPEInputsStep(ModularPipelineBlocks): + model_name = "flux2" + + @property + def description(self) -> str: + return "Step that prepares the 4D RoPE position IDs for Flux2 denoising. Should be placed after text encoder and latent preparation steps." + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam(name="prompt_embeds", required=True), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + name="txt_ids", + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="4D position IDs (T, H, W, L) for text tokens, used for RoPE calculation.", + ), + ] + + @staticmethod + def _prepare_text_ids(x: torch.Tensor, t_coord: torch.Tensor | None = None): + """Prepare 4D position IDs for text tokens.""" + B, L, _ = x.shape + out_ids = [] + + for i in range(B): + t = torch.arange(1) if t_coord is None else t_coord[i] + h = torch.arange(1) + w = torch.arange(1) + seq_l = torch.arange(L) + + coords = torch.cartesian_prod(t, h, w, seq_l) + out_ids.append(coords) + + return torch.stack(out_ids) + + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + prompt_embeds = block_state.prompt_embeds + device = prompt_embeds.device + + block_state.txt_ids = self._prepare_text_ids(prompt_embeds) + block_state.txt_ids = block_state.txt_ids.to(device) + + self.set_block_state(state, block_state) + return components, state + + +class Flux2KleinBaseRoPEInputsStep(ModularPipelineBlocks): + model_name = "flux2-klein" + + @property + def description(self) -> str: + return "Step that prepares the 4D RoPE position IDs for Flux2-Klein base model denoising. Should be placed after text encoder and latent preparation steps." + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam(name="prompt_embeds", required=True), + InputParam(name="negative_prompt_embeds", required=False), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + name="txt_ids", + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="4D position IDs (T, H, W, L) for text tokens, used for RoPE calculation.", + ), + OutputParam( + name="negative_txt_ids", + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="4D position IDs (T, H, W, L) for negative text tokens, used for RoPE calculation.", + ), + ] + + @staticmethod + def _prepare_text_ids(x: torch.Tensor, t_coord: torch.Tensor | None = None): + """Prepare 4D position IDs for text tokens.""" + B, L, _ = x.shape + out_ids = [] + + for i in range(B): + t = torch.arange(1) if t_coord is None else t_coord[i] + h = torch.arange(1) + w = torch.arange(1) + seq_l = torch.arange(L) + + coords = torch.cartesian_prod(t, h, w, seq_l) + out_ids.append(coords) + + return torch.stack(out_ids) + + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + prompt_embeds = block_state.prompt_embeds + device = prompt_embeds.device + + block_state.txt_ids = self._prepare_text_ids(prompt_embeds) + block_state.txt_ids = block_state.txt_ids.to(device) + + block_state.negative_txt_ids = None + if block_state.negative_prompt_embeds is not None: + block_state.negative_txt_ids = self._prepare_text_ids(block_state.negative_prompt_embeds) + block_state.negative_txt_ids = block_state.negative_txt_ids.to(device) + + self.set_block_state(state, block_state) + return components, state + + +class Flux2PrepareImageLatentsStep(ModularPipelineBlocks): + model_name = "flux2" + + @property + def description(self) -> str: + return "Step that prepares image latents and their position IDs for Flux2 image conditioning." + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("image_latents", type_hint=list[torch.Tensor]), + InputParam("batch_size", required=True, type_hint=int), + InputParam("num_images_per_prompt", default=1, type_hint=int), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "image_latents", + type_hint=torch.Tensor, + description="Packed image latents for conditioning", + ), + OutputParam( + "image_latent_ids", + type_hint=torch.Tensor, + description="Position IDs for image latents", + ), + ] + + @staticmethod + def _prepare_image_ids(image_latents: list[torch.Tensor], scale: int = 10): + """ + Generates 4D time-space coordinates (T, H, W, L) for a sequence of image latents. + + Args: + image_latents: A list of image latent feature tensors of shape (1, C, H, W). + scale: Factor used to define the time separation between latents. + + Returns: + Combined coordinate tensor of shape (1, N_total, 4) + """ + if not isinstance(image_latents, list): + raise ValueError(f"Expected `image_latents` to be a list, got {type(image_latents)}.") + + t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))] + t_coords = [t.view(-1) for t in t_coords] + + image_latent_ids = [] + for x, t in zip(image_latents, t_coords): + x = x.squeeze(0) + _, height, width = x.shape + + x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1)) + image_latent_ids.append(x_ids) + + image_latent_ids = torch.cat(image_latent_ids, dim=0) + image_latent_ids = image_latent_ids.unsqueeze(0) + + return image_latent_ids + + @staticmethod + def _pack_latents(latents): + """Pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels)""" + batch_size, num_channels, height, width = latents.shape + latents = latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1) + return latents + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + image_latents = block_state.image_latents + + if image_latents is None: + block_state.image_latents = None + block_state.image_latent_ids = None + self.set_block_state(state, block_state) + + return components, state + + device = components._execution_device + batch_size = block_state.batch_size * block_state.num_images_per_prompt + + image_latent_ids = self._prepare_image_ids(image_latents) + + packed_latents = [] + for latent in image_latents: + packed = self._pack_latents(latent) + packed = packed.squeeze(0) + packed_latents.append(packed) + + image_latents = torch.cat(packed_latents, dim=0) + image_latents = image_latents.unsqueeze(0) + + image_latents = image_latents.repeat(batch_size, 1, 1) + image_latent_ids = image_latent_ids.repeat(batch_size, 1, 1) + image_latent_ids = image_latent_ids.to(device) + + block_state.image_latents = image_latents + block_state.image_latent_ids = image_latent_ids + + self.set_block_state(state, block_state) + return components, state + + +class Flux2PrepareGuidanceStep(ModularPipelineBlocks): + model_name = "flux2" + + @property + def description(self) -> str: + return "Step that prepares the guidance scale tensor for Flux2 inference" + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("guidance_scale", default=4.0), + InputParam("num_images_per_prompt", default=1), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam("guidance", type_hint=torch.Tensor, description="Guidance scale tensor"), + ] + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + device = components._execution_device + batch_size = block_state.batch_size * block_state.num_images_per_prompt + guidance = torch.full([1], block_state.guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(batch_size) + block_state.guidance = guidance + + self.set_block_state(state, block_state) + return components, state diff --git a/diffusers/modular_pipelines/flux2/decoders.py b/diffusers/modular_pipelines/flux2/decoders.py new file mode 100644 index 0000000000000000000000000000000000000000..c8ad9401efff6ef4bd714c294cff064353e362d2 --- /dev/null +++ b/diffusers/modular_pipelines/flux2/decoders.py @@ -0,0 +1,185 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Any, Union + +import numpy as np +import PIL +import torch + +from ...configuration_utils import FrozenDict +from ...models import AutoencoderKLFlux2 +from ...pipelines.flux2.image_processor import Flux2ImageProcessor +from ...utils import logging +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class Flux2UnpackLatentsStep(ModularPipelineBlocks): + model_name = "flux2" + + @property + def description(self) -> str: + return "Step that unpacks the latents from the denoising step" + + @property + def inputs(self) -> list[tuple[str, Any]]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The denoised latents from the denoising step", + ), + InputParam( + "latent_ids", + required=True, + type_hint=torch.Tensor, + description="Position IDs for the latents, used for unpacking", + ), + ] + + @property + def intermediate_outputs(self) -> list[str]: + return [ + OutputParam( + "latents", + type_hint=torch.Tensor, + description="The denoise latents from denoising step, unpacked with position IDs.", + ) + ] + + @staticmethod + def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> torch.Tensor: + """ + Unpack latents using position IDs to scatter tokens into place. + + Args: + x: Packed latents tensor of shape (B, seq_len, C) + x_ids: Position IDs tensor of shape (B, seq_len, 4) with (T, H, W, L) coordinates + + Returns: + Unpacked latents tensor of shape (B, C, H, W) + """ + x_list = [] + for data, pos in zip(x, x_ids): + _, ch = data.shape # noqa: F841 + h_ids = pos[:, 1].to(torch.int64) + w_ids = pos[:, 2].to(torch.int64) + + h = torch.max(h_ids) + 1 + w = torch.max(w_ids) + 1 + + flat_ids = h_ids * w + w_ids + + out = torch.zeros((h * w, ch), device=data.device, dtype=data.dtype) + out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data) + + out = out.view(h, w, ch).permute(2, 0, 1) + x_list.append(out) + + return torch.stack(x_list, dim=0) + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + latents = block_state.latents + latent_ids = block_state.latent_ids + + latents = self._unpack_latents_with_ids(latents, latent_ids) + + block_state.latents = latents + + self.set_block_state(state, block_state) + return components, state + + +class Flux2DecodeStep(ModularPipelineBlocks): + model_name = "flux2" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKLFlux2), + ComponentSpec( + "image_processor", + Flux2ImageProcessor, + config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 32}), + default_creation_method="from_config", + ), + ] + + @property + def description(self) -> str: + return "Step that decodes the denoised latents into images using Flux2 VAE with batch norm denormalization" + + @property + def inputs(self) -> list[tuple[str, Any]]: + return [ + InputParam("output_type", default="pil"), + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The denoised latents from the denoising step", + ), + ] + + @property + def intermediate_outputs(self) -> list[str]: + return [ + OutputParam( + "images", + type_hint=Union[list[PIL.Image.Image], torch.Tensor, np.ndarray], + description="The generated images, can be a list of PIL.Image.Image, torch.Tensor or a numpy array", + ) + ] + + @staticmethod + def _unpatchify_latents(latents): + """Convert patchified latents back to regular format.""" + batch_size, num_channels_latents, height, width = latents.shape + latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), 2, 2, height, width) + latents = latents.permute(0, 1, 4, 2, 5, 3) + latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), height * 2, width * 2) + return latents + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + vae = components.vae + + latents = block_state.latents + + latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype) + latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps).to( + latents.device, latents.dtype + ) + latents = latents * latents_bn_std + latents_bn_mean + + latents = self._unpatchify_latents(latents) + + block_state.images = vae.decode(latents, return_dict=False)[0] + block_state.images = components.image_processor.postprocess( + block_state.images, output_type=block_state.output_type + ) + + self.set_block_state(state, block_state) + return components, state diff --git a/diffusers/modular_pipelines/flux2/denoise.py b/diffusers/modular_pipelines/flux2/denoise.py new file mode 100644 index 0000000000000000000000000000000000000000..66783cc36953ad69cd13353fdd1648e3f1a5d221 --- /dev/null +++ b/diffusers/modular_pipelines/flux2/denoise.py @@ -0,0 +1,509 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import torch + +from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance +from ...models import Flux2Transformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging +from ..modular_pipeline import ( + BlockState, + LoopSequentialPipelineBlocks, + ModularPipelineBlocks, + PipelineState, +) +from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam +from .modular_pipeline import Flux2KleinModularPipeline, Flux2ModularPipeline + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class Flux2LoopDenoiser(ModularPipelineBlocks): + model_name = "flux2" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ComponentSpec("transformer", Flux2Transformer2DModel)] + + @property + def description(self) -> str: + return ( + "Step within the denoising loop that denoises the latents for Flux2. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `Flux2DenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> list[tuple[str, Any]]: + return [ + InputParam("joint_attention_kwargs"), + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The latents to denoise. Shape: (B, seq_len, C)", + ), + InputParam( + "image_latents", + type_hint=torch.Tensor, + description="Packed image latents for conditioning. Shape: (B, img_seq_len, C)", + ), + InputParam( + "image_latent_ids", + type_hint=torch.Tensor, + description="Position IDs for image latents. Shape: (B, img_seq_len, 4)", + ), + InputParam( + "guidance", + required=True, + type_hint=torch.Tensor, + description="Guidance scale as a tensor", + ), + InputParam( + "prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="Text embeddings from Mistral3", + ), + InputParam( + "txt_ids", + required=True, + type_hint=torch.Tensor, + description="4D position IDs for text tokens (T, H, W, L)", + ), + InputParam( + "latent_ids", + required=True, + type_hint=torch.Tensor, + description="4D position IDs for latent tokens (T, H, W, L)", + ), + ] + + @torch.no_grad() + def __call__( + self, components: Flux2ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor + ) -> PipelineState: + latents = block_state.latents + latent_model_input = latents.to(components.transformer.dtype) + img_ids = block_state.latent_ids + + image_latents = getattr(block_state, "image_latents", None) + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1).to(components.transformer.dtype) + image_latent_ids = block_state.image_latent_ids + img_ids = torch.cat([img_ids, image_latent_ids], dim=1) + + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + noise_pred = components.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=block_state.guidance, + encoder_hidden_states=block_state.prompt_embeds, + txt_ids=block_state.txt_ids, + img_ids=img_ids, + joint_attention_kwargs=block_state.joint_attention_kwargs, + return_dict=False, + )[0] + + noise_pred = noise_pred[:, : latents.size(1)] + block_state.noise_pred = noise_pred + + return components, block_state + + +# same as Flux2LoopDenoiser but guidance=None +class Flux2KleinLoopDenoiser(ModularPipelineBlocks): + model_name = "flux2-klein" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ComponentSpec("transformer", Flux2Transformer2DModel)] + + @property + def description(self) -> str: + return ( + "Step within the denoising loop that denoises the latents for Flux2. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `Flux2DenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> list[tuple[str, Any]]: + return [ + InputParam("joint_attention_kwargs"), + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The latents to denoise. Shape: (B, seq_len, C)", + ), + InputParam( + "image_latents", + type_hint=torch.Tensor, + description="Packed image latents for conditioning. Shape: (B, img_seq_len, C)", + ), + InputParam( + "image_latent_ids", + type_hint=torch.Tensor, + description="Position IDs for image latents. Shape: (B, img_seq_len, 4)", + ), + InputParam( + "prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="Text embeddings from Qwen3", + ), + InputParam( + "txt_ids", + required=True, + type_hint=torch.Tensor, + description="4D position IDs for text tokens (T, H, W, L)", + ), + InputParam( + "latent_ids", + required=True, + type_hint=torch.Tensor, + description="4D position IDs for latent tokens (T, H, W, L)", + ), + ] + + @torch.no_grad() + def __call__( + self, components: Flux2KleinModularPipeline, block_state: BlockState, i: int, t: torch.Tensor + ) -> PipelineState: + latents = block_state.latents + latent_model_input = latents.to(components.transformer.dtype) + img_ids = block_state.latent_ids + + image_latents = getattr(block_state, "image_latents", None) + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1).to(components.transformer.dtype) + image_latent_ids = block_state.image_latent_ids + img_ids = torch.cat([img_ids, image_latent_ids], dim=1) + + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + noise_pred = components.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=None, + encoder_hidden_states=block_state.prompt_embeds, + txt_ids=block_state.txt_ids, + img_ids=img_ids, + joint_attention_kwargs=block_state.joint_attention_kwargs, + return_dict=False, + )[0] + + noise_pred = noise_pred[:, : latents.size(1)] + block_state.noise_pred = noise_pred + + return components, block_state + + +# support CFG for Flux2-Klein base model +class Flux2KleinBaseLoopDenoiser(ModularPipelineBlocks): + model_name = "flux2-klein" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("transformer", Flux2Transformer2DModel), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 4.0}), + default_creation_method="from_config", + ), + ] + + @property + def expected_configs(self) -> list[ConfigSpec]: + return [ + ConfigSpec(name="is_distilled", default=False), + ] + + @property + def description(self) -> str: + return ( + "Step within the denoising loop that denoises the latents for Flux2. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `Flux2DenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> list[tuple[str, Any]]: + return [ + InputParam("joint_attention_kwargs"), + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The latents to denoise. Shape: (B, seq_len, C)", + ), + InputParam( + "image_latents", + type_hint=torch.Tensor, + description="Packed image latents for conditioning. Shape: (B, img_seq_len, C)", + ), + InputParam( + "image_latent_ids", + type_hint=torch.Tensor, + description="Position IDs for image latents. Shape: (B, img_seq_len, 4)", + ), + InputParam( + "prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="Text embeddings from Qwen3", + ), + InputParam( + "negative_prompt_embeds", + required=False, + type_hint=torch.Tensor, + description="Negative text embeddings from Qwen3", + ), + InputParam( + "txt_ids", + required=True, + type_hint=torch.Tensor, + description="4D position IDs for text tokens (T, H, W, L)", + ), + InputParam( + "negative_txt_ids", + required=False, + type_hint=torch.Tensor, + description="4D position IDs for negative text tokens (T, H, W, L)", + ), + InputParam( + "latent_ids", + required=True, + type_hint=torch.Tensor, + description="4D position IDs for latent tokens (T, H, W, L)", + ), + ] + + @torch.no_grad() + def __call__( + self, components: Flux2KleinModularPipeline, block_state: BlockState, i: int, t: torch.Tensor + ) -> PipelineState: + latents = block_state.latents + latent_model_input = latents.to(components.transformer.dtype) + img_ids = block_state.latent_ids + + image_latents = getattr(block_state, "image_latents", None) + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1).to(components.transformer.dtype) + image_latent_ids = block_state.image_latent_ids + img_ids = torch.cat([img_ids, image_latent_ids], dim=1) + + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + guider_inputs = { + "encoder_hidden_states": ( + getattr(block_state, "prompt_embeds", None), + getattr(block_state, "negative_prompt_embeds", None), + ), + "txt_ids": ( + getattr(block_state, "txt_ids", None), + getattr(block_state, "negative_txt_ids", None), + ), + } + + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + guider_state = components.guider.prepare_inputs(guider_inputs) + + for guider_state_batch in guider_state: + components.guider.prepare_models(components.transformer) + cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()} + + noise_pred = components.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=None, + img_ids=img_ids, + joint_attention_kwargs=block_state.joint_attention_kwargs, + return_dict=False, + **cond_kwargs, + )[0] + guider_state_batch.noise_pred = noise_pred[:, : latents.size(1)] + components.guider.cleanup_models(components.transformer) + + # perform guidance + block_state.noise_pred = components.guider(guider_state)[0] + + return components, block_state + + +class Flux2LoopAfterDenoiser(ModularPipelineBlocks): + model_name = "flux2" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)] + + @property + def description(self) -> str: + return ( + "Step within the denoising loop that updates the latents after denoising. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `Flux2DenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> list[tuple[str, Any]]: + return [] + + @property + def intermediate_inputs(self) -> list[str]: + return [InputParam("generator")] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + latents_dtype = block_state.latents.dtype + block_state.latents = components.scheduler.step( + block_state.noise_pred, + t, + block_state.latents, + return_dict=False, + )[0] + + if block_state.latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + block_state.latents = block_state.latents.to(latents_dtype) + + return components, block_state + + +class Flux2DenoiseLoopWrapper(LoopSequentialPipelineBlocks): + model_name = "flux2" + + @property + def description(self) -> str: + return ( + "Pipeline block that iteratively denoises the latents over `timesteps`. " + "The specific steps within each iteration can be customized with `sub_blocks` attribute" + ) + + @property + def loop_expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ComponentSpec("transformer", Flux2Transformer2DModel), + ] + + @property + def loop_inputs(self) -> list[InputParam]: + return [ + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process.", + ), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process.", + ), + ] + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.num_warmup_steps = max( + len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0 + ) + + with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: + for i, t in enumerate(block_state.timesteps): + components, block_state = self.loop_step(components, block_state, i=i, t=t) + + if i == len(block_state.timesteps) - 1 or ( + (i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0 + ): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self.set_block_state(state, block_state) + return components, state + + +class Flux2DenoiseStep(Flux2DenoiseLoopWrapper): + block_classes = [Flux2LoopDenoiser, Flux2LoopAfterDenoiser] + block_names = ["denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoises the latents for Flux2. \n" + "Its loop logic is defined in `Flux2DenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" + " - `Flux2LoopDenoiser`\n" + " - `Flux2LoopAfterDenoiser`\n" + "This block supports both text-to-image and image-conditioned generation." + ) + + +class Flux2KleinDenoiseStep(Flux2DenoiseLoopWrapper): + block_classes = [Flux2KleinLoopDenoiser, Flux2LoopAfterDenoiser] + block_names = ["denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoises the latents for Flux2. \n" + "Its loop logic is defined in `Flux2DenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" + " - `Flux2KleinLoopDenoiser`\n" + " - `Flux2LoopAfterDenoiser`\n" + "This block supports both text-to-image and image-conditioned generation." + ) + + +class Flux2KleinBaseDenoiseStep(Flux2DenoiseLoopWrapper): + block_classes = [Flux2KleinBaseLoopDenoiser, Flux2LoopAfterDenoiser] + block_names = ["denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoises the latents for Flux2. \n" + "Its loop logic is defined in `Flux2DenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" + " - `Flux2KleinBaseLoopDenoiser`\n" + " - `Flux2LoopAfterDenoiser`\n" + "This block supports both text-to-image and image-conditioned generation." + ) diff --git a/diffusers/modular_pipelines/flux2/encoders.py b/diffusers/modular_pipelines/flux2/encoders.py new file mode 100644 index 0000000000000000000000000000000000000000..81d20a8f4c6532db2b04d1423e77c1a9f1b0347d --- /dev/null +++ b/diffusers/modular_pipelines/flux2/encoders.py @@ -0,0 +1,608 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +from transformers import AutoProcessor, Mistral3ForConditionalGeneration, Qwen2TokenizerFast, Qwen3ForCausalLM + +from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance +from ...models import AutoencoderKLFlux2 +from ...utils import logging +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam +from .modular_pipeline import Flux2KleinModularPipeline, Flux2ModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def format_text_input(prompts: list[str], system_message: str = None): + """Format prompts for Mistral3 chat template.""" + cleaned_txt = [prompt.replace("[IMG]", "") for prompt in prompts] + + return [ + [ + { + "role": "system", + "content": [{"type": "text", "text": system_message}], + }, + {"role": "user", "content": [{"type": "text", "text": prompt}]}, + ] + for prompt in cleaned_txt + ] + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class Flux2TextEncoderStep(ModularPipelineBlocks): + model_name = "flux2" + + # fmt: off + DEFAULT_SYSTEM_MESSAGE = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation." + # fmt: on + + @property + def description(self) -> str: + return "Text Encoder step that generates text embeddings using Mistral3 to guide the image generation" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("text_encoder", Mistral3ForConditionalGeneration), + ComponentSpec("tokenizer", AutoProcessor), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("prompt"), + InputParam("max_sequence_length", type_hint=int, default=512, required=False), + InputParam("text_encoder_out_layers", type_hint=tuple[int], default=(10, 20, 30), required=False), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "prompt_embeds", + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="Text embeddings from Mistral3 used to guide the image generation", + ), + ] + + @staticmethod + def check_inputs(block_state): + prompt = block_state.prompt + if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + @staticmethod + def _get_mistral_3_prompt_embeds( + text_encoder: Mistral3ForConditionalGeneration, + tokenizer: AutoProcessor, + prompt: str | list[str], + dtype: torch.dtype | None = None, + device: torch.device | None = None, + max_sequence_length: int = 512, + # fmt: off + system_message: str = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation.", + # fmt: on + hidden_states_layers: tuple[int] = (10, 20, 30), + ): + dtype = text_encoder.dtype if dtype is None else dtype + device = text_encoder.device if device is None else device + + prompt = [prompt] if isinstance(prompt, str) else prompt + + messages_batch = format_text_input(prompts=prompt, system_message=system_message) + + inputs = tokenizer.apply_chat_template( + messages_batch, + add_generation_prompt=False, + tokenize=True, + return_dict=True, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=max_sequence_length, + ) + + input_ids = inputs["input_ids"].to(device) + attention_mask = inputs["attention_mask"].to(device) + + output = text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + use_cache=False, + ) + + out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1) + out = out.to(dtype=dtype, device=device) + + batch_size, num_channels, seq_len, hidden_dim = out.shape + prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim) + + return prompt_embeds + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(block_state) + + block_state.device = components._execution_device + + prompt = block_state.prompt + if prompt is None: + prompt = "" + prompt = [prompt] if isinstance(prompt, str) else prompt + + block_state.prompt_embeds = self._get_mistral_3_prompt_embeds( + text_encoder=components.text_encoder, + tokenizer=components.tokenizer, + prompt=prompt, + device=block_state.device, + max_sequence_length=block_state.max_sequence_length, + system_message=self.DEFAULT_SYSTEM_MESSAGE, + hidden_states_layers=block_state.text_encoder_out_layers, + ) + + self.set_block_state(state, block_state) + return components, state + + +class Flux2RemoteTextEncoderStep(ModularPipelineBlocks): + model_name = "flux2" + + REMOTE_URL = "https://remote-text-encoder-flux-2.huggingface.co/predict" + + @property + def description(self) -> str: + return "Text Encoder step that generates text embeddings using a remote API endpoint" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("prompt"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "prompt_embeds", + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="Text embeddings from remote API used to guide the image generation", + ), + ] + + @staticmethod + def check_inputs(block_state): + prompt = block_state.prompt + if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}") + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + import io + + import requests + from huggingface_hub import get_token + + block_state = self.get_block_state(state) + self.check_inputs(block_state) + + block_state.device = components._execution_device + + prompt = block_state.prompt + if prompt is None: + prompt = "" + prompt = [prompt] if isinstance(prompt, str) else prompt + + response = requests.post( + self.REMOTE_URL, + json={"prompt": prompt}, + headers={ + "Authorization": f"Bearer {get_token()}", + "Content-Type": "application/json", + }, + ) + response.raise_for_status() + + block_state.prompt_embeds = torch.load(io.BytesIO(response.content), weights_only=True) + block_state.prompt_embeds = block_state.prompt_embeds.to(block_state.device) + + self.set_block_state(state, block_state) + return components, state + + +class Flux2KleinTextEncoderStep(ModularPipelineBlocks): + model_name = "flux2-klein" + + @property + def description(self) -> str: + return "Text Encoder step that generates text embeddings using Qwen3 to guide the image generation" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("text_encoder", Qwen3ForCausalLM), + ComponentSpec("tokenizer", Qwen2TokenizerFast), + ] + + @property + def expected_configs(self) -> list[ConfigSpec]: + return [ + ConfigSpec(name="is_distilled", default=True), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("prompt"), + InputParam("max_sequence_length", type_hint=int, default=512, required=False), + InputParam("text_encoder_out_layers", type_hint=tuple[int], default=(9, 18, 27), required=False), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "prompt_embeds", + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="Text embeddings from qwen3 used to guide the image generation", + ), + ] + + @staticmethod + def check_inputs(block_state): + prompt = block_state.prompt + + if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2_klein.Flux2KleinPipeline._get_qwen3_prompt_embeds + def _get_qwen3_prompt_embeds( + text_encoder: Qwen3ForCausalLM, + tokenizer: Qwen2TokenizerFast, + prompt: str | list[str], + dtype: torch.dtype | None = None, + device: torch.device | None = None, + max_sequence_length: int = 512, + hidden_states_layers: list[int] = (9, 18, 27), + ): + dtype = text_encoder.dtype if dtype is None else dtype + device = text_encoder.device if device is None else device + + prompt = [prompt] if isinstance(prompt, str) else prompt + + all_input_ids = [] + all_attention_masks = [] + + for single_prompt in prompt: + messages = [{"role": "user", "content": single_prompt}] + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + inputs = tokenizer( + text, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=max_sequence_length, + ) + + all_input_ids.append(inputs["input_ids"]) + all_attention_masks.append(inputs["attention_mask"]) + + input_ids = torch.cat(all_input_ids, dim=0).to(device) + attention_mask = torch.cat(all_attention_masks, dim=0).to(device) + + # Forward pass through the model + output = text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + use_cache=False, + ) + + # Only use outputs from intermediate layers and stack them + out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1) + out = out.to(dtype=dtype, device=device) + + batch_size, num_channels, seq_len, hidden_dim = out.shape + prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim) + + return prompt_embeds + + @torch.no_grad() + def __call__(self, components: Flux2KleinModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(block_state) + + device = components._execution_device + + prompt = block_state.prompt + if prompt is None: + prompt = "" + prompt = [prompt] if isinstance(prompt, str) else prompt + + block_state.prompt_embeds = self._get_qwen3_prompt_embeds( + text_encoder=components.text_encoder, + tokenizer=components.tokenizer, + prompt=prompt, + device=device, + max_sequence_length=block_state.max_sequence_length, + hidden_states_layers=block_state.text_encoder_out_layers, + ) + + self.set_block_state(state, block_state) + return components, state + + +class Flux2KleinBaseTextEncoderStep(ModularPipelineBlocks): + model_name = "flux2-klein" + + @property + def description(self) -> str: + return "Text Encoder step that generates text embeddings using Qwen3 to guide the image generation" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("text_encoder", Qwen3ForCausalLM), + ComponentSpec("tokenizer", Qwen2TokenizerFast), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 4.0}), + default_creation_method="from_config", + ), + ] + + @property + def expected_configs(self) -> list[ConfigSpec]: + return [ + ConfigSpec(name="is_distilled", default=False), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("prompt"), + InputParam("max_sequence_length", type_hint=int, default=512, required=False), + InputParam("text_encoder_out_layers", type_hint=tuple[int], default=(9, 18, 27), required=False), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "prompt_embeds", + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="Text embeddings from qwen3 used to guide the image generation", + ), + OutputParam( + "negative_prompt_embeds", + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="Negative text embeddings from qwen3 used to guide the image generation", + ), + ] + + @staticmethod + def check_inputs(block_state): + prompt = block_state.prompt + + if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2_klein.Flux2KleinPipeline._get_qwen3_prompt_embeds + def _get_qwen3_prompt_embeds( + text_encoder: Qwen3ForCausalLM, + tokenizer: Qwen2TokenizerFast, + prompt: str | list[str], + dtype: torch.dtype | None = None, + device: torch.device | None = None, + max_sequence_length: int = 512, + hidden_states_layers: list[int] = (9, 18, 27), + ): + dtype = text_encoder.dtype if dtype is None else dtype + device = text_encoder.device if device is None else device + + prompt = [prompt] if isinstance(prompt, str) else prompt + + all_input_ids = [] + all_attention_masks = [] + + for single_prompt in prompt: + messages = [{"role": "user", "content": single_prompt}] + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + inputs = tokenizer( + text, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=max_sequence_length, + ) + + all_input_ids.append(inputs["input_ids"]) + all_attention_masks.append(inputs["attention_mask"]) + + input_ids = torch.cat(all_input_ids, dim=0).to(device) + attention_mask = torch.cat(all_attention_masks, dim=0).to(device) + + # Forward pass through the model + output = text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + use_cache=False, + ) + + # Only use outputs from intermediate layers and stack them + out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1) + out = out.to(dtype=dtype, device=device) + + batch_size, num_channels, seq_len, hidden_dim = out.shape + prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim) + + return prompt_embeds + + @torch.no_grad() + def __call__(self, components: Flux2KleinModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(block_state) + + device = components._execution_device + + prompt = block_state.prompt + if prompt is None: + prompt = "" + prompt = [prompt] if isinstance(prompt, str) else prompt + + block_state.prompt_embeds = self._get_qwen3_prompt_embeds( + text_encoder=components.text_encoder, + tokenizer=components.tokenizer, + prompt=prompt, + device=device, + max_sequence_length=block_state.max_sequence_length, + hidden_states_layers=block_state.text_encoder_out_layers, + ) + + if components.requires_unconditional_embeds: + negative_prompt = [""] * len(prompt) + block_state.negative_prompt_embeds = self._get_qwen3_prompt_embeds( + text_encoder=components.text_encoder, + tokenizer=components.tokenizer, + prompt=negative_prompt, + device=device, + max_sequence_length=block_state.max_sequence_length, + hidden_states_layers=block_state.text_encoder_out_layers, + ) + else: + block_state.negative_prompt_embeds = None + + self.set_block_state(state, block_state) + return components, state + + +class Flux2VaeEncoderStep(ModularPipelineBlocks): + model_name = "flux2" + + @property + def description(self) -> str: + return "VAE Encoder step that encodes preprocessed images into latent representations for Flux2." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ComponentSpec("vae", AutoencoderKLFlux2)] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("condition_images", type_hint=list[torch.Tensor]), + InputParam("generator"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "image_latents", + type_hint=list[torch.Tensor], + description="List of latent representations for each reference image", + ), + ] + + @staticmethod + def _patchify_latents(latents): + """Convert latents to patchified format for Flux2.""" + batch_size, num_channels_latents, height, width = latents.shape + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 1, 3, 5, 2, 4) + latents = latents.reshape(batch_size, num_channels_latents * 4, height // 2, width // 2) + return latents + + def _encode_vae_image(self, vae: AutoencoderKLFlux2, image: torch.Tensor, generator: torch.Generator): + """Encode a single image using Flux2 VAE with batch norm normalization.""" + if image.ndim != 4: + raise ValueError(f"Expected image dims 4, got {image.ndim}.") + + image_latents = retrieve_latents(vae.encode(image), generator=generator, sample_mode="argmax") + image_latents = self._patchify_latents(image_latents) + + latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype) + latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps) + latents_bn_std = latents_bn_std.to(image_latents.device, image_latents.dtype) + image_latents = (image_latents - latents_bn_mean) / latents_bn_std + + return image_latents + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + condition_images = block_state.condition_images + + if condition_images is None: + return components, state + + device = components._execution_device + dtype = components.vae.dtype + + image_latents = [] + for image in condition_images: + image = image.to(device=device, dtype=dtype) + latent = self._encode_vae_image( + vae=components.vae, + image=image, + generator=block_state.generator, + ) + image_latents.append(latent) + + block_state.image_latents = image_latents + + self.set_block_state(state, block_state) + return components, state diff --git a/diffusers/modular_pipelines/flux2/inputs.py b/diffusers/modular_pipelines/flux2/inputs.py new file mode 100644 index 0000000000000000000000000000000000000000..7c6881f70a46c6948d04e9c9d980ee194da683c5 --- /dev/null +++ b/diffusers/modular_pipelines/flux2/inputs.py @@ -0,0 +1,242 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from ...configuration_utils import FrozenDict +from ...pipelines.flux2.image_processor import Flux2ImageProcessor +from ...utils import logging +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import Flux2ModularPipeline + + +logger = logging.get_logger(__name__) + + +class Flux2TextInputStep(ModularPipelineBlocks): + model_name = "flux2" + + @property + def description(self) -> str: + return ( + "This step:\n" + " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n" + " 2. Ensures all text embeddings have consistent batch sizes (batch_size * num_images_per_prompt)" + ) + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("num_images_per_prompt", default=1), + InputParam( + "prompt_embeds", + required=True, + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="Pre-generated text embeddings. Can be generated from text_encoder step.", + ), + ] + + @property + def intermediate_outputs(self) -> list[str]: + return [ + OutputParam( + "batch_size", + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt", + ), + OutputParam( + "dtype", + type_hint=torch.dtype, + description="Data type of model tensor inputs (determined by `prompt_embeds`)", + ), + OutputParam( + "prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="Text embeddings used to guide the image generation", + ), + ] + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.batch_size = block_state.prompt_embeds.shape[0] + block_state.dtype = block_state.prompt_embeds.dtype + + _, seq_len, _ = block_state.prompt_embeds.shape + block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) + block_state.prompt_embeds = block_state.prompt_embeds.view( + block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1 + ) + + self.set_block_state(state, block_state) + return components, state + + +class Flux2KleinBaseTextInputStep(ModularPipelineBlocks): + model_name = "flux2-klein" + + @property + def description(self) -> str: + return ( + "This step:\n" + " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n" + " 2. Ensures all text embeddings have consistent batch sizes (batch_size * num_images_per_prompt)" + ) + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("num_images_per_prompt", default=1), + InputParam( + "prompt_embeds", + required=True, + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="Pre-generated text embeddings. Can be generated from text_encoder step.", + ), + InputParam( + "negative_prompt_embeds", + required=False, + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="Pre-generated negative text embeddings. Can be generated from text_encoder step.", + ), + ] + + @property + def intermediate_outputs(self) -> list[str]: + return [ + OutputParam( + "batch_size", + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt", + ), + OutputParam( + "dtype", + type_hint=torch.dtype, + description="Data type of model tensor inputs (determined by `prompt_embeds`)", + ), + OutputParam( + "prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="Text embeddings used to guide the image generation", + ), + OutputParam( + "negative_prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="Negative text embeddings used to guide the image generation", + ), + ] + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.batch_size = block_state.prompt_embeds.shape[0] + block_state.dtype = block_state.prompt_embeds.dtype + + _, seq_len, _ = block_state.prompt_embeds.shape + block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) + block_state.prompt_embeds = block_state.prompt_embeds.view( + block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1 + ) + + if block_state.negative_prompt_embeds is not None: + _, seq_len, _ = block_state.negative_prompt_embeds.shape + block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat( + 1, block_state.num_images_per_prompt, 1 + ) + block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view( + block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1 + ) + + self.set_block_state(state, block_state) + return components, state + + +class Flux2ProcessImagesInputStep(ModularPipelineBlocks): + model_name = "flux2" + + @property + def description(self) -> str: + return "Image preprocess step for Flux2. Validates and preprocesses reference images." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec( + "image_processor", + Flux2ImageProcessor, + config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 32}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("image"), + InputParam("height"), + InputParam("width"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [OutputParam(name="condition_images", type_hint=list[torch.Tensor])] + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, state: PipelineState): + block_state = self.get_block_state(state) + images = block_state.image + + if images is None: + block_state.condition_images = None + self.set_block_state(state, block_state) + return components, state + + if not isinstance(images, list): + images = [images] + + condition_images = [] + for img in images: + components.image_processor.check_image_input(img) + + image_width, image_height = img.size + if image_width * image_height > 1024 * 1024: + img = components.image_processor._resize_to_target_area(img, 1024 * 1024) + image_width, image_height = img.size + + multiple_of = components.vae_scale_factor * 2 + image_width = (image_width // multiple_of) * multiple_of + image_height = (image_height // multiple_of) * multiple_of + condition_img = components.image_processor.preprocess( + img, height=image_height, width=image_width, resize_mode="crop" + ) + condition_images.append(condition_img) + + if block_state.height is None: + block_state.height = image_height + if block_state.width is None: + block_state.width = image_width + + block_state.condition_images = condition_images + + self.set_block_state(state, block_state) + return components, state diff --git a/diffusers/modular_pipelines/flux2/modular_blocks_flux2.py b/diffusers/modular_pipelines/flux2/modular_blocks_flux2.py new file mode 100644 index 0000000000000000000000000000000000000000..b1033a7dff9e868dabbd078822c814be0d9fc546 --- /dev/null +++ b/diffusers/modular_pipelines/flux2/modular_blocks_flux2.py @@ -0,0 +1,356 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ...utils import logging +from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks +from ..modular_pipeline_utils import InsertableDict, OutputParam +from .before_denoise import ( + Flux2PrepareGuidanceStep, + Flux2PrepareImageLatentsStep, + Flux2PrepareLatentsStep, + Flux2RoPEInputsStep, + Flux2SetTimestepsStep, +) +from .decoders import Flux2DecodeStep, Flux2UnpackLatentsStep +from .denoise import Flux2DenoiseStep +from .encoders import ( + Flux2TextEncoderStep, + Flux2VaeEncoderStep, +) +from .inputs import ( + Flux2ProcessImagesInputStep, + Flux2TextInputStep, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# auto_docstring +class Flux2VaeEncoderSequentialStep(SequentialPipelineBlocks): + """ + VAE encoder step that preprocesses, encodes, and prepares image latents for Flux2 conditioning. + + Components: + image_processor (`Flux2ImageProcessor`) vae (`AutoencoderKLFlux2`) + + Inputs: + image (`None`, *optional*): + TODO: Add description. + height (`None`, *optional*): + TODO: Add description. + width (`None`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + + Outputs: + condition_images (`list`): + TODO: Add description. + image_latents (`list`): + List of latent representations for each reference image + """ + + model_name = "flux2" + + block_classes = [Flux2ProcessImagesInputStep(), Flux2VaeEncoderStep()] + block_names = ["preprocess", "encode"] + + @property + def description(self) -> str: + return "VAE encoder step that preprocesses, encodes, and prepares image latents for Flux2 conditioning." + + +# auto_docstring +class Flux2AutoVaeEncoderStep(AutoPipelineBlocks): + """ + VAE encoder step that encodes the image inputs into their latent representations. + This is an auto pipeline block that works for image conditioning tasks. + - `Flux2VaeEncoderSequentialStep` is used when `image` is provided. + - If `image` is not provided, step will be skipped. + + Components: + image_processor (`Flux2ImageProcessor`) vae (`AutoencoderKLFlux2`) + + Inputs: + image (`None`, *optional*): + TODO: Add description. + height (`None`, *optional*): + TODO: Add description. + width (`None`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + + Outputs: + condition_images (`list`): + TODO: Add description. + image_latents (`list`): + List of latent representations for each reference image + """ + + block_classes = [Flux2VaeEncoderSequentialStep] + block_names = ["img_conditioning"] + block_trigger_inputs = ["image"] + + @property + def description(self): + return ( + "VAE encoder step that encodes the image inputs into their latent representations.\n" + "This is an auto pipeline block that works for image conditioning tasks.\n" + " - `Flux2VaeEncoderSequentialStep` is used when `image` is provided.\n" + " - If `image` is not provided, step will be skipped." + ) + + +Flux2CoreDenoiseBlocks = InsertableDict( + [ + ("input", Flux2TextInputStep()), + ("prepare_latents", Flux2PrepareLatentsStep()), + ("set_timesteps", Flux2SetTimestepsStep()), + ("prepare_guidance", Flux2PrepareGuidanceStep()), + ("prepare_rope_inputs", Flux2RoPEInputsStep()), + ("denoise", Flux2DenoiseStep()), + ("after_denoise", Flux2UnpackLatentsStep()), + ] +) + + +# auto_docstring +class Flux2CoreDenoiseStep(SequentialPipelineBlocks): + """ + Core denoise step that performs the denoising process for Flux2-dev. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) transformer (`Flux2Transformer2DModel`) + + Inputs: + num_images_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + prompt_embeds (`Tensor`): + Pre-generated text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + TODO: Add description. + width (`int`, *optional*): + TODO: Add description. + latents (`Tensor | NoneType`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + num_inference_steps (`None`, *optional*, defaults to 50): + TODO: Add description. + timesteps (`None`, *optional*): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + guidance_scale (`None`, *optional*, defaults to 4.0): + TODO: Add description. + joint_attention_kwargs (`None`, *optional*): + TODO: Add description. + image_latents (`Tensor`, *optional*): + Packed image latents for conditioning. Shape: (B, img_seq_len, C) + image_latent_ids (`Tensor`, *optional*): + Position IDs for image latents. Shape: (B, img_seq_len, 4) + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "flux2" + + block_classes = Flux2CoreDenoiseBlocks.values() + block_names = Flux2CoreDenoiseBlocks.keys() + + @property + def description(self): + return "Core denoise step that performs the denoising process for Flux2-dev." + + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + + +Flux2ImageConditionedCoreDenoiseBlocks = InsertableDict( + [ + ("input", Flux2TextInputStep()), + ("prepare_image_latents", Flux2PrepareImageLatentsStep()), + ("prepare_latents", Flux2PrepareLatentsStep()), + ("set_timesteps", Flux2SetTimestepsStep()), + ("prepare_guidance", Flux2PrepareGuidanceStep()), + ("prepare_rope_inputs", Flux2RoPEInputsStep()), + ("denoise", Flux2DenoiseStep()), + ("after_denoise", Flux2UnpackLatentsStep()), + ] +) + + +# auto_docstring +class Flux2ImageConditionedCoreDenoiseStep(SequentialPipelineBlocks): + """ + Core denoise step that performs the denoising process for Flux2-dev with image conditioning. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) transformer (`Flux2Transformer2DModel`) + + Inputs: + num_images_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + prompt_embeds (`Tensor`): + Pre-generated text embeddings. Can be generated from text_encoder step. + image_latents (`list`, *optional*): + TODO: Add description. + height (`int`, *optional*): + TODO: Add description. + width (`int`, *optional*): + TODO: Add description. + latents (`Tensor | NoneType`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + num_inference_steps (`None`, *optional*, defaults to 50): + TODO: Add description. + timesteps (`None`, *optional*): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + guidance_scale (`None`, *optional*, defaults to 4.0): + TODO: Add description. + joint_attention_kwargs (`None`, *optional*): + TODO: Add description. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "flux2" + + block_classes = Flux2ImageConditionedCoreDenoiseBlocks.values() + block_names = Flux2ImageConditionedCoreDenoiseBlocks.keys() + + @property + def description(self): + return "Core denoise step that performs the denoising process for Flux2-dev with image conditioning." + + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + + +class Flux2AutoCoreDenoiseStep(AutoPipelineBlocks): + model_name = "flux2" + + block_classes = [Flux2ImageConditionedCoreDenoiseStep, Flux2CoreDenoiseStep] + block_names = ["image_conditioned", "text2image"] + block_trigger_inputs = ["image_latents", None] + + @property + def description(self): + return ( + "Auto core denoise step that performs the denoising process for Flux2-dev." + "This is an auto pipeline block that works for text-to-image and image-conditioned generation." + " - `Flux2CoreDenoiseStep` is used for text-to-image generation.\n" + " - `Flux2ImageConditionedCoreDenoiseStep` is used for image-conditioned generation.\n" + ) + + +AUTO_BLOCKS = InsertableDict( + [ + ("text_encoder", Flux2TextEncoderStep()), + ("vae_encoder", Flux2AutoVaeEncoderStep()), + ("denoise", Flux2AutoCoreDenoiseStep()), + ("decode", Flux2DecodeStep()), + ] +) + + +# auto_docstring +class Flux2AutoBlocks(SequentialPipelineBlocks): + """ + Auto Modular pipeline for text-to-image and image-conditioned generation using Flux2. + + Supported workflows: + - `text2image`: requires `prompt` + - `image_conditioned`: requires `image`, `prompt` + + Components: + text_encoder (`Mistral3ForConditionalGeneration`) tokenizer (`AutoProcessor`) image_processor + (`Flux2ImageProcessor`) vae (`AutoencoderKLFlux2`) scheduler (`FlowMatchEulerDiscreteScheduler`) transformer + (`Flux2Transformer2DModel`) + + Inputs: + prompt (`None`, *optional*): + TODO: Add description. + max_sequence_length (`int`, *optional*, defaults to 512): + TODO: Add description. + text_encoder_out_layers (`tuple`, *optional*, defaults to (10, 20, 30)): + TODO: Add description. + image (`None`, *optional*): + TODO: Add description. + height (`None`, *optional*): + TODO: Add description. + width (`None`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + num_images_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + image_latents (`list`, *optional*): + TODO: Add description. + latents (`Tensor | NoneType`): + TODO: Add description. + num_inference_steps (`None`): + TODO: Add description. + timesteps (`None`): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + guidance_scale (`None`, *optional*, defaults to 4.0): + TODO: Add description. + joint_attention_kwargs (`None`, *optional*): + TODO: Add description. + image_latent_ids (`Tensor`, *optional*): + Position IDs for image latents. Shape: (B, img_seq_len, 4) + output_type (`None`, *optional*, defaults to pil): + TODO: Add description. + + Outputs: + images (`list`): + Generated images. + """ + + model_name = "flux2" + + block_classes = AUTO_BLOCKS.values() + block_names = AUTO_BLOCKS.keys() + _workflow_map = { + "text2image": {"prompt": True}, + "image_conditioned": {"image": True, "prompt": True}, + } + + @property + def description(self): + return "Auto Modular pipeline for text-to-image and image-conditioned generation using Flux2." + + @property + def outputs(self): + return [ + OutputParam.template("images"), + ] diff --git a/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py b/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py new file mode 100644 index 0000000000000000000000000000000000000000..5dbae43a5a7fd6c4df8ebb6314ed8c20d7c3372f --- /dev/null +++ b/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py @@ -0,0 +1,400 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ...utils import logging +from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks +from ..modular_pipeline_utils import InsertableDict, OutputParam +from .before_denoise import ( + Flux2PrepareImageLatentsStep, + Flux2PrepareLatentsStep, + Flux2RoPEInputsStep, + Flux2SetTimestepsStep, +) +from .decoders import Flux2DecodeStep, Flux2UnpackLatentsStep +from .denoise import Flux2KleinDenoiseStep +from .encoders import ( + Flux2KleinTextEncoderStep, + Flux2VaeEncoderStep, +) +from .inputs import ( + Flux2ProcessImagesInputStep, + Flux2TextInputStep, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +################ +# VAE encoder +################ + + +# auto_docstring +class Flux2KleinVaeEncoderSequentialStep(SequentialPipelineBlocks): + """ + VAE encoder step that preprocesses and encodes the image inputs into their latent representations. + + Components: + image_processor (`Flux2ImageProcessor`) vae (`AutoencoderKLFlux2`) + + Inputs: + image (`None`, *optional*): + TODO: Add description. + height (`None`, *optional*): + TODO: Add description. + width (`None`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + + Outputs: + condition_images (`list`): + TODO: Add description. + image_latents (`list`): + List of latent representations for each reference image + """ + + model_name = "flux2-klein" + + block_classes = [Flux2ProcessImagesInputStep(), Flux2VaeEncoderStep()] + block_names = ["preprocess", "encode"] + + @property + def description(self) -> str: + return "VAE encoder step that preprocesses and encodes the image inputs into their latent representations." + + +# auto_docstring +class Flux2KleinAutoVaeEncoderStep(AutoPipelineBlocks): + """ + VAE encoder step that encodes the image inputs into their latent representations. + This is an auto pipeline block that works for image conditioning tasks. + - `Flux2KleinVaeEncoderSequentialStep` is used when `image` is provided. + - If `image` is not provided, step will be skipped. + + Components: + image_processor (`Flux2ImageProcessor`) vae (`AutoencoderKLFlux2`) + + Inputs: + image (`None`, *optional*): + TODO: Add description. + height (`None`, *optional*): + TODO: Add description. + width (`None`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + + Outputs: + condition_images (`list`): + TODO: Add description. + image_latents (`list`): + List of latent representations for each reference image + """ + + model_name = "flux2-klein" + + block_classes = [Flux2KleinVaeEncoderSequentialStep] + block_names = ["img_conditioning"] + block_trigger_inputs = ["image"] + + @property + def description(self): + return ( + "VAE encoder step that encodes the image inputs into their latent representations.\n" + "This is an auto pipeline block that works for image conditioning tasks.\n" + " - `Flux2KleinVaeEncoderSequentialStep` is used when `image` is provided.\n" + " - If `image` is not provided, step will be skipped." + ) + + +### +### Core denoise +### + +Flux2KleinCoreDenoiseBlocks = InsertableDict( + [ + ("input", Flux2TextInputStep()), + ("prepare_latents", Flux2PrepareLatentsStep()), + ("set_timesteps", Flux2SetTimestepsStep()), + ("prepare_rope_inputs", Flux2RoPEInputsStep()), + ("denoise", Flux2KleinDenoiseStep()), + ("after_denoise", Flux2UnpackLatentsStep()), + ] +) + + +# auto_docstring +class Flux2KleinCoreDenoiseStep(SequentialPipelineBlocks): + """ + Core denoise step that performs the denoising process for Flux2-Klein (distilled model), for text-to-image + generation. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) transformer (`Flux2Transformer2DModel`) + + Inputs: + num_images_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + prompt_embeds (`Tensor`): + Pre-generated text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + TODO: Add description. + width (`int`, *optional*): + TODO: Add description. + latents (`Tensor | NoneType`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + num_inference_steps (`None`, *optional*, defaults to 50): + TODO: Add description. + timesteps (`None`, *optional*): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + joint_attention_kwargs (`None`, *optional*): + TODO: Add description. + image_latents (`Tensor`, *optional*): + Packed image latents for conditioning. Shape: (B, img_seq_len, C) + image_latent_ids (`Tensor`, *optional*): + Position IDs for image latents. Shape: (B, img_seq_len, 4) + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "flux2-klein" + + block_classes = Flux2KleinCoreDenoiseBlocks.values() + block_names = Flux2KleinCoreDenoiseBlocks.keys() + + @property + def description(self): + return "Core denoise step that performs the denoising process for Flux2-Klein (distilled model), for text-to-image generation." + + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + + +Flux2KleinImageConditionedCoreDenoiseBlocks = InsertableDict( + [ + ("input", Flux2TextInputStep()), + ("prepare_image_latents", Flux2PrepareImageLatentsStep()), + ("prepare_latents", Flux2PrepareLatentsStep()), + ("set_timesteps", Flux2SetTimestepsStep()), + ("prepare_rope_inputs", Flux2RoPEInputsStep()), + ("denoise", Flux2KleinDenoiseStep()), + ("after_denoise", Flux2UnpackLatentsStep()), + ] +) + + +# auto_docstring +class Flux2KleinImageConditionedCoreDenoiseStep(SequentialPipelineBlocks): + """ + Core denoise step that performs the denoising process for Flux2-Klein (distilled model) with image conditioning. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) transformer (`Flux2Transformer2DModel`) + + Inputs: + num_images_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + prompt_embeds (`Tensor`): + Pre-generated text embeddings. Can be generated from text_encoder step. + image_latents (`list`, *optional*): + TODO: Add description. + height (`int`, *optional*): + TODO: Add description. + width (`int`, *optional*): + TODO: Add description. + latents (`Tensor | NoneType`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + num_inference_steps (`None`, *optional*, defaults to 50): + TODO: Add description. + timesteps (`None`, *optional*): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + joint_attention_kwargs (`None`, *optional*): + TODO: Add description. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "flux2-klein" + + block_classes = Flux2KleinImageConditionedCoreDenoiseBlocks.values() + block_names = Flux2KleinImageConditionedCoreDenoiseBlocks.keys() + + @property + def description(self): + return "Core denoise step that performs the denoising process for Flux2-Klein (distilled model) with image conditioning." + + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + + +# auto_docstring +class Flux2KleinAutoCoreDenoiseStep(AutoPipelineBlocks): + """ + Auto core denoise step that performs the denoising process for Flux2-Klein. + This is an auto pipeline block that works for text-to-image and image-conditioned generation. + - `Flux2KleinCoreDenoiseStep` is used for text-to-image generation. + - `Flux2KleinImageConditionedCoreDenoiseStep` is used for image-conditioned generation. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) transformer (`Flux2Transformer2DModel`) + + Inputs: + num_images_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + prompt_embeds (`Tensor`): + Pre-generated text embeddings. Can be generated from text_encoder step. + image_latents (`list`, *optional*): + TODO: Add description. + height (`int`, *optional*): + TODO: Add description. + width (`int`, *optional*): + TODO: Add description. + latents (`Tensor | NoneType`): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + num_inference_steps (`None`): + TODO: Add description. + timesteps (`None`): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + joint_attention_kwargs (`None`, *optional*): + TODO: Add description. + image_latent_ids (`Tensor`, *optional*): + Position IDs for image latents. Shape: (B, img_seq_len, 4) + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "flux2-klein" + block_classes = [Flux2KleinImageConditionedCoreDenoiseStep, Flux2KleinCoreDenoiseStep] + block_names = ["image_conditioned", "text2image"] + block_trigger_inputs = ["image_latents", None] + + @property + def description(self): + return ( + "Auto core denoise step that performs the denoising process for Flux2-Klein.\n" + "This is an auto pipeline block that works for text-to-image and image-conditioned generation.\n" + " - `Flux2KleinCoreDenoiseStep` is used for text-to-image generation.\n" + " - `Flux2KleinImageConditionedCoreDenoiseStep` is used for image-conditioned generation.\n" + ) + + +### +### Auto blocks +### + + +# auto_docstring +class Flux2KleinAutoBlocks(SequentialPipelineBlocks): + """ + Auto blocks that perform the text-to-image and image-conditioned generation using Flux2-Klein. + + Supported workflows: + - `text2image`: requires `prompt` + - `image_conditioned`: requires `image`, `prompt` + + Components: + text_encoder (`Qwen3ForCausalLM`) tokenizer (`Qwen2TokenizerFast`) image_processor (`Flux2ImageProcessor`) + vae (`AutoencoderKLFlux2`) scheduler (`FlowMatchEulerDiscreteScheduler`) transformer + (`Flux2Transformer2DModel`) + + Configs: + is_distilled (default: True) + + Inputs: + prompt (`None`, *optional*): + TODO: Add description. + max_sequence_length (`int`, *optional*, defaults to 512): + TODO: Add description. + text_encoder_out_layers (`tuple`, *optional*, defaults to (9, 18, 27)): + TODO: Add description. + image (`None`, *optional*): + TODO: Add description. + height (`None`, *optional*): + TODO: Add description. + width (`None`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + num_images_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + image_latents (`list`, *optional*): + TODO: Add description. + latents (`Tensor | NoneType`): + TODO: Add description. + num_inference_steps (`None`): + TODO: Add description. + timesteps (`None`): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + joint_attention_kwargs (`None`, *optional*): + TODO: Add description. + image_latent_ids (`Tensor`, *optional*): + Position IDs for image latents. Shape: (B, img_seq_len, 4) + output_type (`None`, *optional*, defaults to pil): + TODO: Add description. + + Outputs: + images (`list`): + Generated images. + """ + + model_name = "flux2-klein" + block_classes = [ + Flux2KleinTextEncoderStep(), + Flux2KleinAutoVaeEncoderStep(), + Flux2KleinAutoCoreDenoiseStep(), + Flux2DecodeStep(), + ] + block_names = ["text_encoder", "vae_encoder", "denoise", "decode"] + _workflow_map = { + "text2image": {"prompt": True}, + "image_conditioned": {"image": True, "prompt": True}, + } + + @property + def description(self): + return "Auto blocks that perform the text-to-image and image-conditioned generation using Flux2-Klein." + + @property + def outputs(self): + return [ + OutputParam.template("images"), + ] diff --git a/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein_base.py b/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein_base.py new file mode 100644 index 0000000000000000000000000000000000000000..42e025c622b4e04602f5ab7dd13de46bdf555623 --- /dev/null +++ b/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein_base.py @@ -0,0 +1,413 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ...utils import logging +from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks +from ..modular_pipeline_utils import InsertableDict, OutputParam +from .before_denoise import ( + Flux2KleinBaseRoPEInputsStep, + Flux2PrepareImageLatentsStep, + Flux2PrepareLatentsStep, + Flux2SetTimestepsStep, +) +from .decoders import Flux2DecodeStep, Flux2UnpackLatentsStep +from .denoise import Flux2KleinBaseDenoiseStep +from .encoders import ( + Flux2KleinBaseTextEncoderStep, + Flux2VaeEncoderStep, +) +from .inputs import ( + Flux2KleinBaseTextInputStep, + Flux2ProcessImagesInputStep, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +################ +# VAE encoder +################ + + +# auto_docstring +class Flux2KleinBaseVaeEncoderSequentialStep(SequentialPipelineBlocks): + """ + VAE encoder step that preprocesses and encodes the image inputs into their latent representations. + + Components: + image_processor (`Flux2ImageProcessor`) vae (`AutoencoderKLFlux2`) + + Inputs: + image (`None`, *optional*): + TODO: Add description. + height (`None`, *optional*): + TODO: Add description. + width (`None`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + + Outputs: + condition_images (`list`): + TODO: Add description. + image_latents (`list`): + List of latent representations for each reference image + """ + + model_name = "flux2" + + block_classes = [Flux2ProcessImagesInputStep(), Flux2VaeEncoderStep()] + block_names = ["preprocess", "encode"] + + @property + def description(self) -> str: + return "VAE encoder step that preprocesses and encodes the image inputs into their latent representations." + + +# auto_docstring +class Flux2KleinBaseAutoVaeEncoderStep(AutoPipelineBlocks): + """ + VAE encoder step that encodes the image inputs into their latent representations. + This is an auto pipeline block that works for image conditioning tasks. + - `Flux2KleinBaseVaeEncoderSequentialStep` is used when `image` is provided. + - If `image` is not provided, step will be skipped. + + Components: + image_processor (`Flux2ImageProcessor`) vae (`AutoencoderKLFlux2`) + + Inputs: + image (`None`, *optional*): + TODO: Add description. + height (`None`, *optional*): + TODO: Add description. + width (`None`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + + Outputs: + condition_images (`list`): + TODO: Add description. + image_latents (`list`): + List of latent representations for each reference image + """ + + block_classes = [Flux2KleinBaseVaeEncoderSequentialStep] + block_names = ["img_conditioning"] + block_trigger_inputs = ["image"] + + @property + def description(self): + return ( + "VAE encoder step that encodes the image inputs into their latent representations.\n" + "This is an auto pipeline block that works for image conditioning tasks.\n" + " - `Flux2KleinBaseVaeEncoderSequentialStep` is used when `image` is provided.\n" + " - If `image` is not provided, step will be skipped." + ) + + +### +### Core denoise +### + +Flux2KleinBaseCoreDenoiseBlocks = InsertableDict( + [ + ("input", Flux2KleinBaseTextInputStep()), + ("prepare_latents", Flux2PrepareLatentsStep()), + ("set_timesteps", Flux2SetTimestepsStep()), + ("prepare_rope_inputs", Flux2KleinBaseRoPEInputsStep()), + ("denoise", Flux2KleinBaseDenoiseStep()), + ("after_denoise", Flux2UnpackLatentsStep()), + ] +) + + +# auto_docstring +class Flux2KleinBaseCoreDenoiseStep(SequentialPipelineBlocks): + """ + Core denoise step that performs the denoising process for Flux2-Klein (base model), for text-to-image generation. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) transformer (`Flux2Transformer2DModel`) guider + (`ClassifierFreeGuidance`) + + Configs: + is_distilled (default: False) + + Inputs: + num_images_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + prompt_embeds (`Tensor`): + Pre-generated text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + Pre-generated negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + TODO: Add description. + width (`int`, *optional*): + TODO: Add description. + latents (`Tensor | NoneType`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + num_inference_steps (`None`, *optional*, defaults to 50): + TODO: Add description. + timesteps (`None`, *optional*): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + joint_attention_kwargs (`None`, *optional*): + TODO: Add description. + image_latents (`Tensor`, *optional*): + Packed image latents for conditioning. Shape: (B, img_seq_len, C) + image_latent_ids (`Tensor`, *optional*): + Position IDs for image latents. Shape: (B, img_seq_len, 4) + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "flux2-klein" + block_classes = Flux2KleinBaseCoreDenoiseBlocks.values() + block_names = Flux2KleinBaseCoreDenoiseBlocks.keys() + + @property + def description(self): + return "Core denoise step that performs the denoising process for Flux2-Klein (base model), for text-to-image generation." + + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + + +Flux2KleinBaseImageConditionedCoreDenoiseBlocks = InsertableDict( + [ + ("input", Flux2KleinBaseTextInputStep()), + ("prepare_latents", Flux2PrepareLatentsStep()), + ("prepare_image_latents", Flux2PrepareImageLatentsStep()), + ("set_timesteps", Flux2SetTimestepsStep()), + ("prepare_rope_inputs", Flux2KleinBaseRoPEInputsStep()), + ("denoise", Flux2KleinBaseDenoiseStep()), + ("after_denoise", Flux2UnpackLatentsStep()), + ] +) + + +# auto_docstring +class Flux2KleinBaseImageConditionedCoreDenoiseStep(SequentialPipelineBlocks): + """ + Core denoise step that performs the denoising process for Flux2-Klein (base model) with image conditioning. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) transformer (`Flux2Transformer2DModel`) guider + (`ClassifierFreeGuidance`) + + Configs: + is_distilled (default: False) + + Inputs: + num_images_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + prompt_embeds (`Tensor`): + Pre-generated text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + Pre-generated negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + TODO: Add description. + width (`int`, *optional*): + TODO: Add description. + latents (`Tensor | NoneType`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + image_latents (`list`, *optional*): + TODO: Add description. + num_inference_steps (`None`, *optional*, defaults to 50): + TODO: Add description. + timesteps (`None`, *optional*): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + joint_attention_kwargs (`None`, *optional*): + TODO: Add description. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "flux2-klein" + block_classes = Flux2KleinBaseImageConditionedCoreDenoiseBlocks.values() + block_names = Flux2KleinBaseImageConditionedCoreDenoiseBlocks.keys() + + @property + def description(self): + return "Core denoise step that performs the denoising process for Flux2-Klein (base model) with image conditioning." + + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + + +# auto_docstring +class Flux2KleinBaseAutoCoreDenoiseStep(AutoPipelineBlocks): + """ + Auto core denoise step that performs the denoising process for Flux2-Klein (base model). + This is an auto pipeline block that works for text-to-image and image-conditioned generation. + - `Flux2KleinBaseCoreDenoiseStep` is used for text-to-image generation. + - `Flux2KleinBaseImageConditionedCoreDenoiseStep` is used for image-conditioned generation. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) transformer (`Flux2Transformer2DModel`) guider + (`ClassifierFreeGuidance`) + + Configs: + is_distilled (default: False) + + Inputs: + num_images_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + prompt_embeds (`Tensor`): + Pre-generated text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + Pre-generated negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + TODO: Add description. + width (`int`, *optional*): + TODO: Add description. + latents (`Tensor | NoneType`): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + image_latents (`list`, *optional*): + TODO: Add description. + num_inference_steps (`None`): + TODO: Add description. + timesteps (`None`): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + joint_attention_kwargs (`None`, *optional*): + TODO: Add description. + image_latent_ids (`Tensor`, *optional*): + Position IDs for image latents. Shape: (B, img_seq_len, 4) + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "flux2-klein" + block_classes = [Flux2KleinBaseImageConditionedCoreDenoiseStep, Flux2KleinBaseCoreDenoiseStep] + block_names = ["image_conditioned", "text2image"] + block_trigger_inputs = ["image_latents", None] + + @property + def description(self): + return ( + "Auto core denoise step that performs the denoising process for Flux2-Klein (base model).\n" + "This is an auto pipeline block that works for text-to-image and image-conditioned generation.\n" + " - `Flux2KleinBaseCoreDenoiseStep` is used for text-to-image generation.\n" + " - `Flux2KleinBaseImageConditionedCoreDenoiseStep` is used for image-conditioned generation.\n" + ) + + +### +### Auto blocks +### + + +# auto_docstring +class Flux2KleinBaseAutoBlocks(SequentialPipelineBlocks): + """ + Auto blocks that perform the text-to-image and image-conditioned generation using Flux2-Klein (base model). + + Supported workflows: + - `text2image`: requires `prompt` + - `image_conditioned`: requires `image`, `prompt` + + Components: + text_encoder (`Qwen3ForCausalLM`) tokenizer (`Qwen2TokenizerFast`) guider (`ClassifierFreeGuidance`) + image_processor (`Flux2ImageProcessor`) vae (`AutoencoderKLFlux2`) scheduler + (`FlowMatchEulerDiscreteScheduler`) transformer (`Flux2Transformer2DModel`) + + Configs: + is_distilled (default: False) + + Inputs: + prompt (`None`, *optional*): + TODO: Add description. + max_sequence_length (`int`, *optional*, defaults to 512): + TODO: Add description. + text_encoder_out_layers (`tuple`, *optional*, defaults to (9, 18, 27)): + TODO: Add description. + image (`None`, *optional*): + TODO: Add description. + height (`None`, *optional*): + TODO: Add description. + width (`None`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + num_images_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + latents (`Tensor | NoneType`): + TODO: Add description. + image_latents (`list`, *optional*): + TODO: Add description. + num_inference_steps (`None`): + TODO: Add description. + timesteps (`None`): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + joint_attention_kwargs (`None`, *optional*): + TODO: Add description. + image_latent_ids (`Tensor`, *optional*): + Position IDs for image latents. Shape: (B, img_seq_len, 4) + output_type (`None`, *optional*, defaults to pil): + TODO: Add description. + + Outputs: + images (`list`): + Generated images. + """ + + model_name = "flux2-klein" + block_classes = [ + Flux2KleinBaseTextEncoderStep(), + Flux2KleinBaseAutoVaeEncoderStep(), + Flux2KleinBaseAutoCoreDenoiseStep(), + Flux2DecodeStep(), + ] + block_names = ["text_encoder", "vae_encoder", "denoise", "decode"] + _workflow_map = { + "text2image": {"prompt": True}, + "image_conditioned": {"image": True, "prompt": True}, + } + + @property + def description(self): + return "Auto blocks that perform the text-to-image and image-conditioned generation using Flux2-Klein (base model)." + + @property + def outputs(self): + return [ + OutputParam.template("images"), + ] diff --git a/diffusers/modular_pipelines/flux2/modular_pipeline.py b/diffusers/modular_pipelines/flux2/modular_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..31ba5aec7cfbeedf1d842e918650da0243719e25 --- /dev/null +++ b/diffusers/modular_pipelines/flux2/modular_pipeline.py @@ -0,0 +1,99 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ...loaders import Flux2LoraLoaderMixin +from ...utils import logging +from ..modular_pipeline import ModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class Flux2ModularPipeline(ModularPipeline, Flux2LoraLoaderMixin): + """ + A ModularPipeline for Flux2. + + > [!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "Flux2AutoBlocks" + + @property + def default_height(self): + return self.default_sample_size * self.vae_scale_factor + + @property + def default_width(self): + return self.default_sample_size * self.vae_scale_factor + + @property + def default_sample_size(self): + return 128 + + @property + def vae_scale_factor(self): + vae_scale_factor = 8 + if getattr(self, "vae", None) is not None: + vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + return vae_scale_factor + + @property + def num_channels_latents(self): + num_channels_latents = 32 + if getattr(self, "transformer", None): + num_channels_latents = self.transformer.config.in_channels // 4 + return num_channels_latents + + +class Flux2KleinModularPipeline(Flux2ModularPipeline): + """ + A ModularPipeline for Flux2-Klein (distilled model). + + > [!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "Flux2KleinAutoBlocks" + + @property + def requires_unconditional_embeds(self): + if hasattr(self.config, "is_distilled") and self.config.is_distilled: + return False + + requires_unconditional_embeds = False + if hasattr(self, "guider") and self.guider is not None: + requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1 + + return requires_unconditional_embeds + + +class Flux2KleinBaseModularPipeline(Flux2ModularPipeline): + """ + A ModularPipeline for Flux2-Klein (base model). + + > [!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "Flux2KleinBaseAutoBlocks" + + @property + def requires_unconditional_embeds(self): + if hasattr(self.config, "is_distilled") and self.config.is_distilled: + return False + + requires_unconditional_embeds = False + if hasattr(self, "guider") and self.guider is not None: + requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1 + + return requires_unconditional_embeds diff --git a/diffusers/modular_pipelines/qwenimage/__init__.py b/diffusers/modular_pipelines/qwenimage/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2e6af4495b376219a6bb6bcfaeb3b70bace9d89d --- /dev/null +++ b/diffusers/modular_pipelines/qwenimage/__init__.py @@ -0,0 +1,63 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["modular_blocks_qwenimage"] = ["QwenImageAutoBlocks"] + _import_structure["modular_blocks_qwenimage_edit"] = ["QwenImageEditAutoBlocks"] + _import_structure["modular_blocks_qwenimage_edit_plus"] = ["QwenImageEditPlusAutoBlocks"] + _import_structure["modular_blocks_qwenimage_layered"] = ["QwenImageLayeredAutoBlocks"] + _import_structure["modular_pipeline"] = [ + "QwenImageEditModularPipeline", + "QwenImageEditPlusModularPipeline", + "QwenImageLayeredModularPipeline", + "QwenImageModularPipeline", + ] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .modular_blocks_qwenimage import QwenImageAutoBlocks + from .modular_blocks_qwenimage_edit import QwenImageEditAutoBlocks + from .modular_blocks_qwenimage_edit_plus import QwenImageEditPlusAutoBlocks + from .modular_blocks_qwenimage_layered import QwenImageLayeredAutoBlocks + from .modular_pipeline import ( + QwenImageEditModularPipeline, + QwenImageEditPlusModularPipeline, + QwenImageLayeredModularPipeline, + QwenImageModularPipeline, + ) +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/modular_pipelines/qwenimage/before_denoise.py b/diffusers/modular_pipelines/qwenimage/before_denoise.py new file mode 100644 index 0000000000000000000000000000000000000000..51b5c6ac8c3db39fa8b1cf8c333a803fb9bba2a2 --- /dev/null +++ b/diffusers/modular_pipelines/qwenimage/before_denoise.py @@ -0,0 +1,1330 @@ +# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect + +import numpy as np +import torch + +from ...models import QwenImageControlNetModel, QwenImageMultiControlNetModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils.torch_utils import randn_tensor, unwrap_module +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import QwenImageLayeredPachifier, QwenImageModularPipeline, QwenImagePachifier + + +# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# modified from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps +def get_timesteps(scheduler, num_inference_steps, strength): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = scheduler.timesteps[t_start * scheduler.order :] + if hasattr(scheduler, "set_begin_index"): + scheduler.set_begin_index(t_start * scheduler.order) + + return timesteps, num_inference_steps - t_start + + +# ==================== +# 1. PREPARE LATENTS +# ==================== + + +# auto_docstring +class QwenImagePrepareLatentsStep(ModularPipelineBlocks): + """ + Prepare initial random noise for the generation process + + Components: + pachifier (`QwenImagePachifier`) + + Inputs: + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + batch_size (`int`, *optional*, defaults to 1): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can + be generated in input step. + dtype (`dtype`, *optional*, defaults to torch.float32): + The dtype of the model inputs, can be generated in input step. + + Outputs: + height (`int`): + if not set, updated to default value + width (`int`): + if not set, updated to default value + latents (`Tensor`): + The initial latents to use for the denoising process + """ + + model_name = "qwenimage" + + @property + def description(self) -> str: + return "Prepare initial random noise for the generation process" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("latents"), + InputParam.template("height"), + InputParam.template("width"), + InputParam.template("num_images_per_prompt"), + InputParam.template("generator"), + InputParam.template("batch_size"), + InputParam.template("dtype"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam(name="height", type_hint=int, description="if not set, updated to default value"), + OutputParam(name="width", type_hint=int, description="if not set, updated to default value"), + OutputParam( + name="latents", + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process", + ), + ] + + @staticmethod + def check_inputs(height, width, vae_scale_factor): + if height is not None and height % (vae_scale_factor * 2) != 0: + raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}") + + if width is not None and width % (vae_scale_factor * 2) != 0: + raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}") + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + self.check_inputs( + height=block_state.height, + width=block_state.width, + vae_scale_factor=components.vae_scale_factor, + ) + + device = components._execution_device + batch_size = block_state.batch_size * block_state.num_images_per_prompt + + # we can update the height and width here since it's used to generate the initial + block_state.height = block_state.height or components.default_height + block_state.width = block_state.width or components.default_width + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + latent_height = 2 * (int(block_state.height) // (components.vae_scale_factor * 2)) + latent_width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2)) + + shape = (batch_size, components.num_channels_latents, 1, latent_height, latent_width) + if isinstance(block_state.generator, list) and len(block_state.generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if block_state.latents is None: + block_state.latents = randn_tensor( + shape, generator=block_state.generator, device=device, dtype=block_state.dtype + ) + block_state.latents = components.pachifier.pack_latents(block_state.latents) + + self.set_block_state(state, block_state) + return components, state + + +# auto_docstring +class QwenImageLayeredPrepareLatentsStep(ModularPipelineBlocks): + """ + Prepare initial random noise (B, layers+1, C, H, W) for the generation process + + Components: + pachifier (`QwenImageLayeredPachifier`) + + Inputs: + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + layers (`int`, *optional*, defaults to 4): + Number of layers to extract from the image + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + batch_size (`int`, *optional*, defaults to 1): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can + be generated in input step. + dtype (`dtype`, *optional*, defaults to torch.float32): + The dtype of the model inputs, can be generated in input step. + + Outputs: + height (`int`): + if not set, updated to default value + width (`int`): + if not set, updated to default value + latents (`Tensor`): + The initial latents to use for the denoising process + """ + + model_name = "qwenimage-layered" + + @property + def description(self) -> str: + return "Prepare initial random noise (B, layers+1, C, H, W) for the generation process" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("pachifier", QwenImageLayeredPachifier, default_creation_method="from_config"), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("latents"), + InputParam.template("height"), + InputParam.template("width"), + InputParam.template("layers"), + InputParam.template("num_images_per_prompt"), + InputParam.template("generator"), + InputParam.template("batch_size"), + InputParam.template("dtype"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam(name="height", type_hint=int, description="if not set, updated to default value"), + OutputParam(name="width", type_hint=int, description="if not set, updated to default value"), + OutputParam( + name="latents", + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process", + ), + ] + + @staticmethod + def check_inputs(height, width, vae_scale_factor): + if height is not None and height % (vae_scale_factor * 2) != 0: + raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}") + + if width is not None and width % (vae_scale_factor * 2) != 0: + raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}") + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + self.check_inputs( + height=block_state.height, + width=block_state.width, + vae_scale_factor=components.vae_scale_factor, + ) + + device = components._execution_device + batch_size = block_state.batch_size * block_state.num_images_per_prompt + + # we can update the height and width here since it's used to generate the initial + block_state.height = block_state.height or components.default_height + block_state.width = block_state.width or components.default_width + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + latent_height = 2 * (int(block_state.height) // (components.vae_scale_factor * 2)) + latent_width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2)) + + shape = (batch_size, block_state.layers + 1, components.num_channels_latents, latent_height, latent_width) + if isinstance(block_state.generator, list) and len(block_state.generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if block_state.latents is None: + block_state.latents = randn_tensor( + shape, generator=block_state.generator, device=device, dtype=block_state.dtype + ) + block_state.latents = components.pachifier.pack_latents(block_state.latents) + + self.set_block_state(state, block_state) + return components, state + + +# auto_docstring +class QwenImagePrepareLatentsWithStrengthStep(ModularPipelineBlocks): + """ + Step that adds noise to image latents for image-to-image/inpainting. Should be run after set_timesteps, + prepare_latents. Both noise and image latents should alreadybe patchified. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + latents (`Tensor`): + The initial random noised, can be generated in prepare latent step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (Can be + generated from vae encoder and updated in input step.) + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + + Outputs: + initial_noise (`Tensor`): + The initial random noised used for inpainting denoising. + latents (`Tensor`): + The scaled noisy latents to use for inpainting/image-to-image denoising. + """ + + model_name = "qwenimage" + + @property + def description(self) -> str: + return "Step that adds noise to image latents for image-to-image/inpainting. Should be run after set_timesteps, prepare_latents. Both noise and image latents should alreadybe patchified." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + name="latents", + required=True, + type_hint=torch.Tensor, + description="The initial random noised, can be generated in prepare latent step.", + ), + InputParam.template("image_latents", note="Can be generated from vae encoder and updated in input step."), + InputParam( + name="timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + name="initial_noise", + type_hint=torch.Tensor, + description="The initial random noised used for inpainting denoising.", + ), + OutputParam( + name="latents", + type_hint=torch.Tensor, + description="The scaled noisy latents to use for inpainting/image-to-image denoising.", + ), + ] + + @staticmethod + def check_inputs(image_latents, latents): + if image_latents.shape[0] != latents.shape[0]: + raise ValueError( + f"`image_latents` must have have same batch size as `latents`, but got {image_latents.shape[0]} and {latents.shape[0]}" + ) + + if image_latents.ndim != 3: + raise ValueError(f"`image_latents` must have 3 dimensions (patchified), but got {image_latents.ndim}") + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + self.check_inputs( + image_latents=block_state.image_latents, + latents=block_state.latents, + ) + + # prepare latent timestep + latent_timestep = block_state.timesteps[:1].repeat(block_state.latents.shape[0]) + + # make copy of initial_noise + block_state.initial_noise = block_state.latents + + # scale noise + block_state.latents = components.scheduler.scale_noise( + block_state.image_latents, latent_timestep, block_state.latents + ) + + self.set_block_state(state, block_state) + + return components, state + + +# auto_docstring +class QwenImageCreateMaskLatentsStep(ModularPipelineBlocks): + """ + Step that creates mask latents from preprocessed mask_image by interpolating to latent space. + + Components: + pachifier (`QwenImagePachifier`) + + Inputs: + processed_mask_image (`Tensor`): + The processed mask to use for the inpainting process. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + dtype (`dtype`, *optional*, defaults to torch.float32): + The dtype of the model inputs, can be generated in input step. + + Outputs: + mask (`Tensor`): + The mask to use for the inpainting process. + """ + + model_name = "qwenimage" + + @property + def description(self) -> str: + return "Step that creates mask latents from preprocessed mask_image by interpolating to latent space." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + name="processed_mask_image", + required=True, + type_hint=torch.Tensor, + description="The processed mask to use for the inpainting process.", + ), + InputParam.template("height", required=True), + InputParam.template("width", required=True), + InputParam.template("dtype"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + name="mask", type_hint=torch.Tensor, description="The mask to use for the inpainting process." + ), + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + device = components._execution_device + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + + height_latents = 2 * (int(block_state.height) // (components.vae_scale_factor * 2)) + width_latents = 2 * (int(block_state.width) // (components.vae_scale_factor * 2)) + + block_state.mask = torch.nn.functional.interpolate( + block_state.processed_mask_image, + size=(height_latents, width_latents), + ) + + block_state.mask = block_state.mask.unsqueeze(2) + block_state.mask = block_state.mask.repeat(1, components.num_channels_latents, 1, 1, 1) + block_state.mask = block_state.mask.to(device=device, dtype=block_state.dtype) + + block_state.mask = components.pachifier.pack_latents(block_state.mask) + + self.set_block_state(state, block_state) + + return components, state + + +# ==================== +# 2. SET TIMESTEPS +# ==================== + + +# auto_docstring +class QwenImageSetTimestepsStep(ModularPipelineBlocks): + """ + Step that sets the scheduler's timesteps for text-to-image generation. Should be run after prepare latents step. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`list`, *optional*): + Custom sigmas for the denoising process. + latents (`Tensor`): + The initial random noised latents for the denoising process. Can be generated in prepare latents step. + + Outputs: + timesteps (`Tensor`): + The timesteps to use for the denoising process + """ + + model_name = "qwenimage" + + @property + def description(self) -> str: + return "Step that sets the scheduler's timesteps for text-to-image generation. Should be run after prepare latents step." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("num_inference_steps"), + InputParam.template("sigmas"), + InputParam( + name="latents", + required=True, + type_hint=torch.Tensor, + description="The initial random noised latents for the denoising process. Can be generated in prepare latents step.", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + name="timesteps", type_hint=torch.Tensor, description="The timesteps to use for the denoising process" + ), + ] + + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + device = components._execution_device + sigmas = ( + np.linspace(1.0, 1 / block_state.num_inference_steps, block_state.num_inference_steps) + if block_state.sigmas is None + else block_state.sigmas + ) + + mu = calculate_shift( + image_seq_len=block_state.latents.shape[1], + base_seq_len=components.scheduler.config.get("base_image_seq_len", 256), + max_seq_len=components.scheduler.config.get("max_image_seq_len", 4096), + base_shift=components.scheduler.config.get("base_shift", 0.5), + max_shift=components.scheduler.config.get("max_shift", 1.15), + ) + block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( + scheduler=components.scheduler, + num_inference_steps=block_state.num_inference_steps, + device=device, + sigmas=sigmas, + mu=mu, + ) + + components.scheduler.set_begin_index(0) + + self.set_block_state(state, block_state) + + return components, state + + +# auto_docstring +class QwenImageLayeredSetTimestepsStep(ModularPipelineBlocks): + """ + Set timesteps step for QwenImage Layered with custom mu calculation based on image_latents. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`list`, *optional*): + Custom sigmas for the denoising process. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + + Outputs: + timesteps (`Tensor`): + The timesteps to use for the denoising process. + """ + + model_name = "qwenimage-layered" + + @property + def description(self) -> str: + return "Set timesteps step for QwenImage Layered with custom mu calculation based on image_latents." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("num_inference_steps"), + InputParam.template("sigmas"), + InputParam.template("image_latents"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + name="timesteps", type_hint=torch.Tensor, description="The timesteps to use for the denoising process." + ), + ] + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + device = components._execution_device + + # Layered-specific mu calculation + base_seqlen = 256 * 256 / 16 / 16 # = 256 + mu = (block_state.image_latents.shape[1] / base_seqlen) ** 0.5 + + # Default sigmas if not provided + sigmas = ( + np.linspace(1.0, 1 / block_state.num_inference_steps, block_state.num_inference_steps) + if block_state.sigmas is None + else block_state.sigmas + ) + + block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( + components.scheduler, + block_state.num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + + components.scheduler.set_begin_index(0) + + self.set_block_state(state, block_state) + return components, state + + +# auto_docstring +class QwenImageSetTimestepsWithStrengthStep(ModularPipelineBlocks): + """ + Step that sets the scheduler's timesteps for image-to-image generation, and inpainting. Should be run after prepare + latents step. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`list`, *optional*): + Custom sigmas for the denoising process. + latents (`Tensor`): + The latents to use for the denoising process. Can be generated in prepare latents step. + strength (`float`, *optional*, defaults to 0.9): + Strength for img2img/inpainting. + + Outputs: + timesteps (`Tensor`): + The timesteps to use for the denoising process. + num_inference_steps (`int`): + The number of denoising steps to perform at inference time. Updated based on strength. + """ + + model_name = "qwenimage" + + @property + def description(self) -> str: + return "Step that sets the scheduler's timesteps for image-to-image generation, and inpainting. Should be run after prepare latents step." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("num_inference_steps"), + InputParam.template("sigmas"), + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The latents to use for the denoising process. Can be generated in prepare latents step.", + ), + InputParam.template("strength", default=0.9), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + name="timesteps", + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process.", + ), + OutputParam( + name="num_inference_steps", + type_hint=int, + description="The number of denoising steps to perform at inference time. Updated based on strength.", + ), + ] + + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + device = components._execution_device + sigmas = ( + np.linspace(1.0, 1 / block_state.num_inference_steps, block_state.num_inference_steps) + if block_state.sigmas is None + else block_state.sigmas + ) + + mu = calculate_shift( + image_seq_len=block_state.latents.shape[1], + base_seq_len=components.scheduler.config.get("base_image_seq_len", 256), + max_seq_len=components.scheduler.config.get("max_image_seq_len", 4096), + base_shift=components.scheduler.config.get("base_shift", 0.5), + max_shift=components.scheduler.config.get("max_shift", 1.15), + ) + block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( + scheduler=components.scheduler, + num_inference_steps=block_state.num_inference_steps, + device=device, + sigmas=sigmas, + mu=mu, + ) + + block_state.timesteps, block_state.num_inference_steps = get_timesteps( + scheduler=components.scheduler, + num_inference_steps=block_state.num_inference_steps, + strength=block_state.strength, + ) + + self.set_block_state(state, block_state) + + return components, state + + +# ==================== +# 3. OTHER INPUTS FOR DENOISER +# ==================== + +## RoPE inputs for denoiser + + +# auto_docstring +class QwenImageRoPEInputsStep(ModularPipelineBlocks): + """ + Step that prepares the RoPE inputs for the denoising process. Should be place after prepare_latents step + + Inputs: + batch_size (`int`, *optional*, defaults to 1): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can + be generated in input step. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + + Outputs: + img_shapes (`list`): + The shapes of the images latents, used for RoPE calculation + """ + + model_name = "qwenimage" + + @property + def description(self) -> str: + return ( + "Step that prepares the RoPE inputs for the denoising process. Should be place after prepare_latents step" + ) + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("batch_size"), + InputParam.template("height", required=True), + InputParam.template("width", required=True), + InputParam.template("prompt_embeds_mask"), + InputParam.template("negative_prompt_embeds_mask"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + name="img_shapes", + kwargs_type="denoiser_input_fields", + type_hint=list[list[tuple[int, int, int]]], + description="The shapes of the images latents, used for RoPE calculation", + ), + ] + + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.img_shapes = [ + [ + ( + 1, + block_state.height // components.vae_scale_factor // 2, + block_state.width // components.vae_scale_factor // 2, + ) + ] + ] * block_state.batch_size + + self.set_block_state(state, block_state) + + return components, state + + +# auto_docstring +class QwenImageEditRoPEInputsStep(ModularPipelineBlocks): + """ + Step that prepares the RoPE inputs for denoising process. This is used in QwenImage Edit. Should be placed after + prepare_latents step + + Inputs: + batch_size (`int`, *optional*, defaults to 1): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can + be generated in input step. + image_height (`int`): + The height of the reference image. Can be generated in input step. + image_width (`int`): + The width of the reference image. Can be generated in input step. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + + Outputs: + img_shapes (`list`): + The shapes of the images latents, used for RoPE calculation + """ + + model_name = "qwenimage" + + @property + def description(self) -> str: + return "Step that prepares the RoPE inputs for denoising process. This is used in QwenImage Edit. Should be placed after prepare_latents step" + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("batch_size"), + InputParam( + name="image_height", + required=True, + type_hint=int, + description="The height of the reference image. Can be generated in input step.", + ), + InputParam( + name="image_width", + required=True, + type_hint=int, + description="The width of the reference image. Can be generated in input step.", + ), + InputParam.template("height", required=True), + InputParam.template("width", required=True), + InputParam.template("prompt_embeds_mask"), + InputParam.template("negative_prompt_embeds_mask"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + name="img_shapes", + kwargs_type="denoiser_input_fields", + type_hint=list[list[tuple[int, int, int]]], + description="The shapes of the images latents, used for RoPE calculation", + ), + ] + + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + # for edit, image size can be different from the target size (height/width) + block_state.img_shapes = [ + [ + ( + 1, + block_state.height // components.vae_scale_factor // 2, + block_state.width // components.vae_scale_factor // 2, + ), + ( + 1, + block_state.image_height // components.vae_scale_factor // 2, + block_state.image_width // components.vae_scale_factor // 2, + ), + ] + ] * block_state.batch_size + + self.set_block_state(state, block_state) + + return components, state + + +# auto_docstring +class QwenImageEditPlusRoPEInputsStep(ModularPipelineBlocks): + """ + Step that prepares the RoPE inputs for denoising process. This is used in QwenImage Edit Plus. + Unlike Edit, Edit Plus handles lists of image_height/image_width for multiple reference images. Should be placed + after prepare_latents step. + + Inputs: + batch_size (`int`, *optional*, defaults to 1): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can + be generated in input step. + image_height (`list`): + The heights of the reference images. Can be generated in input step. + image_width (`list`): + The widths of the reference images. Can be generated in input step. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + + Outputs: + img_shapes (`list`): + The shapes of the image latents, used for RoPE calculation + txt_seq_lens (`list`): + The sequence lengths of the prompt embeds, used for RoPE calculation + negative_txt_seq_lens (`list`): + The sequence lengths of the negative prompt embeds, used for RoPE calculation + """ + + model_name = "qwenimage-edit-plus" + + @property + def description(self) -> str: + return ( + "Step that prepares the RoPE inputs for denoising process. This is used in QwenImage Edit Plus.\n" + "Unlike Edit, Edit Plus handles lists of image_height/image_width for multiple reference images.\n" + "Should be placed after prepare_latents step." + ) + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("batch_size"), + InputParam( + name="image_height", + required=True, + type_hint=list[int], + description="The heights of the reference images. Can be generated in input step.", + ), + InputParam( + name="image_width", + required=True, + type_hint=list[int], + description="The widths of the reference images. Can be generated in input step.", + ), + InputParam.template("height", required=True), + InputParam.template("width", required=True), + InputParam.template("prompt_embeds_mask"), + InputParam.template("negative_prompt_embeds_mask"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + name="img_shapes", + kwargs_type="denoiser_input_fields", + type_hint=list[list[tuple[int, int, int]]], + description="The shapes of the image latents, used for RoPE calculation", + ), + OutputParam( + name="txt_seq_lens", + kwargs_type="denoiser_input_fields", + type_hint=list[int], + description="The sequence lengths of the prompt embeds, used for RoPE calculation", + ), + OutputParam( + name="negative_txt_seq_lens", + kwargs_type="denoiser_input_fields", + type_hint=list[int], + description="The sequence lengths of the negative prompt embeds, used for RoPE calculation", + ), + ] + + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + vae_scale_factor = components.vae_scale_factor + + # Edit Plus: image_height and image_width are lists + block_state.img_shapes = [ + [ + (1, block_state.height // vae_scale_factor // 2, block_state.width // vae_scale_factor // 2), + *[ + (1, img_height // vae_scale_factor // 2, img_width // vae_scale_factor // 2) + for img_height, img_width in zip(block_state.image_height, block_state.image_width) + ], + ] + ] * block_state.batch_size + + block_state.txt_seq_lens = ( + block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None + ) + block_state.negative_txt_seq_lens = ( + block_state.negative_prompt_embeds_mask.sum(dim=1).tolist() + if block_state.negative_prompt_embeds_mask is not None + else None + ) + + self.set_block_state(state, block_state) + + return components, state + + +# auto_docstring +class QwenImageLayeredRoPEInputsStep(ModularPipelineBlocks): + """ + Step that prepares the RoPE inputs for the denoising process. Should be place after prepare_latents step + + Inputs: + batch_size (`int`, *optional*, defaults to 1): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can + be generated in input step. + layers (`int`, *optional*, defaults to 4): + Number of layers to extract from the image + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + + Outputs: + img_shapes (`list`): + The shapes of the image latents, used for RoPE calculation + txt_seq_lens (`list`): + The sequence lengths of the prompt embeds, used for RoPE calculation + negative_txt_seq_lens (`list`): + The sequence lengths of the negative prompt embeds, used for RoPE calculation + additional_t_cond (`Tensor`): + The additional t cond, used for RoPE calculation + """ + + model_name = "qwenimage-layered" + + @property + def description(self) -> str: + return ( + "Step that prepares the RoPE inputs for the denoising process. Should be place after prepare_latents step" + ) + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("batch_size"), + InputParam.template("layers"), + InputParam.template("height", required=True), + InputParam.template("width", required=True), + InputParam.template("prompt_embeds_mask"), + InputParam.template("negative_prompt_embeds_mask"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + name="img_shapes", + type_hint=list[list[tuple[int, int, int]]], + kwargs_type="denoiser_input_fields", + description="The shapes of the image latents, used for RoPE calculation", + ), + OutputParam( + name="txt_seq_lens", + type_hint=list[int], + kwargs_type="denoiser_input_fields", + description="The sequence lengths of the prompt embeds, used for RoPE calculation", + ), + OutputParam( + name="negative_txt_seq_lens", + type_hint=list[int], + kwargs_type="denoiser_input_fields", + description="The sequence lengths of the negative prompt embeds, used for RoPE calculation", + ), + OutputParam( + name="additional_t_cond", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="The additional t cond, used for RoPE calculation", + ), + ] + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + device = components._execution_device + + # All shapes are the same for Layered + shape = ( + 1, + block_state.height // components.vae_scale_factor // 2, + block_state.width // components.vae_scale_factor // 2, + ) + + # layers+1 output shapes + 1 condition shape (all same) + block_state.img_shapes = [[shape] * (block_state.layers + 2)] * block_state.batch_size + + # txt_seq_lens + block_state.txt_seq_lens = ( + block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None + ) + block_state.negative_txt_seq_lens = ( + block_state.negative_prompt_embeds_mask.sum(dim=1).tolist() + if block_state.negative_prompt_embeds_mask is not None + else None + ) + + block_state.additional_t_cond = torch.tensor([0] * block_state.batch_size).to(device=device, dtype=torch.long) + + self.set_block_state(state, block_state) + return components, state + + +## ControlNet inputs for denoiser + + +# auto_docstring +class QwenImageControlNetBeforeDenoiserStep(ModularPipelineBlocks): + """ + step that prepare inputs for controlnet. Insert before the Denoise Step, after set_timesteps step. + + Components: + controlnet (`QwenImageControlNetModel`) + + Inputs: + control_guidance_start (`float`, *optional*, defaults to 0.0): + When to start applying ControlNet. + control_guidance_end (`float`, *optional*, defaults to 1.0): + When to stop applying ControlNet. + controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): + Scale for ControlNet conditioning. + control_image_latents (`Tensor`): + The control image latents to use for the denoising process. Can be generated in controlnet vae encoder + step. + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + + Outputs: + controlnet_keep (`list`): + The controlnet keep values + """ + + model_name = "qwenimage" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("controlnet", QwenImageControlNetModel), + ] + + @property + def description(self) -> str: + return "step that prepare inputs for controlnet. Insert before the Denoise Step, after set_timesteps step." + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("control_guidance_start"), + InputParam.template("control_guidance_end"), + InputParam.template("controlnet_conditioning_scale"), + InputParam( + name="control_image_latents", + required=True, + type_hint=torch.Tensor, + description="The control image latents to use for the denoising process. Can be generated in controlnet vae encoder step.", + ), + InputParam( + name="timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam("controlnet_keep", type_hint=list[float], description="The controlnet keep values"), + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + controlnet = unwrap_module(components.controlnet) + + # control_guidance_start/control_guidance_end (align format) + if not isinstance(block_state.control_guidance_start, list) and isinstance( + block_state.control_guidance_end, list + ): + block_state.control_guidance_start = len(block_state.control_guidance_end) * [ + block_state.control_guidance_start + ] + elif not isinstance(block_state.control_guidance_end, list) and isinstance( + block_state.control_guidance_start, list + ): + block_state.control_guidance_end = len(block_state.control_guidance_start) * [ + block_state.control_guidance_end + ] + elif not isinstance(block_state.control_guidance_start, list) and not isinstance( + block_state.control_guidance_end, list + ): + mult = ( + len(block_state.control_image_latents) if isinstance(controlnet, QwenImageMultiControlNetModel) else 1 + ) + block_state.control_guidance_start, block_state.control_guidance_end = ( + mult * [block_state.control_guidance_start], + mult * [block_state.control_guidance_end], + ) + + # controlnet_conditioning_scale (align format) + if isinstance(controlnet, QwenImageMultiControlNetModel) and isinstance( + block_state.controlnet_conditioning_scale, float + ): + block_state.controlnet_conditioning_scale = [block_state.controlnet_conditioning_scale] * mult + + # controlnet_keep + block_state.controlnet_keep = [] + for i in range(len(block_state.timesteps)): + keeps = [ + 1.0 - float(i / len(block_state.timesteps) < s or (i + 1) / len(block_state.timesteps) > e) + for s, e in zip(block_state.control_guidance_start, block_state.control_guidance_end) + ] + block_state.controlnet_keep.append(keeps[0] if isinstance(controlnet, QwenImageControlNetModel) else keeps) + + self.set_block_state(state, block_state) + + return components, state diff --git a/diffusers/modular_pipelines/qwenimage/decoders.py b/diffusers/modular_pipelines/qwenimage/decoders.py new file mode 100644 index 0000000000000000000000000000000000000000..e4ccb6b8e0470dbd9140e4f4927050801c5d629d --- /dev/null +++ b/diffusers/modular_pipelines/qwenimage/decoders.py @@ -0,0 +1,511 @@ +# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Any + +import torch + +from ...configuration_utils import FrozenDict +from ...image_processor import InpaintProcessor, VaeImageProcessor +from ...models import AutoencoderKLQwenImage +from ...utils import logging +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import QwenImageLayeredPachifier, QwenImageModularPipeline, QwenImagePachifier + + +logger = logging.get_logger(__name__) + + +# after denoising loop (unpack latents) + + +# auto_docstring +class QwenImageAfterDenoiseStep(ModularPipelineBlocks): + """ + Step that unpack the latents from 3D tensor (batch_size, sequence_length, channels) into 5D tensor (batch_size, + channels, 1, height, width) + + Components: + pachifier (`QwenImagePachifier`) + + Inputs: + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + latents (`Tensor`): + The latents to decode, can be generated in the denoise step. + + Outputs: + latents (`Tensor`): + The denoisedlatents unpacked to B, C, 1, H, W + """ + + model_name = "qwenimage" + + @property + def description(self) -> str: + return "Step that unpack the latents from 3D tensor (batch_size, sequence_length, channels) into 5D tensor (batch_size, channels, 1, height, width)" + + @property + def expected_components(self) -> list[ComponentSpec]: + components = [ + ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"), + ] + + return components + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("height", required=True), + InputParam.template("width", required=True), + InputParam( + name="latents", + required=True, + type_hint=torch.Tensor, + description="The latents to decode, can be generated in the denoise step.", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + name="latents", type_hint=torch.Tensor, description="The denoisedlatents unpacked to B, C, 1, H, W" + ), + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + vae_scale_factor = components.vae_scale_factor + block_state.latents = components.pachifier.unpack_latents( + block_state.latents, block_state.height, block_state.width, vae_scale_factor=vae_scale_factor + ) + + self.set_block_state(state, block_state) + return components, state + + +# auto_docstring +class QwenImageLayeredAfterDenoiseStep(ModularPipelineBlocks): + """ + Unpack latents from (B, seq, C*4) to (B, C, layers+1, H, W) after denoising. + + Components: + pachifier (`QwenImageLayeredPachifier`) + + Inputs: + latents (`Tensor`): + The denoised latents to decode, can be generated in the denoise step. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + layers (`int`, *optional*, defaults to 4): + Number of layers to extract from the image + + Outputs: + latents (`Tensor`): + Denoised latents. (unpacked to B, C, layers+1, H, W) + """ + + model_name = "qwenimage-layered" + + @property + def description(self) -> str: + return "Unpack latents from (B, seq, C*4) to (B, C, layers+1, H, W) after denoising." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("pachifier", QwenImageLayeredPachifier, default_creation_method="from_config"), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + name="latents", + required=True, + type_hint=torch.Tensor, + description="The denoised latents to decode, can be generated in the denoise step.", + ), + InputParam.template("height", required=True), + InputParam.template("width", required=True), + InputParam.template("layers"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam.template("latents", note="unpacked to B, C, layers+1, H, W"), + ] + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + # Unpack: (B, seq, C*4) -> (B, C, layers+1, H, W) + block_state.latents = components.pachifier.unpack_latents( + block_state.latents, + block_state.height, + block_state.width, + block_state.layers, + components.vae_scale_factor, + ) + + self.set_block_state(state, block_state) + return components, state + + +# decode step + + +# auto_docstring +class QwenImageDecoderStep(ModularPipelineBlocks): + """ + Step that decodes the latents to images + + Components: + vae (`AutoencoderKLQwenImage`) + + Inputs: + latents (`Tensor`): + The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise + step. + + Outputs: + images (`list`): + Generated images. (tensor output of the vae decoder.) + """ + + model_name = "qwenimage" + + @property + def description(self) -> str: + return "Step that decodes the latents to images" + + @property + def expected_components(self) -> list[ComponentSpec]: + components = [ + ComponentSpec("vae", AutoencoderKLQwenImage), + ] + + return components + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + name="latents", + required=True, + type_hint=torch.Tensor, + description="The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise step.", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [OutputParam.template("images", note="tensor output of the vae decoder.")] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + # YiYi Notes: remove support for output_type = "latents', we can just skip decode/encode step in modular + if block_state.latents.ndim == 4: + block_state.latents = block_state.latents.unsqueeze(dim=1) + elif block_state.latents.ndim != 5: + raise ValueError( + f"expect latents to be a 4D or 5D tensor but got: {block_state.latents.shape}. Please make sure the latents are unpacked before decode step." + ) + block_state.latents = block_state.latents.to(components.vae.dtype) + + latents_mean = ( + torch.tensor(components.vae.config.latents_mean) + .view(1, components.vae.config.z_dim, 1, 1, 1) + .to(block_state.latents.device, block_state.latents.dtype) + ) + latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view( + 1, components.vae.config.z_dim, 1, 1, 1 + ).to(block_state.latents.device, block_state.latents.dtype) + block_state.latents = block_state.latents / latents_std + latents_mean + block_state.images = components.vae.decode(block_state.latents, return_dict=False)[0][:, :, 0] + + self.set_block_state(state, block_state) + return components, state + + +# auto_docstring +class QwenImageLayeredDecoderStep(ModularPipelineBlocks): + """ + Decode unpacked latents (B, C, layers+1, H, W) into layer images. + + Components: + vae (`AutoencoderKLQwenImage`) image_processor (`VaeImageProcessor`) + + Inputs: + latents (`Tensor`): + The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise + step. + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + + Outputs: + images (`list`): + Generated images. + """ + + model_name = "qwenimage-layered" + + @property + def description(self) -> str: + return "Decode unpacked latents (B, C, layers+1, H, W) into layer images." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKLQwenImage), + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 16}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + name="latents", + required=True, + type_hint=torch.Tensor, + description="The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise step.", + ), + InputParam.template("output_type"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [OutputParam.template("images")] + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + latents = block_state.latents + + # 1. VAE normalization + latents = latents.to(components.vae.dtype) + latents_mean = ( + torch.tensor(components.vae.config.latents_mean) + .view(1, components.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view( + 1, components.vae.config.z_dim, 1, 1, 1 + ).to(latents.device, latents.dtype) + latents = latents / latents_std + latents_mean + + # 2. Reshape for batch decoding: (B, C, layers+1, H, W) -> (B*layers, C, 1, H, W) + b, c, f, h, w = latents.shape + # 3. Remove first frame (composite), keep layers frames + latents = latents[:, :, 1:] + latents = latents.permute(0, 2, 1, 3, 4).reshape(-1, c, 1, h, w) + + # 4. Decode: (B*layers, C, 1, H, W) -> (B*layers, C, H, W) + image = components.vae.decode(latents, return_dict=False)[0] + image = image.squeeze(2) + + # 5. Postprocess - returns flat list of B*layers images + image = components.image_processor.postprocess(image, output_type=block_state.output_type) + + # 6. Chunk into list per batch item + images = [] + for bidx in range(b): + images.append(image[bidx * f : (bidx + 1) * f]) + + block_state.images = images + + self.set_block_state(state, block_state) + return components, state + + +# postprocess the decoded images + + +# auto_docstring +class QwenImageProcessImagesOutputStep(ModularPipelineBlocks): + """ + postprocess the generated image + + Components: + image_processor (`VaeImageProcessor`) + + Inputs: + images (`Tensor`): + the generated image tensor from decoders step + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + + Outputs: + images (`list`): + Generated images. + """ + + model_name = "qwenimage" + + @property + def description(self) -> str: + return "postprocess the generated image" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 16}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + name="images", + required=True, + type_hint=torch.Tensor, + description="the generated image tensor from decoders step", + ), + InputParam.template("output_type"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [OutputParam.template("images")] + + @staticmethod + def check_inputs(output_type): + if output_type not in ["pil", "np", "pt"]: + raise ValueError(f"Invalid output_type: {output_type}") + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState): + block_state = self.get_block_state(state) + + self.check_inputs(block_state.output_type) + + block_state.images = components.image_processor.postprocess( + image=block_state.images, + output_type=block_state.output_type, + ) + + self.set_block_state(state, block_state) + return components, state + + +# auto_docstring +class QwenImageInpaintProcessImagesOutputStep(ModularPipelineBlocks): + """ + postprocess the generated image, optional apply the mask overally to the original image.. + + Components: + image_mask_processor (`InpaintProcessor`) + + Inputs: + images (`Tensor`): + the generated image tensor from decoders step + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + mask_overlay_kwargs (`dict`, *optional*): + The kwargs for the postprocess step to apply the mask overlay. generated in + InpaintProcessImagesInputStep. + + Outputs: + images (`list`): + Generated images. + """ + + model_name = "qwenimage" + + @property + def description(self) -> str: + return "postprocess the generated image, optional apply the mask overally to the original image.." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec( + "image_mask_processor", + InpaintProcessor, + config=FrozenDict({"vae_scale_factor": 16}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + name="images", + required=True, + type_hint=torch.Tensor, + description="the generated image tensor from decoders step", + ), + InputParam.template("output_type"), + InputParam( + name="mask_overlay_kwargs", + type_hint=dict[str, Any], + description="The kwargs for the postprocess step to apply the mask overlay. generated in InpaintProcessImagesInputStep.", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [OutputParam.template("images")] + + @staticmethod + def check_inputs(output_type, mask_overlay_kwargs): + if output_type not in ["pil", "np", "pt"]: + raise ValueError(f"Invalid output_type: {output_type}") + + if mask_overlay_kwargs and output_type != "pil": + raise ValueError("only support output_type 'pil' for mask overlay") + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState): + block_state = self.get_block_state(state) + + self.check_inputs(block_state.output_type, block_state.mask_overlay_kwargs) + + if block_state.mask_overlay_kwargs is None: + mask_overlay_kwargs = {} + else: + mask_overlay_kwargs = block_state.mask_overlay_kwargs + + block_state.images = components.image_mask_processor.postprocess( + image=block_state.images, + **mask_overlay_kwargs, + ) + + self.set_block_state(state, block_state) + return components, state diff --git a/diffusers/modular_pipelines/qwenimage/denoise.py b/diffusers/modular_pipelines/qwenimage/denoise.py new file mode 100644 index 0000000000000000000000000000000000000000..de8ea05c5047cbe346e0ce61e59e2f6c5827d3f2 --- /dev/null +++ b/diffusers/modular_pipelines/qwenimage/denoise.py @@ -0,0 +1,943 @@ +# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect + +import torch + +from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance +from ...models import QwenImageControlNetModel, QwenImageTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging +from ..modular_pipeline import BlockState, LoopSequentialPipelineBlocks, ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import QwenImageModularPipeline + + +logger = logging.get_logger(__name__) + +# ==================== +# 1. LOOP STEPS (run at each denoising step) +# ==================== + + +# loop step:before denoiser +class QwenImageLoopBeforeDenoiser(ModularPipelineBlocks): + model_name = "qwenimage" + + @property + def description(self) -> str: + return ( + "step within the denoising loop that prepares the latent input for the denoiser. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `QwenImageDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + name="latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", + ), + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + # one timestep + block_state.timestep = t.expand(block_state.latents.shape[0]).to(block_state.latents.dtype) + block_state.latent_model_input = block_state.latents + return components, block_state + + +class QwenImageEditLoopBeforeDenoiser(ModularPipelineBlocks): + model_name = "qwenimage-edit" + + @property + def description(self) -> str: + return ( + "step within the denoising loop that prepares the latent input for the denoiser. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `QwenImageDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + name="latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", + ), + InputParam.template("image_latents"), + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + # one timestep + + block_state.latent_model_input = torch.cat([block_state.latents, block_state.image_latents], dim=1) + block_state.timestep = t.expand(block_state.latents.shape[0]).to(block_state.latents.dtype) + return components, block_state + + +class QwenImageLoopBeforeDenoiserControlNet(ModularPipelineBlocks): + model_name = "qwenimage" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 4.0}), + default_creation_method="from_config", + ), + ComponentSpec("controlnet", QwenImageControlNetModel), + ] + + @property + def description(self) -> str: + return ( + "step within the denoising loop that runs the controlnet before the denoiser. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `QwenImageDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + "control_image_latents", + required=True, + type_hint=torch.Tensor, + description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step.", + ), + InputParam.template("controlnet_conditioning_scale", note="updated in prepare_controlnet_inputs step."), + InputParam( + name="controlnet_keep", + required=True, + type_hint=list[float], + description="The controlnet keep values. Can be generated in prepare_controlnet_inputs step.", + ), + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: int): + # cond_scale for the timestep (controlnet input) + if isinstance(block_state.controlnet_keep[i], list): + block_state.cond_scale = [ + c * s for c, s in zip(block_state.controlnet_conditioning_scale, block_state.controlnet_keep[i]) + ] + else: + controlnet_cond_scale = block_state.controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + block_state.cond_scale = controlnet_cond_scale * block_state.controlnet_keep[i] + + # run controlnet for the guidance batch + controlnet_block_samples = components.controlnet( + hidden_states=block_state.latent_model_input, + controlnet_cond=block_state.control_image_latents, + conditioning_scale=block_state.cond_scale, + timestep=block_state.timestep / 1000, + img_shapes=block_state.img_shapes, + encoder_hidden_states=block_state.prompt_embeds, + encoder_hidden_states_mask=block_state.prompt_embeds_mask, + return_dict=False, + ) + + block_state.additional_cond_kwargs["controlnet_block_samples"] = controlnet_block_samples + + return components, block_state + + +# loop step:denoiser +class QwenImageLoopDenoiser(ModularPipelineBlocks): + model_name = "qwenimage" + + @property + def description(self) -> str: + return ( + "step within the denoising loop that denoise the latent input for the denoiser. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `QwenImageDenoiseLoopWrapper`)" + ) + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 4.0}), + default_creation_method="from_config", + ), + ComponentSpec("transformer", QwenImageTransformer2DModel), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("attention_kwargs"), + InputParam.template("denoiser_input_fields"), + InputParam( + "img_shapes", + required=True, + type_hint=list[tuple[int, int]], + description="The shape of the image latents for RoPE calculation. can be generated in prepare_additional_inputs step.", + ), + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + guider_inputs = { + "encoder_hidden_states": ( + getattr(block_state, "prompt_embeds", None), + getattr(block_state, "negative_prompt_embeds", None), + ), + "encoder_hidden_states_mask": ( + getattr(block_state, "prompt_embeds_mask", None), + getattr(block_state, "negative_prompt_embeds_mask", None), + ), + } + + transformer_args = set(inspect.signature(components.transformer.forward).parameters.keys()) + additional_cond_kwargs = {} + for field_name, field_value in block_state.denoiser_input_fields.items(): + if field_name in transformer_args and field_name not in guider_inputs: + additional_cond_kwargs[field_name] = field_value + block_state.additional_cond_kwargs.update(additional_cond_kwargs) + + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + guider_state = components.guider.prepare_inputs(guider_inputs) + + for guider_state_batch in guider_state: + components.guider.prepare_models(components.transformer) + cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()} + + # YiYi TODO: add cache context + guider_state_batch.noise_pred = components.transformer( + hidden_states=block_state.latent_model_input, + timestep=block_state.timestep / 1000, + attention_kwargs=block_state.attention_kwargs, + return_dict=False, + **cond_kwargs, + **block_state.additional_cond_kwargs, + )[0] + + components.guider.cleanup_models(components.transformer) + + guider_output = components.guider(guider_state) + + # apply guidance rescale + pred_cond_norm = torch.norm(guider_output.pred_cond, dim=-1, keepdim=True) + pred_norm = torch.norm(guider_output.pred, dim=-1, keepdim=True) + block_state.noise_pred = guider_output.pred * (pred_cond_norm / pred_norm) + + return components, block_state + + +class QwenImageEditLoopDenoiser(ModularPipelineBlocks): + model_name = "qwenimage-edit" + + @property + def description(self) -> str: + return ( + "step within the denoising loop that denoise the latent input for the denoiser. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `QwenImageDenoiseLoopWrapper`)" + ) + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 4.0}), + default_creation_method="from_config", + ), + ComponentSpec("transformer", QwenImageTransformer2DModel), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("attention_kwargs"), + InputParam.template("denoiser_input_fields"), + InputParam( + "img_shapes", + required=True, + type_hint=list[tuple[int, int]], + description="The shape of the image latents for RoPE calculation. Can be generated in prepare_additional_inputs step.", + ), + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + guider_inputs = { + "encoder_hidden_states": ( + getattr(block_state, "prompt_embeds", None), + getattr(block_state, "negative_prompt_embeds", None), + ), + "encoder_hidden_states_mask": ( + getattr(block_state, "prompt_embeds_mask", None), + getattr(block_state, "negative_prompt_embeds_mask", None), + ), + } + + transformer_args = set(inspect.signature(components.transformer.forward).parameters.keys()) + additional_cond_kwargs = {} + for field_name, field_value in block_state.denoiser_input_fields.items(): + if field_name in transformer_args and field_name not in guider_inputs: + additional_cond_kwargs[field_name] = field_value + block_state.additional_cond_kwargs.update(additional_cond_kwargs) + + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + guider_state = components.guider.prepare_inputs(guider_inputs) + + for guider_state_batch in guider_state: + components.guider.prepare_models(components.transformer) + cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()} + + # YiYi TODO: add cache context + guider_state_batch.noise_pred = components.transformer( + hidden_states=block_state.latent_model_input, + timestep=block_state.timestep / 1000, + attention_kwargs=block_state.attention_kwargs, + return_dict=False, + **cond_kwargs, + **block_state.additional_cond_kwargs, + )[0] + + components.guider.cleanup_models(components.transformer) + + guider_output = components.guider(guider_state) + + pred = guider_output.pred[:, : block_state.latents.size(1)] + pred_cond = guider_output.pred_cond[:, : block_state.latents.size(1)] + + # apply guidance rescale + pred_cond_norm = torch.norm(pred_cond, dim=-1, keepdim=True) + pred_norm = torch.norm(pred, dim=-1, keepdim=True) + block_state.noise_pred = pred * (pred_cond_norm / pred_norm) + + return components, block_state + + +# loop step:after denoiser +class QwenImageLoopAfterDenoiser(ModularPipelineBlocks): + model_name = "qwenimage" + + @property + def description(self) -> str: + return ( + "step within the denoising loop that updates the latents. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `QwenImageDenoiseLoopWrapper`)" + ) + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam.template("latents"), + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + latents_dtype = block_state.latents.dtype + block_state.latents = components.scheduler.step( + block_state.noise_pred, + t, + block_state.latents, + return_dict=False, + )[0] + + if block_state.latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + block_state.latents = block_state.latents.to(latents_dtype) + + return components, block_state + + +class QwenImageLoopAfterDenoiserInpaint(ModularPipelineBlocks): + model_name = "qwenimage" + + @property + def description(self) -> str: + return ( + "step within the denoising loop that updates the latents using mask and image_latents for inpainting. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `QwenImageDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + "mask", + required=True, + type_hint=torch.Tensor, + description="The mask to use for the inpainting process. Can be generated in inpaint prepare latents step.", + ), + InputParam.template("image_latents"), + InputParam( + "initial_noise", + required=True, + type_hint=torch.Tensor, + description="The initial noise to use for the inpainting process. Can be generated in inpaint prepare latents step.", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam.template("latents"), + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + block_state.init_latents_proper = block_state.image_latents + if i < len(block_state.timesteps) - 1: + block_state.noise_timestep = block_state.timesteps[i + 1] + block_state.init_latents_proper = components.scheduler.scale_noise( + block_state.init_latents_proper, torch.tensor([block_state.noise_timestep]), block_state.initial_noise + ) + + block_state.latents = ( + 1 - block_state.mask + ) * block_state.init_latents_proper + block_state.mask * block_state.latents + + return components, block_state + + +# ==================== +# 2. DENOISE LOOP WRAPPER: define the denoising loop logic +# ==================== +class QwenImageDenoiseLoopWrapper(LoopSequentialPipelineBlocks): + model_name = "qwenimage" + + @property + def description(self) -> str: + return ( + "Pipeline block that iteratively denoise the latents over `timesteps`. " + "The specific steps with each iteration can be customized with `sub_blocks` attributes" + ) + + @property + def loop_expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ] + + @property + def loop_inputs(self) -> list[InputParam]: + return [ + InputParam( + name="timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", + ), + InputParam.template("num_inference_steps", required=True), + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.num_warmup_steps = max( + len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0 + ) + + block_state.additional_cond_kwargs = {} + + with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: + for i, t in enumerate(block_state.timesteps): + components, block_state = self.loop_step(components, block_state, i=i, t=t) + if i == len(block_state.timesteps) - 1 or ( + (i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0 + ): + progress_bar.update() + + self.set_block_state(state, block_state) + + return components, state + + +# ==================== +# 3. DENOISE STEPS: compose the denoising loop with loop wrapper + loop steps +# ==================== + + +# Qwen Image (text2image, image2image) + + +# auto_docstring +class QwenImageDenoiseStep(QwenImageDenoiseLoopWrapper): + """ + Denoise step that iteratively denoise the latents. + Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method At each iteration, it runs blocks + defined in `sub_blocks` sequencially: + - `QwenImageLoopBeforeDenoiser` + - `QwenImageLoopDenoiser` + - `QwenImageLoopAfterDenoiser` + This block supports text2image and image2image tasks for QwenImage. + + Components: + guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) scheduler + (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + num_inference_steps (`int`): + The number of denoising steps. + latents (`Tensor`): + The initial latents to use for the denoising process. Can be generated in prepare_latent step. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + img_shapes (`list`): + The shape of the image latents for RoPE calculation. can be generated in prepare_additional_inputs step. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "qwenimage" + + block_classes = [ + QwenImageLoopBeforeDenoiser, + QwenImageLoopDenoiser, + QwenImageLoopAfterDenoiser, + ] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents.\n" + "Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method\n" + "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n" + " - `QwenImageLoopBeforeDenoiser`\n" + " - `QwenImageLoopDenoiser`\n" + " - `QwenImageLoopAfterDenoiser`\n" + "This block supports text2image and image2image tasks for QwenImage." + ) + + +# Qwen Image (inpainting) +# auto_docstring +class QwenImageInpaintDenoiseStep(QwenImageDenoiseLoopWrapper): + """ + Denoise step that iteratively denoise the latents. + Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method At each iteration, it runs blocks + defined in `sub_blocks` sequencially: + - `QwenImageLoopBeforeDenoiser` + - `QwenImageLoopDenoiser` + - `QwenImageLoopAfterDenoiser` + - `QwenImageLoopAfterDenoiserInpaint` + This block supports inpainting tasks for QwenImage. + + Components: + guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) scheduler + (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + num_inference_steps (`int`): + The number of denoising steps. + latents (`Tensor`): + The initial latents to use for the denoising process. Can be generated in prepare_latent step. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + img_shapes (`list`): + The shape of the image latents for RoPE calculation. can be generated in prepare_additional_inputs step. + mask (`Tensor`): + The mask to use for the inpainting process. Can be generated in inpaint prepare latents step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + initial_noise (`Tensor`): + The initial noise to use for the inpainting process. Can be generated in inpaint prepare latents step. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "qwenimage" + block_classes = [ + QwenImageLoopBeforeDenoiser, + QwenImageLoopDenoiser, + QwenImageLoopAfterDenoiser, + QwenImageLoopAfterDenoiserInpaint, + ] + block_names = ["before_denoiser", "denoiser", "after_denoiser", "after_denoiser_inpaint"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. \n" + "Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n" + " - `QwenImageLoopBeforeDenoiser`\n" + " - `QwenImageLoopDenoiser`\n" + " - `QwenImageLoopAfterDenoiser`\n" + " - `QwenImageLoopAfterDenoiserInpaint`\n" + "This block supports inpainting tasks for QwenImage." + ) + + +# Qwen Image (text2image, image2image) with controlnet +# auto_docstring +class QwenImageControlNetDenoiseStep(QwenImageDenoiseLoopWrapper): + """ + Denoise step that iteratively denoise the latents. + Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method At each iteration, it runs blocks + defined in `sub_blocks` sequencially: + - `QwenImageLoopBeforeDenoiser` + - `QwenImageLoopBeforeDenoiserControlNet` + - `QwenImageLoopDenoiser` + - `QwenImageLoopAfterDenoiser` + This block supports text2img/img2img tasks with controlnet for QwenImage. + + Components: + guider (`ClassifierFreeGuidance`) controlnet (`QwenImageControlNetModel`) transformer + (`QwenImageTransformer2DModel`) scheduler (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + num_inference_steps (`int`): + The number of denoising steps. + latents (`Tensor`): + The initial latents to use for the denoising process. Can be generated in prepare_latent step. + control_image_latents (`Tensor`): + The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step. + controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): + Scale for ControlNet conditioning. (updated in prepare_controlnet_inputs step.) + controlnet_keep (`list`): + The controlnet keep values. Can be generated in prepare_controlnet_inputs step. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + img_shapes (`list`): + The shape of the image latents for RoPE calculation. can be generated in prepare_additional_inputs step. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "qwenimage" + block_classes = [ + QwenImageLoopBeforeDenoiser, + QwenImageLoopBeforeDenoiserControlNet, + QwenImageLoopDenoiser, + QwenImageLoopAfterDenoiser, + ] + block_names = ["before_denoiser", "before_denoiser_controlnet", "denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. \n" + "Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n" + " - `QwenImageLoopBeforeDenoiser`\n" + " - `QwenImageLoopBeforeDenoiserControlNet`\n" + " - `QwenImageLoopDenoiser`\n" + " - `QwenImageLoopAfterDenoiser`\n" + "This block supports text2img/img2img tasks with controlnet for QwenImage." + ) + + +# Qwen Image (inpainting) with controlnet +# auto_docstring +class QwenImageInpaintControlNetDenoiseStep(QwenImageDenoiseLoopWrapper): + """ + Denoise step that iteratively denoise the latents. + Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method At each iteration, it runs blocks + defined in `sub_blocks` sequencially: + - `QwenImageLoopBeforeDenoiser` + - `QwenImageLoopBeforeDenoiserControlNet` + - `QwenImageLoopDenoiser` + - `QwenImageLoopAfterDenoiser` + - `QwenImageLoopAfterDenoiserInpaint` + This block supports inpainting tasks with controlnet for QwenImage. + + Components: + guider (`ClassifierFreeGuidance`) controlnet (`QwenImageControlNetModel`) transformer + (`QwenImageTransformer2DModel`) scheduler (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + num_inference_steps (`int`): + The number of denoising steps. + latents (`Tensor`): + The initial latents to use for the denoising process. Can be generated in prepare_latent step. + control_image_latents (`Tensor`): + The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step. + controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): + Scale for ControlNet conditioning. (updated in prepare_controlnet_inputs step.) + controlnet_keep (`list`): + The controlnet keep values. Can be generated in prepare_controlnet_inputs step. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + img_shapes (`list`): + The shape of the image latents for RoPE calculation. can be generated in prepare_additional_inputs step. + mask (`Tensor`): + The mask to use for the inpainting process. Can be generated in inpaint prepare latents step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + initial_noise (`Tensor`): + The initial noise to use for the inpainting process. Can be generated in inpaint prepare latents step. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "qwenimage" + block_classes = [ + QwenImageLoopBeforeDenoiser, + QwenImageLoopBeforeDenoiserControlNet, + QwenImageLoopDenoiser, + QwenImageLoopAfterDenoiser, + QwenImageLoopAfterDenoiserInpaint, + ] + block_names = [ + "before_denoiser", + "before_denoiser_controlnet", + "denoiser", + "after_denoiser", + "after_denoiser_inpaint", + ] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. \n" + "Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n" + " - `QwenImageLoopBeforeDenoiser`\n" + " - `QwenImageLoopBeforeDenoiserControlNet`\n" + " - `QwenImageLoopDenoiser`\n" + " - `QwenImageLoopAfterDenoiser`\n" + " - `QwenImageLoopAfterDenoiserInpaint`\n" + "This block supports inpainting tasks with controlnet for QwenImage." + ) + + +# Qwen Image Edit (image2image) +# auto_docstring +class QwenImageEditDenoiseStep(QwenImageDenoiseLoopWrapper): + """ + Denoise step that iteratively denoise the latents. + Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method At each iteration, it runs blocks + defined in `sub_blocks` sequencially: + - `QwenImageEditLoopBeforeDenoiser` + - `QwenImageEditLoopDenoiser` + - `QwenImageLoopAfterDenoiser` + This block supports QwenImage Edit. + + Components: + guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) scheduler + (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + num_inference_steps (`int`): + The number of denoising steps. + latents (`Tensor`): + The initial latents to use for the denoising process. Can be generated in prepare_latent step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + img_shapes (`list`): + The shape of the image latents for RoPE calculation. Can be generated in prepare_additional_inputs step. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "qwenimage-edit" + block_classes = [ + QwenImageEditLoopBeforeDenoiser, + QwenImageEditLoopDenoiser, + QwenImageLoopAfterDenoiser, + ] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. \n" + "Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n" + " - `QwenImageEditLoopBeforeDenoiser`\n" + " - `QwenImageEditLoopDenoiser`\n" + " - `QwenImageLoopAfterDenoiser`\n" + "This block supports QwenImage Edit." + ) + + +# Qwen Image Edit (inpainting) +# auto_docstring +class QwenImageEditInpaintDenoiseStep(QwenImageDenoiseLoopWrapper): + """ + Denoise step that iteratively denoise the latents. + Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method At each iteration, it runs blocks + defined in `sub_blocks` sequencially: + - `QwenImageEditLoopBeforeDenoiser` + - `QwenImageEditLoopDenoiser` + - `QwenImageLoopAfterDenoiser` + - `QwenImageLoopAfterDenoiserInpaint` + This block supports inpainting tasks for QwenImage Edit. + + Components: + guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) scheduler + (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + num_inference_steps (`int`): + The number of denoising steps. + latents (`Tensor`): + The initial latents to use for the denoising process. Can be generated in prepare_latent step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + img_shapes (`list`): + The shape of the image latents for RoPE calculation. Can be generated in prepare_additional_inputs step. + mask (`Tensor`): + The mask to use for the inpainting process. Can be generated in inpaint prepare latents step. + initial_noise (`Tensor`): + The initial noise to use for the inpainting process. Can be generated in inpaint prepare latents step. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "qwenimage-edit" + block_classes = [ + QwenImageEditLoopBeforeDenoiser, + QwenImageEditLoopDenoiser, + QwenImageLoopAfterDenoiser, + QwenImageLoopAfterDenoiserInpaint, + ] + block_names = ["before_denoiser", "denoiser", "after_denoiser", "after_denoiser_inpaint"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. \n" + "Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n" + " - `QwenImageEditLoopBeforeDenoiser`\n" + " - `QwenImageEditLoopDenoiser`\n" + " - `QwenImageLoopAfterDenoiser`\n" + " - `QwenImageLoopAfterDenoiserInpaint`\n" + "This block supports inpainting tasks for QwenImage Edit." + ) + + +# Qwen Image Layered (image2image) +# auto_docstring +class QwenImageLayeredDenoiseStep(QwenImageDenoiseLoopWrapper): + """ + Denoise step that iteratively denoise the latents. + Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method At each iteration, it runs blocks + defined in `sub_blocks` sequencially: + - `QwenImageEditLoopBeforeDenoiser` + - `QwenImageEditLoopDenoiser` + - `QwenImageLoopAfterDenoiser` + This block supports QwenImage Layered. + + Components: + guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) scheduler + (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + num_inference_steps (`int`): + The number of denoising steps. + latents (`Tensor`): + The initial latents to use for the denoising process. Can be generated in prepare_latent step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + img_shapes (`list`): + The shape of the image latents for RoPE calculation. Can be generated in prepare_additional_inputs step. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "qwenimage-layered" + block_classes = [ + QwenImageEditLoopBeforeDenoiser, + QwenImageEditLoopDenoiser, + QwenImageLoopAfterDenoiser, + ] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. \n" + "Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n" + " - `QwenImageEditLoopBeforeDenoiser`\n" + " - `QwenImageEditLoopDenoiser`\n" + " - `QwenImageLoopAfterDenoiser`\n" + "This block supports QwenImage Layered." + ) diff --git a/diffusers/modular_pipelines/qwenimage/encoders.py b/diffusers/modular_pipelines/qwenimage/encoders.py new file mode 100644 index 0000000000000000000000000000000000000000..527267dc0d6eb10348d9d9a0a87ea2154a0c6261 --- /dev/null +++ b/diffusers/modular_pipelines/qwenimage/encoders.py @@ -0,0 +1,1780 @@ +# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Text and VAE encoder blocks for QwenImage pipelines. +""" + +import PIL +import torch +from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor + +from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance +from ...image_processor import InpaintProcessor, VaeImageProcessor, is_valid_image, is_valid_image_imagelist +from ...models import AutoencoderKLQwenImage, QwenImageControlNetModel, QwenImageMultiControlNetModel +from ...pipelines.qwenimage.pipeline_qwenimage_edit import calculate_dimensions +from ...utils import logging +from ...utils.torch_utils import unwrap_module +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import QwenImageModularPipeline +from .prompt_templates import ( + QWENIMAGE_EDIT_PLUS_IMG_TEMPLATE, + QWENIMAGE_EDIT_PLUS_PROMPT_TEMPLATE, + QWENIMAGE_EDIT_PLUS_PROMPT_TEMPLATE_START_IDX, + QWENIMAGE_EDIT_PROMPT_TEMPLATE, + QWENIMAGE_EDIT_PROMPT_TEMPLATE_START_IDX, + QWENIMAGE_LAYERED_CAPTION_PROMPT_CN, + QWENIMAGE_LAYERED_CAPTION_PROMPT_EN, + QWENIMAGE_PROMPT_TEMPLATE, + QWENIMAGE_PROMPT_TEMPLATE_START_IDX, +) + + +logger = logging.get_logger(__name__) + + +def _extract_masked_hidden(hidden_states: torch.Tensor, mask: torch.Tensor): + bool_mask = mask.bool() + valid_lengths = bool_mask.sum(dim=1) + selected = hidden_states[bool_mask] + split_result = torch.split(selected, valid_lengths.tolist(), dim=0) + return split_result + + +def get_qwen_prompt_embeds( + text_encoder, + tokenizer, + prompt: str | list[str] = None, + prompt_template_encode: str = QWENIMAGE_PROMPT_TEMPLATE, + prompt_template_encode_start_idx: int = QWENIMAGE_PROMPT_TEMPLATE_START_IDX, + tokenizer_max_length: int = 1024, + device: torch.device | None = None, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + + template = prompt_template_encode + drop_idx = prompt_template_encode_start_idx + txt = [template.format(e) for e in prompt] + txt_tokens = tokenizer( + txt, max_length=tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt" + ).to(device) + encoder_hidden_states = text_encoder( + input_ids=txt_tokens.input_ids, + attention_mask=txt_tokens.attention_mask, + output_hidden_states=True, + ) + hidden_states = encoder_hidden_states.hidden_states[-1] + + split_hidden_states = _extract_masked_hidden(hidden_states, txt_tokens.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + ) + encoder_attention_mask = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + ) + + prompt_embeds = prompt_embeds.to(device=device) + + return prompt_embeds, encoder_attention_mask + + +def get_qwen_prompt_embeds_edit( + text_encoder, + processor, + prompt: str | list[str] = None, + image: torch.Tensor | None = None, + prompt_template_encode: str = QWENIMAGE_EDIT_PROMPT_TEMPLATE, + prompt_template_encode_start_idx: int = QWENIMAGE_EDIT_PROMPT_TEMPLATE_START_IDX, + device: torch.device | None = None, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + + template = prompt_template_encode + drop_idx = prompt_template_encode_start_idx + txt = [template.format(e) for e in prompt] + + model_inputs = processor( + text=txt, + images=image, + padding=True, + return_tensors="pt", + ).to(device) + + outputs = text_encoder( + input_ids=model_inputs.input_ids, + attention_mask=model_inputs.attention_mask, + pixel_values=model_inputs.pixel_values, + image_grid_thw=model_inputs.image_grid_thw, + output_hidden_states=True, + ) + + hidden_states = outputs.hidden_states[-1] + split_hidden_states = _extract_masked_hidden(hidden_states, model_inputs.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + ) + encoder_attention_mask = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + ) + + prompt_embeds = prompt_embeds.to(device=device) + + return prompt_embeds, encoder_attention_mask + + +def get_qwen_prompt_embeds_edit_plus( + text_encoder, + processor, + prompt: str | list[str] = None, + image: torch.Tensor | list[PIL.Image.Image, PIL.Image.Image] | None = None, + prompt_template_encode: str = QWENIMAGE_EDIT_PLUS_PROMPT_TEMPLATE, + img_template_encode: str = QWENIMAGE_EDIT_PLUS_IMG_TEMPLATE, + prompt_template_encode_start_idx: int = QWENIMAGE_EDIT_PLUS_PROMPT_TEMPLATE_START_IDX, + device: torch.device | None = None, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + if isinstance(image, list): + base_img_prompt = "" + for i, img in enumerate(image): + base_img_prompt += img_template_encode.format(i + 1) + elif image is not None: + base_img_prompt = img_template_encode.format(1) + else: + base_img_prompt = "" + + template = prompt_template_encode + + drop_idx = prompt_template_encode_start_idx + txt = [template.format(base_img_prompt + e) for e in prompt] + + model_inputs = processor( + text=txt, + images=image, + padding=True, + return_tensors="pt", + ).to(device) + outputs = text_encoder( + input_ids=model_inputs.input_ids, + attention_mask=model_inputs.attention_mask, + pixel_values=model_inputs.pixel_values, + image_grid_thw=model_inputs.image_grid_thw, + output_hidden_states=True, + ) + + hidden_states = outputs.hidden_states[-1] + split_hidden_states = _extract_masked_hidden(hidden_states, model_inputs.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + ) + encoder_attention_mask = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + ) + + prompt_embeds = prompt_embeds.to(device=device) + return prompt_embeds, encoder_attention_mask + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Modified from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._encode_vae_image +def encode_vae_image( + image: torch.Tensor, + vae: AutoencoderKLQwenImage, + generator: torch.Generator, + device: torch.device, + dtype: torch.dtype, + latent_channels: int = 16, + sample_mode: str = "argmax", +): + if not isinstance(image, torch.Tensor): + raise ValueError(f"Expected image to be a tensor, got {type(image)}.") + + # preprocessed image should be a 4D tensor: batch_size, num_channels, height, width + if image.dim() == 4: + image = image.unsqueeze(2) + elif image.dim() != 5: + raise ValueError(f"Expected image dims 4 or 5, got {image.dim()}.") + + image = image.to(device=device, dtype=dtype) + + if isinstance(generator, list): + image_latents = [ + retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i], sample_mode=sample_mode) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(vae.encode(image), generator=generator, sample_mode=sample_mode) + latents_mean = ( + torch.tensor(vae.config.latents_mean) + .view(1, latent_channels, 1, 1, 1) + .to(image_latents.device, image_latents.dtype) + ) + latents_std = ( + torch.tensor(vae.config.latents_std) + .view(1, latent_channels, 1, 1, 1) + .to(image_latents.device, image_latents.dtype) + ) + image_latents = (image_latents - latents_mean) / latents_std + + return image_latents + + +# ==================== +# 1. RESIZE +# ==================== +# In QwenImage pipelines, resize is a separate step because the resized image is used in VL encoding and vae encoder blocks: +# +# image (PIL.Image.Image) +# │ +# ▼ +# resized_image ([PIL.Image.Image]) +# │ +# ├──► text_encoder ──► prompt_embeds, prompt_embeds_mask +# │ (VL encoding needs the resized image for vision-language fusion) +# │ +# └──► image_processor ──► processed_image (torch.Tensor, pixel space) +# │ +# ▼ +# vae_encoder ──► image_latents (torch.Tensor, latent space) +# +# In most of our other pipelines, resizing is done as part of the image preprocessing step. +# ==================== + + +# auto_docstring +class QwenImageEditResizeStep(ModularPipelineBlocks): + """ + Image Resize step that resize the image to target area while maintaining the aspect ratio. + + Components: + image_resize_processor (`VaeImageProcessor`) + + Inputs: + image (`Image | list`): + Reference image(s) for denoising. Can be a single image or list of images. + + Outputs: + resized_image (`list`): + The resized images + """ + + model_name = "qwenimage-edit" + + @property + def description(self) -> str: + return "Image Resize step that resize the image to target area while maintaining the aspect ratio." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec( + "image_resize_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 16}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [InputParam.template("image")] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + name="resized_image", + type_hint=list[PIL.Image.Image], + description="The resized images", + ), + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState): + block_state = self.get_block_state(state) + + images = block_state.image + + if not is_valid_image_imagelist(images): + raise ValueError(f"Images must be image or list of images but are {type(images)}") + + if is_valid_image(images): + images = [images] + + image_width, image_height = images[0].size + calculated_width, calculated_height, _ = calculate_dimensions(1024 * 1024, image_width / image_height) + + resized_images = [ + components.image_resize_processor.resize(image, height=calculated_height, width=calculated_width) + for image in images + ] + + block_state.resized_image = resized_images + self.set_block_state(state, block_state) + return components, state + + +# auto_docstring +class QwenImageLayeredResizeStep(ModularPipelineBlocks): + """ + Image Resize step that resize the image to a target area (defined by the resolution parameter from user) while + maintaining the aspect ratio. + + Components: + image_resize_processor (`VaeImageProcessor`) + + Inputs: + image (`Image | list`): + Reference image(s) for denoising. Can be a single image or list of images. + resolution (`int`, *optional*, defaults to 640): + The target area to resize the image to, can be 1024 or 640 + + Outputs: + resized_image (`list`): + The resized images + """ + + model_name = "qwenimage-layered" + + @property + def description(self) -> str: + return "Image Resize step that resize the image to a target area (defined by the resolution parameter from user) while maintaining the aspect ratio." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec( + "image_resize_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 16}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("image"), + InputParam( + name="resolution", + default=640, + type_hint=int, + description="The target area to resize the image to, can be 1024 or 640", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + name="resized_image", + type_hint=list[PIL.Image.Image], + description="The resized images", + ) + ] + + @staticmethod + def check_inputs(resolution: int): + if resolution not in [1024, 640]: + raise ValueError(f"Resolution must be 1024 or 640 but is {resolution}") + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState): + block_state = self.get_block_state(state) + + self.check_inputs(resolution=block_state.resolution) + + images = block_state.image + + if not is_valid_image_imagelist(images): + raise ValueError(f"Images must be image or list of images but are {type(images)}") + + if is_valid_image(images): + images = [images] + + image_width, image_height = images[0].size + target_area = block_state.resolution * block_state.resolution + calculated_width, calculated_height, _ = calculate_dimensions(target_area, image_width / image_height) + + resized_images = [ + components.image_resize_processor.resize(image, height=calculated_height, width=calculated_width) + for image in images + ] + + block_state.resized_image = resized_images + self.set_block_state(state, block_state) + return components, state + + +# auto_docstring +class QwenImageEditPlusResizeStep(ModularPipelineBlocks): + """ + Resize images for QwenImage Edit Plus pipeline. + Produces two outputs: resized_image (1024x1024) for VAE encoding, resized_cond_image (384x384) for VL text + encoding. Each image is resized independently based on its own aspect ratio. + + Components: + image_resize_processor (`VaeImageProcessor`) + + Inputs: + image (`Image | list`): + Reference image(s) for denoising. Can be a single image or list of images. + + Outputs: + resized_image (`list`): + Images resized to 1024x1024 target area for VAE encoding + resized_cond_image (`list`): + Images resized to 384x384 target area for VL text encoding + """ + + model_name = "qwenimage-edit-plus" + + @property + def description(self) -> str: + return ( + "Resize images for QwenImage Edit Plus pipeline.\n" + "Produces two outputs: resized_image (1024x1024) for VAE encoding, " + "resized_cond_image (384x384) for VL text encoding.\n" + "Each image is resized independently based on its own aspect ratio." + ) + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec( + "image_resize_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 16}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + # image + return [InputParam.template("image")] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + name="resized_image", + type_hint=list[PIL.Image.Image], + description="Images resized to 1024x1024 target area for VAE encoding", + ), + OutputParam( + name="resized_cond_image", + type_hint=list[PIL.Image.Image], + description="Images resized to 384x384 target area for VL text encoding", + ), + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState): + block_state = self.get_block_state(state) + + images = block_state.image + + if not is_valid_image_imagelist(images): + raise ValueError(f"Images must be image or list of images but are {type(images)}") + + if is_valid_image(images): + images = [images] + + # Resize each image independently based on its own aspect ratio + resized_images = [] + resized_cond_images = [] + for image in images: + image_width, image_height = image.size + + # For VAE encoder (1024x1024 target area) + vae_width, vae_height, _ = calculate_dimensions(1024 * 1024, image_width / image_height) + resized_images.append(components.image_resize_processor.resize(image, height=vae_height, width=vae_width)) + + # For VL text encoder (384x384 target area) + vl_width, vl_height, _ = calculate_dimensions(384 * 384, image_width / image_height) + resized_cond_images.append( + components.image_resize_processor.resize(image, height=vl_height, width=vl_width) + ) + + block_state.resized_image = resized_images + block_state.resized_cond_image = resized_cond_images + self.set_block_state(state, block_state) + return components, state + + +# ==================== +# 2. GET IMAGE PROMPT +# ==================== + + +# auto_docstring +class QwenImageLayeredGetImagePromptStep(ModularPipelineBlocks): + """ + Auto-caption step that generates a text prompt from the input image if none is provided. + Uses the VL model (text_encoder) to generate a description of the image. If prompt is already provided, this step + passes through unchanged. + + Components: + text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor (`Qwen2VLProcessor`) + + Inputs: + prompt (`str`, *optional*): + The prompt or prompts to guide image generation. + resized_image (`Image`): + The image to generate caption from, should be resized use the resize step + use_en_prompt (`bool`, *optional*, defaults to False): + Whether to use English prompt template + + Outputs: + prompt (`str`): + The prompt or prompts to guide image generation. If not provided, updated using image caption + """ + + model_name = "qwenimage-layered" + + def __init__(self): + self.image_caption_prompt_en = QWENIMAGE_LAYERED_CAPTION_PROMPT_EN + self.image_caption_prompt_cn = QWENIMAGE_LAYERED_CAPTION_PROMPT_CN + super().__init__() + + @property + def description(self) -> str: + return ( + "Auto-caption step that generates a text prompt from the input image if none is provided.\n" + "Uses the VL model (text_encoder) to generate a description of the image.\n" + "If prompt is already provided, this step passes through unchanged." + ) + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("text_encoder", Qwen2_5_VLForConditionalGeneration), + ComponentSpec("processor", Qwen2VLProcessor), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template( + "prompt", required=False + ), # it is not required for qwenimage-layered, unlike other pipelines + InputParam( + name="resized_image", + required=True, + type_hint=PIL.Image.Image, + description="The image to generate caption from, should be resized use the resize step", + ), + InputParam( + name="use_en_prompt", + default=False, + type_hint=bool, + description="Whether to use English prompt template", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + name="prompt", + type_hint=str, + description="The prompt or prompts to guide image generation. If not provided, updated using image caption", + ), + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + device = components._execution_device + + # If prompt is empty or None, generate caption from image + if block_state.prompt is None or block_state.prompt == "" or block_state.prompt == " ": + if block_state.use_en_prompt: + caption_prompt = self.image_caption_prompt_en + else: + caption_prompt = self.image_caption_prompt_cn + + model_inputs = components.processor( + text=caption_prompt, + images=block_state.resized_image, + padding=True, + return_tensors="pt", + ).to(device) + + generated_ids = components.text_encoder.generate(**model_inputs, max_new_tokens=512) + generated_ids_trimmed = [ + out_ids[len(in_ids) :] for in_ids, out_ids in zip(model_inputs.input_ids, generated_ids) + ] + output_text = components.processor.batch_decode( + generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False + )[0] + + block_state.prompt = output_text.strip() + + self.set_block_state(state, block_state) + return components, state + + +# ==================== +# 3. TEXT ENCODER +# ==================== + + +# auto_docstring +class QwenImageTextEncoderStep(ModularPipelineBlocks): + """ + Text Encoder step that generates text embeddings to guide the image generation. + + Components: + text_encoder (`Qwen2_5_VLForConditionalGeneration`): The text encoder to use tokenizer (`Qwen2Tokenizer`): + The tokenizer to use guider (`ClassifierFreeGuidance`) + + Inputs: + prompt (`str`): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + max_sequence_length (`int`, *optional*, defaults to 1024): + Maximum sequence length for prompt encoding. + + Outputs: + prompt_embeds (`Tensor`): + The prompt embeddings. + prompt_embeds_mask (`Tensor`): + The encoder attention mask. + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. + """ + + model_name = "qwenimage" + + def __init__(self): + self.prompt_template_encode = QWENIMAGE_PROMPT_TEMPLATE + self.prompt_template_encode_start_idx = QWENIMAGE_PROMPT_TEMPLATE_START_IDX + self.tokenizer_max_length = 1024 + super().__init__() + + @property + def description(self) -> str: + return "Text Encoder step that generates text embeddings to guide the image generation." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("text_encoder", Qwen2_5_VLForConditionalGeneration, description="The text encoder to use"), + ComponentSpec("tokenizer", Qwen2Tokenizer, description="The tokenizer to use"), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 4.0}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("prompt"), + InputParam.template("negative_prompt"), + InputParam.template("max_sequence_length", default=1024), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam.template("prompt_embeds"), + OutputParam.template("prompt_embeds_mask"), + OutputParam.template("negative_prompt_embeds"), + OutputParam.template("negative_prompt_embeds_mask"), + ] + + @staticmethod + def check_inputs(prompt, negative_prompt, max_sequence_length): + if not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if ( + negative_prompt is not None + and not isinstance(negative_prompt, str) + and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + if max_sequence_length is not None and max_sequence_length > 1024: + raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState): + block_state = self.get_block_state(state) + + device = components._execution_device + self.check_inputs(block_state.prompt, block_state.negative_prompt, block_state.max_sequence_length) + + block_state.prompt_embeds, block_state.prompt_embeds_mask = get_qwen_prompt_embeds( + components.text_encoder, + components.tokenizer, + prompt=block_state.prompt, + prompt_template_encode=self.prompt_template_encode, + prompt_template_encode_start_idx=self.prompt_template_encode_start_idx, + tokenizer_max_length=self.tokenizer_max_length, + device=device, + ) + + block_state.prompt_embeds = block_state.prompt_embeds[:, : block_state.max_sequence_length] + block_state.prompt_embeds_mask = block_state.prompt_embeds_mask[:, : block_state.max_sequence_length] + + block_state.negative_prompt_embeds = None + block_state.negative_prompt_embeds_mask = None + if components.requires_unconditional_embeds: + negative_prompt = block_state.negative_prompt or "" + block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = get_qwen_prompt_embeds( + components.text_encoder, + components.tokenizer, + prompt=negative_prompt, + prompt_template_encode=self.prompt_template_encode, + prompt_template_encode_start_idx=self.prompt_template_encode_start_idx, + tokenizer_max_length=self.tokenizer_max_length, + device=device, + ) + block_state.negative_prompt_embeds = block_state.negative_prompt_embeds[ + :, : block_state.max_sequence_length + ] + block_state.negative_prompt_embeds_mask = block_state.negative_prompt_embeds_mask[ + :, : block_state.max_sequence_length + ] + + self.set_block_state(state, block_state) + return components, state + + +# auto_docstring +class QwenImageEditTextEncoderStep(ModularPipelineBlocks): + """ + Text Encoder step that processes both prompt and image together to generate text embeddings for guiding image + generation. + + Components: + text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor (`Qwen2VLProcessor`) guider + (`ClassifierFreeGuidance`) + + Inputs: + prompt (`str`): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + resized_image (`Image`): + The image prompt to encode, should be resized using resize step + + Outputs: + prompt_embeds (`Tensor`): + The prompt embeddings. + prompt_embeds_mask (`Tensor`): + The encoder attention mask. + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. + """ + + model_name = "qwenimage" + + def __init__(self): + self.prompt_template_encode = QWENIMAGE_EDIT_PROMPT_TEMPLATE + self.prompt_template_encode_start_idx = QWENIMAGE_EDIT_PROMPT_TEMPLATE_START_IDX + super().__init__() + + @property + def description(self) -> str: + return "Text Encoder step that processes both prompt and image together to generate text embeddings for guiding image generation." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("text_encoder", Qwen2_5_VLForConditionalGeneration), + ComponentSpec("processor", Qwen2VLProcessor), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 4.0}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("prompt"), + InputParam.template("negative_prompt"), + InputParam( + name="resized_image", + required=True, + type_hint=PIL.Image.Image, + description="The image prompt to encode, should be resized using resize step", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam.template("prompt_embeds"), + OutputParam.template("prompt_embeds_mask"), + OutputParam.template("negative_prompt_embeds"), + OutputParam.template("negative_prompt_embeds_mask"), + ] + + @staticmethod + def check_inputs(prompt, negative_prompt): + if not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if ( + negative_prompt is not None + and not isinstance(negative_prompt, str) + and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState): + block_state = self.get_block_state(state) + + self.check_inputs(block_state.prompt, block_state.negative_prompt) + + device = components._execution_device + + block_state.prompt_embeds, block_state.prompt_embeds_mask = get_qwen_prompt_embeds_edit( + components.text_encoder, + components.processor, + prompt=block_state.prompt, + image=block_state.resized_image, + prompt_template_encode=self.prompt_template_encode, + prompt_template_encode_start_idx=self.prompt_template_encode_start_idx, + device=device, + ) + + block_state.negative_prompt_embeds = None + block_state.negative_prompt_embeds_mask = None + if components.requires_unconditional_embeds: + negative_prompt = block_state.negative_prompt or " " + block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = get_qwen_prompt_embeds_edit( + components.text_encoder, + components.processor, + prompt=negative_prompt, + image=block_state.resized_image, + prompt_template_encode=self.prompt_template_encode, + prompt_template_encode_start_idx=self.prompt_template_encode_start_idx, + device=device, + ) + + self.set_block_state(state, block_state) + return components, state + + +# auto_docstring +class QwenImageEditPlusTextEncoderStep(ModularPipelineBlocks): + """ + Text Encoder step for QwenImage Edit Plus that processes prompt and multiple images together to generate text + embeddings for guiding image generation. + + Components: + text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor (`Qwen2VLProcessor`) guider + (`ClassifierFreeGuidance`) + + Inputs: + prompt (`str`): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + resized_cond_image (`Tensor`): + The image(s) to encode, can be a single image or list of images, should be resized to 384x384 using + resize step + + Outputs: + prompt_embeds (`Tensor`): + The prompt embeddings. + prompt_embeds_mask (`Tensor`): + The encoder attention mask. + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. + """ + + model_name = "qwenimage-edit-plus" + + def __init__(self): + self.prompt_template_encode = QWENIMAGE_EDIT_PLUS_PROMPT_TEMPLATE + self.img_template_encode = QWENIMAGE_EDIT_PLUS_IMG_TEMPLATE + self.prompt_template_encode_start_idx = QWENIMAGE_EDIT_PLUS_PROMPT_TEMPLATE_START_IDX + super().__init__() + + @property + def description(self) -> str: + return ( + "Text Encoder step for QwenImage Edit Plus that processes prompt and multiple images together " + "to generate text embeddings for guiding image generation." + ) + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("text_encoder", Qwen2_5_VLForConditionalGeneration), + ComponentSpec("processor", Qwen2VLProcessor), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 4.0}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("prompt"), + InputParam.template("negative_prompt"), + InputParam( + name="resized_cond_image", + required=True, + type_hint=torch.Tensor, + description="The image(s) to encode, can be a single image or list of images, should be resized to 384x384 using resize step", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam.template("prompt_embeds"), + OutputParam.template("prompt_embeds_mask"), + OutputParam.template("negative_prompt_embeds"), + OutputParam.template("negative_prompt_embeds_mask"), + ] + + @staticmethod + def check_inputs(prompt, negative_prompt): + if not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if ( + negative_prompt is not None + and not isinstance(negative_prompt, str) + and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState): + block_state = self.get_block_state(state) + + self.check_inputs(block_state.prompt, block_state.negative_prompt) + + device = components._execution_device + + block_state.prompt_embeds, block_state.prompt_embeds_mask = get_qwen_prompt_embeds_edit_plus( + components.text_encoder, + components.processor, + prompt=block_state.prompt, + image=block_state.resized_cond_image, + prompt_template_encode=self.prompt_template_encode, + img_template_encode=self.img_template_encode, + prompt_template_encode_start_idx=self.prompt_template_encode_start_idx, + device=device, + ) + + block_state.negative_prompt_embeds = None + block_state.negative_prompt_embeds_mask = None + if components.requires_unconditional_embeds: + negative_prompt = block_state.negative_prompt or " " + block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = ( + get_qwen_prompt_embeds_edit_plus( + components.text_encoder, + components.processor, + prompt=negative_prompt, + image=block_state.resized_cond_image, + prompt_template_encode=self.prompt_template_encode, + img_template_encode=self.img_template_encode, + prompt_template_encode_start_idx=self.prompt_template_encode_start_idx, + device=device, + ) + ) + + self.set_block_state(state, block_state) + return components, state + + +# ==================== +# 4. IMAGE PREPROCESS +# ==================== + + +# auto_docstring +class QwenImageInpaintProcessImagesInputStep(ModularPipelineBlocks): + """ + Image Preprocess step for inpainting task. This processes the image and mask inputs together. Images will be + resized to the given height and width. + + Components: + image_mask_processor (`InpaintProcessor`) + + Inputs: + mask_image (`Image`): + Mask image for inpainting. + image (`Image | list`): + Reference image(s) for denoising. Can be a single image or list of images. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + padding_mask_crop (`int`, *optional*): + Padding for mask cropping in inpainting. + + Outputs: + processed_image (`Tensor`): + The processed image + processed_mask_image (`Tensor`): + The processed mask image + mask_overlay_kwargs (`dict`): + The kwargs for the postprocess step to apply the mask overlay + """ + + model_name = "qwenimage" + + @property + def description(self) -> str: + return "Image Preprocess step for inpainting task. This processes the image and mask inputs together. Images will be resized to the given height and width." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec( + "image_mask_processor", + InpaintProcessor, + config=FrozenDict({"vae_scale_factor": 16}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("mask_image"), + InputParam.template("image"), + InputParam.template("height"), + InputParam.template("width"), + InputParam.template("padding_mask_crop"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + name="processed_image", + type_hint=torch.Tensor, + description="The processed image", + ), + OutputParam( + name="processed_mask_image", + type_hint=torch.Tensor, + description="The processed mask image", + ), + OutputParam( + name="mask_overlay_kwargs", + type_hint=dict, + description="The kwargs for the postprocess step to apply the mask overlay", + ), + ] + + @staticmethod + def check_inputs(height, width, vae_scale_factor): + if height is not None and height % (vae_scale_factor * 2) != 0: + raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}") + + if width is not None and width % (vae_scale_factor * 2) != 0: + raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}") + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState): + block_state = self.get_block_state(state) + + self.check_inputs( + height=block_state.height, width=block_state.width, vae_scale_factor=components.vae_scale_factor + ) + height = block_state.height or components.default_height + width = block_state.width or components.default_width + + block_state.processed_image, block_state.processed_mask_image, block_state.mask_overlay_kwargs = ( + components.image_mask_processor.preprocess( + image=block_state.image, + mask=block_state.mask_image, + height=height, + width=width, + padding_mask_crop=block_state.padding_mask_crop, + ) + ) + + self.set_block_state(state, block_state) + return components, state + + +# auto_docstring +class QwenImageEditInpaintProcessImagesInputStep(ModularPipelineBlocks): + """ + Image Preprocess step for inpainting task. This processes the image and mask inputs together. Images should be + resized first. + + Components: + image_mask_processor (`InpaintProcessor`) + + Inputs: + mask_image (`Image`): + Mask image for inpainting. + resized_image (`Image`): + The resized image. should be generated using a resize step + padding_mask_crop (`int`, *optional*): + Padding for mask cropping in inpainting. + + Outputs: + processed_image (`Tensor`): + The processed image + processed_mask_image (`Tensor`): + The processed mask image + mask_overlay_kwargs (`dict`): + The kwargs for the postprocess step to apply the mask overlay + """ + + model_name = "qwenimage-edit" + + @property + def description(self) -> str: + return "Image Preprocess step for inpainting task. This processes the image and mask inputs together. Images should be resized first." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec( + "image_mask_processor", + InpaintProcessor, + config=FrozenDict({"vae_scale_factor": 16}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("mask_image"), + InputParam( + name="resized_image", + required=True, + type_hint=PIL.Image.Image, + description="The resized image. should be generated using a resize step", + ), + InputParam.template("padding_mask_crop"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam(name="processed_image", type_hint=torch.Tensor, description="The processed image"), + OutputParam( + name="processed_mask_image", + type_hint=torch.Tensor, + description="The processed mask image", + ), + OutputParam( + name="mask_overlay_kwargs", + type_hint=dict, + description="The kwargs for the postprocess step to apply the mask overlay", + ), + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState): + block_state = self.get_block_state(state) + + width, height = block_state.resized_image[0].size + + block_state.processed_image, block_state.processed_mask_image, block_state.mask_overlay_kwargs = ( + components.image_mask_processor.preprocess( + image=block_state.resized_image, + mask=block_state.mask_image, + height=height, + width=width, + padding_mask_crop=block_state.padding_mask_crop, + ) + ) + + self.set_block_state(state, block_state) + return components, state + + +# auto_docstring +class QwenImageProcessImagesInputStep(ModularPipelineBlocks): + """ + Image Preprocess step. will resize the image to the given height and width. + + Components: + image_processor (`VaeImageProcessor`) + + Inputs: + image (`Image | list`): + Reference image(s) for denoising. Can be a single image or list of images. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + + Outputs: + processed_image (`Tensor`): + The processed image + """ + + model_name = "qwenimage" + + @property + def description(self) -> str: + return "Image Preprocess step. will resize the image to the given height and width." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 16}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("image"), + InputParam.template("height"), + InputParam.template("width"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + name="processed_image", + type_hint=torch.Tensor, + description="The processed image", + ) + ] + + @staticmethod + def check_inputs(height, width, vae_scale_factor): + if height is not None and height % (vae_scale_factor * 2) != 0: + raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}") + + if width is not None and width % (vae_scale_factor * 2) != 0: + raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}") + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState): + block_state = self.get_block_state(state) + + self.check_inputs( + height=block_state.height, width=block_state.width, vae_scale_factor=components.vae_scale_factor + ) + height = block_state.height or components.default_height + width = block_state.width or components.default_width + + block_state.processed_image = components.image_processor.preprocess( + image=block_state.image, + height=height, + width=width, + ) + + self.set_block_state(state, block_state) + return components, state + + +# auto_docstring +class QwenImageEditProcessImagesInputStep(ModularPipelineBlocks): + """ + Image Preprocess step. Images needs to be resized first. + + Components: + image_processor (`VaeImageProcessor`) + + Inputs: + resized_image (`list`): + The resized image. should be generated using a resize step + + Outputs: + processed_image (`Tensor`): + The processed image + """ + + model_name = "qwenimage-edit" + + @property + def description(self) -> str: + return "Image Preprocess step. Images needs to be resized first." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 16}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + name="resized_image", + required=True, + type_hint=list[PIL.Image.Image], + description="The resized image. should be generated using a resize step", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + name="processed_image", + type_hint=torch.Tensor, + description="The processed image", + ) + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState): + block_state = self.get_block_state(state) + + width, height = block_state.resized_image[0].size + + block_state.processed_image = components.image_processor.preprocess( + image=block_state.resized_image, + height=height, + width=width, + ) + + self.set_block_state(state, block_state) + return components, state + + +# auto_docstring +class QwenImageEditPlusProcessImagesInputStep(ModularPipelineBlocks): + """ + Image Preprocess step. Images can be resized first. If a list of images is provided, will return a list of + processed images. + + Components: + image_processor (`VaeImageProcessor`) + + Inputs: + resized_image (`list`): + The resized image. should be generated using a resize step + + Outputs: + processed_image (`Tensor`): + The processed image + """ + + model_name = "qwenimage-edit-plus" + + @property + def description(self) -> str: + return "Image Preprocess step. Images can be resized first. If a list of images is provided, will return a list of processed images." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 16}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + name="resized_image", + required=True, + type_hint=list[PIL.Image.Image], + description="The resized image. should be generated using a resize step", + ) + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + name="processed_image", + type_hint=torch.Tensor, + description="The processed image", + ) + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState): + block_state = self.get_block_state(state) + + image = block_state.resized_image + + is_image_list = isinstance(image, list) + if not is_image_list: + image = [image] + + processed_images = [] + for img in image: + img_width, img_height = img.size + processed_images.append( + components.image_processor.preprocess(image=img, height=img_height, width=img_width) + ) + + if is_image_list: + block_state.processed_image = processed_images + else: + block_state.processed_image = processed_images[0] + + self.set_block_state(state, block_state) + return components, state + + +# ==================== +# 5. VAE ENCODER +# ==================== + + +# auto_docstring +class QwenImageVaeEncoderStep(ModularPipelineBlocks): + """ + VAE Encoder step that converts processed_image into latent representations image_latents. + Handles both single images and lists of images with varied resolutions. + + Components: + vae (`AutoencoderKLQwenImage`) + + Inputs: + processed_image (`Tensor`): + The image tensor to encode + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + + Outputs: + image_latents (`Tensor`): + The latent representation of the input image. + """ + + model_name = "qwenimage" + + def __init__(self, input: InputParam | None = None, output: OutputParam | None = None): + """Initialize a VAE encoder step for converting images to latent representations. + + Handles both single images and lists of images. When input is a list, outputs a list of latents. When input is + a single tensor, outputs a single latent tensor. + + Args: + input (InputParam, optional): Input parameter for the processed image. Defaults to "processed_image". + output (OutputParam, optional): Output parameter for the image latents. Defaults to "image_latents". + """ + if input is None: + input = InputParam( + name="processed_image", required=True, type_hint=torch.Tensor, description="The image tensor to encode" + ) + + if output is None: + output = OutputParam.template("image_latents") + + if not isinstance(input, InputParam): + raise ValueError(f"input must be InputParam but is {type(input)}") + if not isinstance(output, OutputParam): + raise ValueError(f"output must be OutputParam but is {type(output)}") + + self._input = input + self._output = output + self._image_input_name = input.name + self._image_latents_output_name = output.name + super().__init__() + + @property + def description(self) -> str: + return ( + f"VAE Encoder step that converts {self._image_input_name} into latent representations {self._image_latents_output_name}.\n" + "Handles both single images and lists of images with varied resolutions." + ) + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ComponentSpec("vae", AutoencoderKLQwenImage)] + + @property + def inputs(self) -> list[InputParam]: + return [ + self._input, # default is "processed_image" + InputParam.template("generator"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [self._output] # default is "image_latents" + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + device = components._execution_device + dtype = components.vae.dtype + + image = getattr(block_state, self._image_input_name) + is_image_list = isinstance(image, list) + if not is_image_list: + image = [image] + + # Handle both single image and list of images + image_latents = [] + for img in image: + image_latents.append( + encode_vae_image( + image=img, + vae=components.vae, + generator=block_state.generator, + device=device, + dtype=dtype, + latent_channels=components.num_channels_latents, + ) + ) + if not is_image_list: + image_latents = image_latents[0] + + setattr(block_state, self._image_latents_output_name, image_latents) + + self.set_block_state(state, block_state) + + return components, state + + +# auto_docstring +class QwenImageControlNetVaeEncoderStep(ModularPipelineBlocks): + """ + VAE Encoder step that converts `control_image` into latent representations control_image_latents. + + Components: + vae (`AutoencoderKLQwenImage`) controlnet (`QwenImageControlNetModel`) control_image_processor + (`VaeImageProcessor`) + + Inputs: + control_image (`Image`): + Control image for ControlNet conditioning. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + + Outputs: + control_image_latents (`Tensor`): + The latents representing the control image + """ + + model_name = "qwenimage" + + @property + def description(self) -> str: + return "VAE Encoder step that converts `control_image` into latent representations control_image_latents.\n" + + @property + def expected_components(self) -> list[ComponentSpec]: + components = [ + ComponentSpec("vae", AutoencoderKLQwenImage), + ComponentSpec("controlnet", QwenImageControlNetModel), + ComponentSpec( + "control_image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 16}), + default_creation_method="from_config", + ), + ] + return components + + @property + def inputs(self) -> list[InputParam]: + inputs = [ + InputParam.template("control_image"), + InputParam.template("height"), + InputParam.template("width"), + InputParam.template("generator"), + ] + return inputs + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "control_image_latents", + type_hint=torch.Tensor, + description="The latents representing the control image", + ) + ] + + @staticmethod + def check_inputs(height, width, vae_scale_factor): + if height is not None and height % (vae_scale_factor * 2) != 0: + raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}") + + if width is not None and width % (vae_scale_factor * 2) != 0: + raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}") + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + self.check_inputs(block_state.height, block_state.width, components.vae_scale_factor) + + device = components._execution_device + dtype = components.vae.dtype + + height = block_state.height or components.default_height + width = block_state.width or components.default_width + + controlnet = unwrap_module(components.controlnet) + if isinstance(controlnet, QwenImageMultiControlNetModel) and not isinstance(block_state.control_image, list): + block_state.control_image = [block_state.control_image] + + if isinstance(controlnet, QwenImageMultiControlNetModel): + block_state.control_image_latents = [] + for control_image_ in block_state.control_image: + control_image_ = components.control_image_processor.preprocess( + image=control_image_, + height=height, + width=width, + ) + + control_image_latents_ = encode_vae_image( + image=control_image_, + vae=components.vae, + generator=block_state.generator, + device=device, + dtype=dtype, + latent_channels=components.num_channels_latents, + sample_mode="sample", + ) + block_state.control_image_latents.append(control_image_latents_) + + elif isinstance(controlnet, QwenImageControlNetModel): + control_image = components.control_image_processor.preprocess( + image=block_state.control_image, + height=height, + width=width, + ) + block_state.control_image_latents = encode_vae_image( + image=control_image, + vae=components.vae, + generator=block_state.generator, + device=device, + dtype=dtype, + latent_channels=components.num_channels_latents, + sample_mode="sample", + ) + + else: + raise ValueError( + f"Expected controlnet to be a QwenImageControlNetModel or QwenImageMultiControlNetModel, got {type(controlnet)}" + ) + + self.set_block_state(state, block_state) + + return components, state + + +# ==================== +# 6. PERMUTE LATENTS +# ==================== + + +# auto_docstring +class QwenImageLayeredPermuteLatentsStep(ModularPipelineBlocks): + """ + Permute image latents from (B, C, 1, H, W) to (B, 1, C, H, W) for Layered packing. + + Inputs: + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + + Outputs: + image_latents (`Tensor`): + The latent representation of the input image. (permuted from [B, C, 1, H, W] to [B, 1, C, H, W]) + """ + + model_name = "qwenimage-layered" + + @property + def description(self) -> str: + return "Permute image latents from (B, C, 1, H, W) to (B, 1, C, H, W) for Layered packing." + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("image_latents"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam.template("image_latents", note="permuted from [B, C, 1, H, W] to [B, 1, C, H, W]"), + ] + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + # Permute: (B, C, 1, H, W) -> (B, 1, C, H, W) + latents = block_state.image_latents + block_state.image_latents = latents.permute(0, 2, 1, 3, 4) + + self.set_block_state(state, block_state) + return components, state diff --git a/diffusers/modular_pipelines/qwenimage/inputs.py b/diffusers/modular_pipelines/qwenimage/inputs.py new file mode 100644 index 0000000000000000000000000000000000000000..faec7db245df1ac1ab00245628c25abbf21e15f7 --- /dev/null +++ b/diffusers/modular_pipelines/qwenimage/inputs.py @@ -0,0 +1,1024 @@ +# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch + +from ...models import QwenImageMultiControlNetModel +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import QwenImageLayeredPachifier, QwenImageModularPipeline, QwenImagePachifier + + +def repeat_tensor_to_batch_size( + input_name: str, + input_tensor: torch.Tensor, + batch_size: int, + num_images_per_prompt: int = 1, +) -> torch.Tensor: + """Repeat tensor elements to match the final batch size. + + This function expands a tensor's batch dimension to match the final batch size (batch_size * num_images_per_prompt) + by repeating each element along dimension 0. + + The input tensor must have batch size 1 or batch_size. The function will: + - If batch size is 1: repeat each element (batch_size * num_images_per_prompt) times + - If batch size equals batch_size: repeat each element num_images_per_prompt times + + Args: + input_name (str): Name of the input tensor (used for error messages) + input_tensor (torch.Tensor): The tensor to repeat. Must have batch size 1 or batch_size. + batch_size (int): The base batch size (number of prompts) + num_images_per_prompt (int, optional): Number of images to generate per prompt. Defaults to 1. + + Returns: + torch.Tensor: The repeated tensor with final batch size (batch_size * num_images_per_prompt) + + Raises: + ValueError: If input_tensor is not a torch.Tensor or has invalid batch size + + Examples: + tensor = torch.tensor([[1, 2, 3]]) # shape: [1, 3] repeated = repeat_tensor_to_batch_size("image", tensor, + batch_size=2, num_images_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]]) - shape: + [4, 3] + + tensor = torch.tensor([[1, 2, 3], [4, 5, 6]]) # shape: [2, 3] repeated = repeat_tensor_to_batch_size("image", + tensor, batch_size=2, num_images_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [4, 5, 6], [4, 5, 6]]) + - shape: [4, 3] + """ + # make sure input is a tensor + if not isinstance(input_tensor, torch.Tensor): + raise ValueError(f"`{input_name}` must be a tensor") + + # make sure input tensor e.g. image_latents has batch size 1 or batch_size same as prompts + if input_tensor.shape[0] == 1: + repeat_by = batch_size * num_images_per_prompt + elif input_tensor.shape[0] == batch_size: + repeat_by = num_images_per_prompt + else: + raise ValueError( + f"`{input_name}` must have have batch size 1 or {batch_size}, but got {input_tensor.shape[0]}" + ) + + # expand the tensor to match the batch_size * num_images_per_prompt + input_tensor = input_tensor.repeat_interleave(repeat_by, dim=0) + + return input_tensor + + +def calculate_dimension_from_latents(latents: torch.Tensor, vae_scale_factor: int) -> tuple[int, int]: + """Calculate image dimensions from latent tensor dimensions. + + This function converts latent space dimensions to image space dimensions by multiplying the latent height and width + by the VAE scale factor. + + Args: + latents (torch.Tensor): The latent tensor. Must have 4 or 5 dimensions. + Expected shapes: [batch, channels, height, width] or [batch, channels, frames, height, width] + vae_scale_factor (int): The scale factor used by the VAE to compress images. + Typically 8 for most VAEs (image is 8x larger than latents in each dimension) + + Returns: + tuple[int, int]: The calculated image dimensions as (height, width) + + Raises: + ValueError: If latents tensor doesn't have 4 or 5 dimensions + + """ + # make sure the latents are not packed + if latents.ndim != 4 and latents.ndim != 5: + raise ValueError(f"unpacked latents must have 4 or 5 dimensions, but got {latents.ndim}") + + latent_height, latent_width = latents.shape[-2:] + + height = latent_height * vae_scale_factor + width = latent_width * vae_scale_factor + + return height, width + + +# auto_docstring +class QwenImageTextInputsStep(ModularPipelineBlocks): + """ + Text input processing step that standardizes text embeddings for the pipeline. + This step: + 1. Determines `batch_size` and `dtype` based on `prompt_embeds` + 2. Ensures all text embeddings have consistent batch sizes (batch_size * num_images_per_prompt) + + This block should be placed after all encoder steps to process the text embeddings before they are used in + subsequent pipeline steps. + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + + Outputs: + batch_size (`int`): + The batch size of the prompt embeddings + dtype (`dtype`): + The data type of the prompt embeddings + prompt_embeds (`Tensor`): + The prompt embeddings. (batch-expanded) + prompt_embeds_mask (`Tensor`): + The encoder attention mask. (batch-expanded) + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. (batch-expanded) + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. (batch-expanded) + """ + + model_name = "qwenimage" + + @property + def description(self) -> str: + summary_section = ( + "Text input processing step that standardizes text embeddings for the pipeline.\n" + "This step:\n" + " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n" + " 2. Ensures all text embeddings have consistent batch sizes (batch_size * num_images_per_prompt)" + ) + + # Placement guidance + placement_section = "\n\nThis block should be placed after all encoder steps to process the text embeddings before they are used in subsequent pipeline steps." + + return summary_section + placement_section + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("num_images_per_prompt"), + InputParam.template("prompt_embeds"), + InputParam.template("prompt_embeds_mask"), + InputParam.template("negative_prompt_embeds"), + InputParam.template("negative_prompt_embeds_mask"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam(name="batch_size", type_hint=int, description="The batch size of the prompt embeddings"), + OutputParam(name="dtype", type_hint=torch.dtype, description="The data type of the prompt embeddings"), + OutputParam.template("prompt_embeds", note="batch-expanded"), + OutputParam.template("prompt_embeds_mask", note="batch-expanded"), + OutputParam.template("negative_prompt_embeds", note="batch-expanded"), + OutputParam.template("negative_prompt_embeds_mask", note="batch-expanded"), + ] + + @staticmethod + def check_inputs( + prompt_embeds, + prompt_embeds_mask, + negative_prompt_embeds, + negative_prompt_embeds_mask, + ): + if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: + raise ValueError("`negative_prompt_embeds_mask` is required when `negative_prompt_embeds` is not None") + + if negative_prompt_embeds is None and negative_prompt_embeds_mask is not None: + raise ValueError("cannot pass `negative_prompt_embeds_mask` without `negative_prompt_embeds`") + + if prompt_embeds_mask.shape[0] != prompt_embeds.shape[0]: + raise ValueError("`prompt_embeds_mask` must have the same batch size as `prompt_embeds`") + + elif negative_prompt_embeds is not None and negative_prompt_embeds.shape[0] != prompt_embeds.shape[0]: + raise ValueError("`negative_prompt_embeds` must have the same batch size as `prompt_embeds`") + + elif ( + negative_prompt_embeds_mask is not None and negative_prompt_embeds_mask.shape[0] != prompt_embeds.shape[0] + ): + raise ValueError("`negative_prompt_embeds_mask` must have the same batch size as `prompt_embeds`") + + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + self.check_inputs( + prompt_embeds=block_state.prompt_embeds, + prompt_embeds_mask=block_state.prompt_embeds_mask, + negative_prompt_embeds=block_state.negative_prompt_embeds, + negative_prompt_embeds_mask=block_state.negative_prompt_embeds_mask, + ) + + block_state.batch_size = block_state.prompt_embeds.shape[0] + block_state.dtype = block_state.prompt_embeds.dtype + + _, seq_len, _ = block_state.prompt_embeds.shape + + block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) + block_state.prompt_embeds = block_state.prompt_embeds.view( + block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1 + ) + + block_state.prompt_embeds_mask = block_state.prompt_embeds_mask.repeat(1, block_state.num_images_per_prompt, 1) + block_state.prompt_embeds_mask = block_state.prompt_embeds_mask.view( + block_state.batch_size * block_state.num_images_per_prompt, seq_len + ) + + if block_state.negative_prompt_embeds is not None: + _, seq_len, _ = block_state.negative_prompt_embeds.shape + block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat( + 1, block_state.num_images_per_prompt, 1 + ) + block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view( + block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1 + ) + + block_state.negative_prompt_embeds_mask = block_state.negative_prompt_embeds_mask.repeat( + 1, block_state.num_images_per_prompt, 1 + ) + block_state.negative_prompt_embeds_mask = block_state.negative_prompt_embeds_mask.view( + block_state.batch_size * block_state.num_images_per_prompt, seq_len + ) + + self.set_block_state(state, block_state) + + return components, state + + +# auto_docstring +class QwenImageAdditionalInputsStep(ModularPipelineBlocks): + """ + Input processing step that: + 1. For image latent inputs: Updates height/width if None, patchifies, and expands batch size + 2. For additional batch inputs: Expands batch dimensions to match final batch size + + Configured inputs: + - Image latent inputs: ['image_latents'] + + This block should be placed after the encoder steps and the text input step. + + Components: + pachifier (`QwenImagePachifier`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + batch_size (`int`, *optional*, defaults to 1): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can + be generated in input step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + + Outputs: + image_height (`int`): + The image height calculated from the image latents dimension + image_width (`int`): + The image width calculated from the image latents dimension + height (`int`): + if not provided, updated to image height + width (`int`): + if not provided, updated to image width + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified and + batch-expanded) + """ + + model_name = "qwenimage" + + def __init__( + self, + image_latent_inputs: list[InputParam] | None = None, + additional_batch_inputs: list[InputParam] | None = None, + ): + # by default, process `image_latents` + if image_latent_inputs is None: + image_latent_inputs = [InputParam.template("image_latents")] + if additional_batch_inputs is None: + additional_batch_inputs = [] + + if not isinstance(image_latent_inputs, list): + raise ValueError(f"image_latent_inputs must be a list, but got {type(image_latent_inputs)}") + else: + for input_param in image_latent_inputs: + if not isinstance(input_param, InputParam): + raise ValueError(f"image_latent_inputs must be a list of InputParam, but got {type(input_param)}") + + if not isinstance(additional_batch_inputs, list): + raise ValueError(f"additional_batch_inputs must be a list, but got {type(additional_batch_inputs)}") + else: + for input_param in additional_batch_inputs: + if not isinstance(input_param, InputParam): + raise ValueError( + f"additional_batch_inputs must be a list of InputParam, but got {type(input_param)}" + ) + + self._image_latent_inputs = image_latent_inputs + self._additional_batch_inputs = additional_batch_inputs + super().__init__() + + @property + def description(self) -> str: + summary_section = ( + "Input processing step that:\n" + " 1. For image latent inputs: Updates height/width if None, patchifies, and expands batch size\n" + " 2. For additional batch inputs: Expands batch dimensions to match final batch size" + ) + + inputs_info = "" + if self._image_latent_inputs or self._additional_batch_inputs: + inputs_info = "\n\nConfigured inputs:" + if self._image_latent_inputs: + inputs_info += f"\n - Image latent inputs: {[p.name for p in self._image_latent_inputs]}" + if self._additional_batch_inputs: + inputs_info += f"\n - Additional batch inputs: {[p.name for p in self._additional_batch_inputs]}" + + placement_section = "\n\nThis block should be placed after the encoder steps and the text input step." + + return summary_section + inputs_info + placement_section + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"), + ] + + @property + def inputs(self) -> list[InputParam]: + inputs = [ + InputParam.template("num_images_per_prompt"), + InputParam.template("batch_size"), + InputParam.template("height"), + InputParam.template("width"), + ] + # default is `image_latents` + inputs += self._image_latent_inputs + self._additional_batch_inputs + + return inputs + + @property + def intermediate_outputs(self) -> list[OutputParam]: + outputs = [ + OutputParam( + name="image_height", + type_hint=int, + description="The image height calculated from the image latents dimension", + ), + OutputParam( + name="image_width", + type_hint=int, + description="The image width calculated from the image latents dimension", + ), + ] + + # `height`/`width` are not new outputs, but they will be updated if any image latent inputs are provided + if len(self._image_latent_inputs) > 0: + outputs.append( + OutputParam(name="height", type_hint=int, description="if not provided, updated to image height") + ) + outputs.append( + OutputParam(name="width", type_hint=int, description="if not provided, updated to image width") + ) + + # image latent inputs are modified in place (patchified and batch-expanded) + for input_param in self._image_latent_inputs: + outputs.append( + OutputParam( + name=input_param.name, + type_hint=input_param.type_hint, + description=input_param.description + " (patchified and batch-expanded)", + ) + ) + + # additional batch inputs (batch-expanded only) + for input_param in self._additional_batch_inputs: + outputs.append( + OutputParam( + name=input_param.name, + type_hint=input_param.type_hint, + description=input_param.description + " (batch-expanded)", + ) + ) + + return outputs + + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + # Process image latent inputs + for input_param in self._image_latent_inputs: + image_latent_input_name = input_param.name + image_latent_tensor = getattr(block_state, image_latent_input_name) + if image_latent_tensor is None: + continue + + # 1. Calculate height/width from latents and update if not provided + height, width = calculate_dimension_from_latents(image_latent_tensor, components.vae_scale_factor) + block_state.height = block_state.height or height + block_state.width = block_state.width or width + + if not hasattr(block_state, "image_height"): + block_state.image_height = height + if not hasattr(block_state, "image_width"): + block_state.image_width = width + + # 2. Patchify + image_latent_tensor = components.pachifier.pack_latents(image_latent_tensor) + + # 3. Expand batch size + image_latent_tensor = repeat_tensor_to_batch_size( + input_name=image_latent_input_name, + input_tensor=image_latent_tensor, + num_images_per_prompt=block_state.num_images_per_prompt, + batch_size=block_state.batch_size, + ) + + setattr(block_state, image_latent_input_name, image_latent_tensor) + + # Process additional batch inputs (only batch expansion) + for input_param in self._additional_batch_inputs: + input_name = input_param.name + input_tensor = getattr(block_state, input_name) + if input_tensor is None: + continue + + input_tensor = repeat_tensor_to_batch_size( + input_name=input_name, + input_tensor=input_tensor, + num_images_per_prompt=block_state.num_images_per_prompt, + batch_size=block_state.batch_size, + ) + + setattr(block_state, input_name, input_tensor) + + self.set_block_state(state, block_state) + return components, state + + +# auto_docstring +class QwenImageEditPlusAdditionalInputsStep(ModularPipelineBlocks): + """ + Input processing step for Edit Plus that: + 1. For image latent inputs (list): Collects heights/widths, patchifies each, concatenates, expands batch + 2. For additional batch inputs: Expands batch dimensions to match final batch size + Height/width defaults to last image in the list. + + Configured inputs: + - Image latent inputs: ['image_latents'] + + This block should be placed after the encoder steps and the text input step. + + Components: + pachifier (`QwenImagePachifier`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + batch_size (`int`, *optional*, defaults to 1): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can + be generated in input step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + + Outputs: + image_height (`list`): + The image heights calculated from the image latents dimension + image_width (`list`): + The image widths calculated from the image latents dimension + height (`int`): + if not provided, updated to image height + width (`int`): + if not provided, updated to image width + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified, + concatenated, and batch-expanded) + """ + + model_name = "qwenimage-edit-plus" + + def __init__( + self, + image_latent_inputs: list[InputParam] | None = None, + additional_batch_inputs: list[InputParam] | None = None, + ): + if image_latent_inputs is None: + image_latent_inputs = [InputParam.template("image_latents")] + if additional_batch_inputs is None: + additional_batch_inputs = [] + + if not isinstance(image_latent_inputs, list): + raise ValueError(f"image_latent_inputs must be a list, but got {type(image_latent_inputs)}") + else: + for input_param in image_latent_inputs: + if not isinstance(input_param, InputParam): + raise ValueError(f"image_latent_inputs must be a list of InputParam, but got {type(input_param)}") + + if not isinstance(additional_batch_inputs, list): + raise ValueError(f"additional_batch_inputs must be a list, but got {type(additional_batch_inputs)}") + else: + for input_param in additional_batch_inputs: + if not isinstance(input_param, InputParam): + raise ValueError( + f"additional_batch_inputs must be a list of InputParam, but got {type(input_param)}" + ) + + self._image_latent_inputs = image_latent_inputs + self._additional_batch_inputs = additional_batch_inputs + super().__init__() + + @property + def description(self) -> str: + summary_section = ( + "Input processing step for Edit Plus that:\n" + " 1. For image latent inputs (list): Collects heights/widths, patchifies each, concatenates, expands batch\n" + " 2. For additional batch inputs: Expands batch dimensions to match final batch size\n" + " Height/width defaults to last image in the list." + ) + + inputs_info = "" + if self._image_latent_inputs or self._additional_batch_inputs: + inputs_info = "\n\nConfigured inputs:" + if self._image_latent_inputs: + inputs_info += f"\n - Image latent inputs: {[p.name for p in self._image_latent_inputs]}" + if self._additional_batch_inputs: + inputs_info += f"\n - Additional batch inputs: {[p.name for p in self._additional_batch_inputs]}" + + placement_section = "\n\nThis block should be placed after the encoder steps and the text input step." + + return summary_section + inputs_info + placement_section + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"), + ] + + @property + def inputs(self) -> list[InputParam]: + inputs = [ + InputParam.template("num_images_per_prompt"), + InputParam.template("batch_size"), + InputParam.template("height"), + InputParam.template("width"), + ] + + # default is `image_latents` + inputs += self._image_latent_inputs + self._additional_batch_inputs + + return inputs + + @property + def intermediate_outputs(self) -> list[OutputParam]: + outputs = [ + OutputParam( + name="image_height", + type_hint=list[int], + description="The image heights calculated from the image latents dimension", + ), + OutputParam( + name="image_width", + type_hint=list[int], + description="The image widths calculated from the image latents dimension", + ), + ] + + # `height`/`width` are updated if any image latent inputs are provided + if len(self._image_latent_inputs) > 0: + outputs.append( + OutputParam(name="height", type_hint=int, description="if not provided, updated to image height") + ) + outputs.append( + OutputParam(name="width", type_hint=int, description="if not provided, updated to image width") + ) + + # image latent inputs are modified in place (patchified, concatenated, and batch-expanded) + for input_param in self._image_latent_inputs: + outputs.append( + OutputParam( + name=input_param.name, + type_hint=input_param.type_hint, + description=input_param.description + " (patchified, concatenated, and batch-expanded)", + ) + ) + + # additional batch inputs (batch-expanded only) + for input_param in self._additional_batch_inputs: + outputs.append( + OutputParam( + name=input_param.name, + type_hint=input_param.type_hint, + description=input_param.description + " (batch-expanded)", + ) + ) + + return outputs + + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + # Process image latent inputs + for input_param in self._image_latent_inputs: + image_latent_input_name = input_param.name + image_latent_tensor = getattr(block_state, image_latent_input_name) + if image_latent_tensor is None: + continue + + is_list = isinstance(image_latent_tensor, list) + if not is_list: + image_latent_tensor = [image_latent_tensor] + + image_heights = [] + image_widths = [] + packed_image_latent_tensors = [] + + for i, img_latent_tensor in enumerate(image_latent_tensor): + # 1. Calculate height/width from latents + height, width = calculate_dimension_from_latents(img_latent_tensor, components.vae_scale_factor) + image_heights.append(height) + image_widths.append(width) + + # 2. Patchify + img_latent_tensor = components.pachifier.pack_latents(img_latent_tensor) + + # 3. Expand batch size + img_latent_tensor = repeat_tensor_to_batch_size( + input_name=f"{image_latent_input_name}[{i}]", + input_tensor=img_latent_tensor, + num_images_per_prompt=block_state.num_images_per_prompt, + batch_size=block_state.batch_size, + ) + packed_image_latent_tensors.append(img_latent_tensor) + + # Concatenate all packed latents along dim=1 + packed_image_latent_tensors = torch.cat(packed_image_latent_tensors, dim=1) + + # Output lists of heights/widths + block_state.image_height = image_heights + block_state.image_width = image_widths + + # Default height/width from last image + block_state.height = block_state.height or image_heights[-1] + block_state.width = block_state.width or image_widths[-1] + + setattr(block_state, image_latent_input_name, packed_image_latent_tensors) + + # Process additional batch inputs (only batch expansion) + for input_param in self._additional_batch_inputs: + input_name = input_param.name + input_tensor = getattr(block_state, input_name) + if input_tensor is None: + continue + + input_tensor = repeat_tensor_to_batch_size( + input_name=input_name, + input_tensor=input_tensor, + num_images_per_prompt=block_state.num_images_per_prompt, + batch_size=block_state.batch_size, + ) + + setattr(block_state, input_name, input_tensor) + + self.set_block_state(state, block_state) + return components, state + + +# same as QwenImageAdditionalInputsStep, but with layered pachifier. + + +# auto_docstring +class QwenImageLayeredAdditionalInputsStep(ModularPipelineBlocks): + """ + Input processing step for Layered that: + 1. For image latent inputs: Updates height/width if None, patchifies with layered pachifier, and expands batch + size + 2. For additional batch inputs: Expands batch dimensions to match final batch size + + Configured inputs: + - Image latent inputs: ['image_latents'] + + This block should be placed after the encoder steps and the text input step. + + Components: + pachifier (`QwenImageLayeredPachifier`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + batch_size (`int`, *optional*, defaults to 1): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can + be generated in input step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + + Outputs: + image_height (`int`): + The image height calculated from the image latents dimension + image_width (`int`): + The image width calculated from the image latents dimension + height (`int`): + if not provided, updated to image height + width (`int`): + if not provided, updated to image width + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified + with layered pachifier and batch-expanded) + """ + + model_name = "qwenimage-layered" + + def __init__( + self, + image_latent_inputs: list[InputParam] | None = None, + additional_batch_inputs: list[InputParam] | None = None, + ): + if image_latent_inputs is None: + image_latent_inputs = [InputParam.template("image_latents")] + if additional_batch_inputs is None: + additional_batch_inputs = [] + + if not isinstance(image_latent_inputs, list): + raise ValueError(f"image_latent_inputs must be a list, but got {type(image_latent_inputs)}") + else: + for input_param in image_latent_inputs: + if not isinstance(input_param, InputParam): + raise ValueError(f"image_latent_inputs must be a list of InputParam, but got {type(input_param)}") + + if not isinstance(additional_batch_inputs, list): + raise ValueError(f"additional_batch_inputs must be a list, but got {type(additional_batch_inputs)}") + else: + for input_param in additional_batch_inputs: + if not isinstance(input_param, InputParam): + raise ValueError( + f"additional_batch_inputs must be a list of InputParam, but got {type(input_param)}" + ) + + self._image_latent_inputs = image_latent_inputs + self._additional_batch_inputs = additional_batch_inputs + super().__init__() + + @property + def description(self) -> str: + summary_section = ( + "Input processing step for Layered that:\n" + " 1. For image latent inputs: Updates height/width if None, patchifies with layered pachifier, and expands batch size\n" + " 2. For additional batch inputs: Expands batch dimensions to match final batch size" + ) + + inputs_info = "" + if self._image_latent_inputs or self._additional_batch_inputs: + inputs_info = "\n\nConfigured inputs:" + if self._image_latent_inputs: + inputs_info += f"\n - Image latent inputs: {[p.name for p in self._image_latent_inputs]}" + if self._additional_batch_inputs: + inputs_info += f"\n - Additional batch inputs: {[p.name for p in self._additional_batch_inputs]}" + + placement_section = "\n\nThis block should be placed after the encoder steps and the text input step." + + return summary_section + inputs_info + placement_section + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("pachifier", QwenImageLayeredPachifier, default_creation_method="from_config"), + ] + + @property + def inputs(self) -> list[InputParam]: + inputs = [ + InputParam.template("num_images_per_prompt"), + InputParam.template("batch_size"), + ] + # default is `image_latents` + + inputs += self._image_latent_inputs + self._additional_batch_inputs + + return inputs + + @property + def intermediate_outputs(self) -> list[OutputParam]: + outputs = [ + OutputParam( + name="image_height", + type_hint=int, + description="The image height calculated from the image latents dimension", + ), + OutputParam( + name="image_width", + type_hint=int, + description="The image width calculated from the image latents dimension", + ), + ] + + if len(self._image_latent_inputs) > 0: + outputs.append( + OutputParam(name="height", type_hint=int, description="if not provided, updated to image height") + ) + outputs.append( + OutputParam(name="width", type_hint=int, description="if not provided, updated to image width") + ) + + # Add outputs for image latent inputs (patchified with layered pachifier and batch-expanded) + for input_param in self._image_latent_inputs: + outputs.append( + OutputParam( + name=input_param.name, + type_hint=input_param.type_hint, + description=input_param.description + " (patchified with layered pachifier and batch-expanded)", + ) + ) + + # Add outputs for additional batch inputs (batch-expanded only) + for input_param in self._additional_batch_inputs: + outputs.append( + OutputParam( + name=input_param.name, + type_hint=input_param.type_hint, + description=input_param.description + " (batch-expanded)", + ) + ) + + return outputs + + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + # Process image latent inputs + for input_param in self._image_latent_inputs: + image_latent_input_name = input_param.name + image_latent_tensor = getattr(block_state, image_latent_input_name) + if image_latent_tensor is None: + continue + + # 1. Calculate height/width from latents and update if not provided + # Layered latents are (B, layers, C, H, W) + height = image_latent_tensor.shape[3] * components.vae_scale_factor + width = image_latent_tensor.shape[4] * components.vae_scale_factor + block_state.height = height + block_state.width = width + + if not hasattr(block_state, "image_height"): + block_state.image_height = height + if not hasattr(block_state, "image_width"): + block_state.image_width = width + + # 2. Patchify with layered pachifier + image_latent_tensor = components.pachifier.pack_latents(image_latent_tensor) + + # 3. Expand batch size + image_latent_tensor = repeat_tensor_to_batch_size( + input_name=image_latent_input_name, + input_tensor=image_latent_tensor, + num_images_per_prompt=block_state.num_images_per_prompt, + batch_size=block_state.batch_size, + ) + + setattr(block_state, image_latent_input_name, image_latent_tensor) + + # Process additional batch inputs (only batch expansion) + for input_param in self._additional_batch_inputs: + input_name = input_param.name + input_tensor = getattr(block_state, input_name) + if input_tensor is None: + continue + + input_tensor = repeat_tensor_to_batch_size( + input_name=input_name, + input_tensor=input_tensor, + num_images_per_prompt=block_state.num_images_per_prompt, + batch_size=block_state.batch_size, + ) + + setattr(block_state, input_name, input_tensor) + + self.set_block_state(state, block_state) + return components, state + + +# auto_docstring +class QwenImageControlNetInputsStep(ModularPipelineBlocks): + """ + prepare the `control_image_latents` for controlnet. Insert after all the other inputs steps. + + Inputs: + control_image_latents (`Tensor`): + The control image latents to use for the denoising process. Can be generated in controlnet vae encoder + step. + batch_size (`int`, *optional*, defaults to 1): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can + be generated in input step. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + + Outputs: + control_image_latents (`Tensor`): + The control image latents (patchified and batch-expanded). + height (`int`): + if not provided, updated to control image height + width (`int`): + if not provided, updated to control image width + """ + + model_name = "qwenimage" + + @property + def description(self) -> str: + return "prepare the `control_image_latents` for controlnet. Insert after all the other inputs steps." + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + name="control_image_latents", + required=True, + type_hint=torch.Tensor, + description="The control image latents to use for the denoising process. Can be generated in controlnet vae encoder step.", + ), + InputParam.template("batch_size"), + InputParam.template("num_images_per_prompt"), + InputParam.template("height"), + InputParam.template("width"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + name="control_image_latents", + type_hint=torch.Tensor, + description="The control image latents (patchified and batch-expanded).", + ), + OutputParam(name="height", type_hint=int, description="if not provided, updated to control image height"), + OutputParam(name="width", type_hint=int, description="if not provided, updated to control image width"), + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + if isinstance(components.controlnet, QwenImageMultiControlNetModel): + control_image_latents = [] + # loop through each control_image_latents + for i, control_image_latents_ in enumerate(block_state.control_image_latents): + # 1. update height/width if not provided + height, width = calculate_dimension_from_latents(control_image_latents_, components.vae_scale_factor) + block_state.height = block_state.height or height + block_state.width = block_state.width or width + + # 2. pack + control_image_latents_ = components.pachifier.pack_latents(control_image_latents_) + + # 3. repeat to match the batch size + control_image_latents_ = repeat_tensor_to_batch_size( + input_name=f"control_image_latents[{i}]", + input_tensor=control_image_latents_, + num_images_per_prompt=block_state.num_images_per_prompt, + batch_size=block_state.batch_size, + ) + + control_image_latents.append(control_image_latents_) + + block_state.control_image_latents = control_image_latents + + else: + # 1. update height/width if not provided + height, width = calculate_dimension_from_latents( + block_state.control_image_latents, components.vae_scale_factor + ) + block_state.height = block_state.height or height + block_state.width = block_state.width or width + + # 2. pack + block_state.control_image_latents = components.pachifier.pack_latents(block_state.control_image_latents) + + # 3. repeat to match the batch size + block_state.control_image_latents = repeat_tensor_to_batch_size( + input_name="control_image_latents", + input_tensor=block_state.control_image_latents, + num_images_per_prompt=block_state.num_images_per_prompt, + batch_size=block_state.batch_size, + ) + + block_state.control_image_latents = block_state.control_image_latents + + self.set_block_state(state, block_state) + + return components, state diff --git a/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage.py b/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage.py new file mode 100644 index 0000000000000000000000000000000000000000..bf87028b2f90841c9ba17357879b69b74cf8b150 --- /dev/null +++ b/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage.py @@ -0,0 +1,1224 @@ +# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from ...utils import logging +from ..modular_pipeline import AutoPipelineBlocks, ConditionalPipelineBlocks, SequentialPipelineBlocks +from ..modular_pipeline_utils import InputParam, InsertableDict, OutputParam +from .before_denoise import ( + QwenImageControlNetBeforeDenoiserStep, + QwenImageCreateMaskLatentsStep, + QwenImagePrepareLatentsStep, + QwenImagePrepareLatentsWithStrengthStep, + QwenImageRoPEInputsStep, + QwenImageSetTimestepsStep, + QwenImageSetTimestepsWithStrengthStep, +) +from .decoders import ( + QwenImageAfterDenoiseStep, + QwenImageDecoderStep, + QwenImageInpaintProcessImagesOutputStep, + QwenImageProcessImagesOutputStep, +) +from .denoise import ( + QwenImageControlNetDenoiseStep, + QwenImageDenoiseStep, + QwenImageInpaintControlNetDenoiseStep, + QwenImageInpaintDenoiseStep, +) +from .encoders import ( + QwenImageControlNetVaeEncoderStep, + QwenImageInpaintProcessImagesInputStep, + QwenImageProcessImagesInputStep, + QwenImageTextEncoderStep, + QwenImageVaeEncoderStep, +) +from .inputs import ( + QwenImageAdditionalInputsStep, + QwenImageControlNetInputsStep, + QwenImageTextInputsStep, +) + + +logger = logging.get_logger(__name__) + + +# ==================== +# 1. TEXT ENCODER +# ==================== + + +# auto_docstring +class QwenImageAutoTextEncoderStep(AutoPipelineBlocks): + """ + Text encoder step that encodes the text prompt into a text embedding. This is an auto pipeline block. + + Components: + text_encoder (`Qwen2_5_VLForConditionalGeneration`): The text encoder to use tokenizer (`Qwen2Tokenizer`): + The tokenizer to use guider (`ClassifierFreeGuidance`) + + Inputs: + prompt (`str`, *optional*): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + max_sequence_length (`int`, *optional*, defaults to 1024): + Maximum sequence length for prompt encoding. + + Outputs: + prompt_embeds (`Tensor`): + The prompt embeddings. + prompt_embeds_mask (`Tensor`): + The encoder attention mask. + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. + """ + + model_name = "qwenimage" + block_classes = [QwenImageTextEncoderStep()] + block_names = ["text_encoder"] + block_trigger_inputs = ["prompt"] + + @property + def description(self) -> str: + return "Text encoder step that encodes the text prompt into a text embedding. This is an auto pipeline block." + " - `QwenImageTextEncoderStep` (text_encoder) is used when `prompt` is provided." + " - if `prompt` is not provided, step will be skipped." + + +# ==================== +# 2. VAE ENCODER +# ==================== + + +# auto_docstring +class QwenImageInpaintVaeEncoderStep(SequentialPipelineBlocks): + """ + This step is used for processing image and mask inputs for inpainting tasks. It: + - Resizes the image to the target size, based on `height` and `width`. + - Processes and updates `image` and `mask_image`. + - Creates `image_latents`. + + Components: + image_mask_processor (`InpaintProcessor`) vae (`AutoencoderKLQwenImage`) + + Inputs: + mask_image (`Image`): + Mask image for inpainting. + image (`Image | list`): + Reference image(s) for denoising. Can be a single image or list of images. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + padding_mask_crop (`int`, *optional*): + Padding for mask cropping in inpainting. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + + Outputs: + processed_image (`Tensor`): + The processed image + processed_mask_image (`Tensor`): + The processed mask image + mask_overlay_kwargs (`dict`): + The kwargs for the postprocess step to apply the mask overlay + image_latents (`Tensor`): + The latent representation of the input image. + """ + + model_name = "qwenimage" + block_classes = [QwenImageInpaintProcessImagesInputStep(), QwenImageVaeEncoderStep()] + block_names = ["preprocess", "encode"] + + @property + def description(self) -> str: + return ( + "This step is used for processing image and mask inputs for inpainting tasks. It:\n" + " - Resizes the image to the target size, based on `height` and `width`.\n" + " - Processes and updates `image` and `mask_image`.\n" + " - Creates `image_latents`." + ) + + +# auto_docstring +class QwenImageImg2ImgVaeEncoderStep(SequentialPipelineBlocks): + """ + Vae encoder step that preprocess andencode the image inputs into their latent representations. + + Components: + image_processor (`VaeImageProcessor`) vae (`AutoencoderKLQwenImage`) + + Inputs: + image (`Image | list`): + Reference image(s) for denoising. Can be a single image or list of images. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + + Outputs: + processed_image (`Tensor`): + The processed image + image_latents (`Tensor`): + The latent representation of the input image. + """ + + model_name = "qwenimage" + + block_classes = [QwenImageProcessImagesInputStep(), QwenImageVaeEncoderStep()] + block_names = ["preprocess", "encode"] + + @property + def description(self) -> str: + return "Vae encoder step that preprocess andencode the image inputs into their latent representations." + + +class QwenImageAutoVaeEncoderStep(AutoPipelineBlocks): + block_classes = [QwenImageInpaintVaeEncoderStep, QwenImageImg2ImgVaeEncoderStep] + block_names = ["inpaint", "img2img"] + block_trigger_inputs = ["mask_image", "image"] + + @property + def description(self): + return ( + "Vae encoder step that encode the image inputs into their latent representations.\n" + + "This is an auto pipeline block.\n" + + " - `QwenImageInpaintVaeEncoderStep` (inpaint) is used when `mask_image` is provided.\n" + + " - `QwenImageImg2ImgVaeEncoderStep` (img2img) is used when `image` is provided.\n" + + " - if `mask_image` or `image` is not provided, step will be skipped." + ) + + +# optional controlnet vae encoder +# auto_docstring +class QwenImageOptionalControlNetVaeEncoderStep(AutoPipelineBlocks): + """ + Vae encoder step that encode the image inputs into their latent representations. + This is an auto pipeline block. + - `QwenImageControlNetVaeEncoderStep` (controlnet) is used when `control_image` is provided. + - if `control_image` is not provided, step will be skipped. + + Components: + vae (`AutoencoderKLQwenImage`) controlnet (`QwenImageControlNetModel`) control_image_processor + (`VaeImageProcessor`) + + Inputs: + control_image (`Image`, *optional*): + Control image for ControlNet conditioning. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + + Outputs: + control_image_latents (`Tensor`): + The latents representing the control image + """ + + block_classes = [QwenImageControlNetVaeEncoderStep] + block_names = ["controlnet"] + block_trigger_inputs = ["control_image"] + + @property + def description(self): + return ( + "Vae encoder step that encode the image inputs into their latent representations.\n" + + "This is an auto pipeline block.\n" + + " - `QwenImageControlNetVaeEncoderStep` (controlnet) is used when `control_image` is provided.\n" + + " - if `control_image` is not provided, step will be skipped." + ) + + +# ==================== +# 3. DENOISE (input -> prepare_latents -> set_timesteps -> prepare_rope_inputs -> denoise -> after_denoise) +# ==================== + + +# assemble input steps +# auto_docstring +class QwenImageImg2ImgInputStep(SequentialPipelineBlocks): + """ + Input step that prepares the inputs for the img2img denoising step. It: + + Components: + pachifier (`QwenImagePachifier`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + + Outputs: + batch_size (`int`): + The batch size of the prompt embeddings + dtype (`dtype`): + The data type of the prompt embeddings + prompt_embeds (`Tensor`): + The prompt embeddings. (batch-expanded) + prompt_embeds_mask (`Tensor`): + The encoder attention mask. (batch-expanded) + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. (batch-expanded) + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. (batch-expanded) + image_height (`int`): + The image height calculated from the image latents dimension + image_width (`int`): + The image width calculated from the image latents dimension + height (`int`): + if not provided, updated to image height + width (`int`): + if not provided, updated to image width + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified and + batch-expanded) + """ + + model_name = "qwenimage" + block_classes = [QwenImageTextInputsStep(), QwenImageAdditionalInputsStep()] + block_names = ["text_inputs", "additional_inputs"] + + @property + def description(self): + return "Input step that prepares the inputs for the img2img denoising step. It:\n" + " - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents`).\n" + " - update height/width based `image_latents`, patchify `image_latents`." + + +# auto_docstring +class QwenImageInpaintInputStep(SequentialPipelineBlocks): + """ + Input step that prepares the inputs for the inpainting denoising step. It: + + Components: + pachifier (`QwenImagePachifier`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`, *optional*): + image latents used to guide the image generation. Can be generated from vae_encoder step. + processed_mask_image (`Tensor`, *optional*): + The processed mask image + + Outputs: + batch_size (`int`): + The batch size of the prompt embeddings + dtype (`dtype`): + The data type of the prompt embeddings + prompt_embeds (`Tensor`): + The prompt embeddings. (batch-expanded) + prompt_embeds_mask (`Tensor`): + The encoder attention mask. (batch-expanded) + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. (batch-expanded) + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. (batch-expanded) + image_height (`int`): + The image height calculated from the image latents dimension + image_width (`int`): + The image width calculated from the image latents dimension + height (`int`): + if not provided, updated to image height + width (`int`): + if not provided, updated to image width + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified and + batch-expanded) + processed_mask_image (`Tensor`): + The processed mask image (batch-expanded) + """ + + model_name = "qwenimage" + block_classes = [ + QwenImageTextInputsStep(), + QwenImageAdditionalInputsStep( + additional_batch_inputs=[ + InputParam(name="processed_mask_image", type_hint=torch.Tensor, description="The processed mask image") + ] + ), + ] + block_names = ["text_inputs", "additional_inputs"] + + @property + def description(self): + return "Input step that prepares the inputs for the inpainting denoising step. It:\n" + " - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents` and `processed_mask_image`).\n" + " - update height/width based `image_latents`, patchify `image_latents`." + + +# assemble prepare latents steps +# auto_docstring +class QwenImageInpaintPrepareLatentsStep(SequentialPipelineBlocks): + """ + This step prepares the latents/image_latents and mask inputs for the inpainting denoising step. It: + - Add noise to the image latents to create the latents input for the denoiser. + - Create the pachified latents `mask` based on the processedmask image. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) pachifier (`QwenImagePachifier`) + + Inputs: + latents (`Tensor`): + The initial random noised, can be generated in prepare latent step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (Can be + generated from vae encoder and updated in input step.) + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + processed_mask_image (`Tensor`): + The processed mask to use for the inpainting process. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + dtype (`dtype`, *optional*, defaults to torch.float32): + The dtype of the model inputs, can be generated in input step. + + Outputs: + initial_noise (`Tensor`): + The initial random noised used for inpainting denoising. + latents (`Tensor`): + The scaled noisy latents to use for inpainting/image-to-image denoising. + mask (`Tensor`): + The mask to use for the inpainting process. + """ + + model_name = "qwenimage" + block_classes = [QwenImagePrepareLatentsWithStrengthStep(), QwenImageCreateMaskLatentsStep()] + block_names = ["add_noise_to_latents", "create_mask_latents"] + + @property + def description(self) -> str: + return ( + "This step prepares the latents/image_latents and mask inputs for the inpainting denoising step. It:\n" + " - Add noise to the image latents to create the latents input for the denoiser.\n" + " - Create the pachified latents `mask` based on the processedmask image.\n" + ) + + +# assemble denoising steps + + +# Qwen Image (text2image) +# auto_docstring +class QwenImageCoreDenoiseStep(SequentialPipelineBlocks): + """ + step that denoise noise into image for text2image task. It includes the denoise loop, as well as prepare the inputs + (timesteps, latents, rope inputs etc.). + + Components: + pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider + (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`list`, *optional*): + Custom sigmas for the denoising process. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "qwenimage" + block_classes = [ + QwenImageTextInputsStep(), + QwenImagePrepareLatentsStep(), + QwenImageSetTimestepsStep(), + QwenImageRoPEInputsStep(), + QwenImageDenoiseStep(), + QwenImageAfterDenoiseStep(), + ] + block_names = [ + "input", + "prepare_latents", + "set_timesteps", + "prepare_rope_inputs", + "denoise", + "after_denoise", + ] + + @property + def description(self): + return "step that denoise noise into image for text2image task. It includes the denoise loop, as well as prepare the inputs (timesteps, latents, rope inputs etc.)." + + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + + +# Qwen Image (inpainting) +# auto_docstring +class QwenImageInpaintCoreDenoiseStep(SequentialPipelineBlocks): + """ + Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for inpaint + task. + + Components: + pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider + (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`, *optional*): + image latents used to guide the image generation. Can be generated from vae_encoder step. + processed_mask_image (`Tensor`, *optional*): + The processed mask image + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`list`, *optional*): + Custom sigmas for the denoising process. + strength (`float`, *optional*, defaults to 0.9): + Strength for img2img/inpainting. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "qwenimage" + block_classes = [ + QwenImageInpaintInputStep(), + QwenImagePrepareLatentsStep(), + QwenImageSetTimestepsWithStrengthStep(), + QwenImageInpaintPrepareLatentsStep(), + QwenImageRoPEInputsStep(), + QwenImageInpaintDenoiseStep(), + QwenImageAfterDenoiseStep(), + ] + block_names = [ + "input", + "prepare_latents", + "set_timesteps", + "prepare_inpaint_latents", + "prepare_rope_inputs", + "denoise", + "after_denoise", + ] + + @property + def description(self): + return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for inpaint task." + + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + + +# Qwen Image (image2image) +# auto_docstring +class QwenImageImg2ImgCoreDenoiseStep(SequentialPipelineBlocks): + """ + Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for img2img + task. + + Components: + pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider + (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`list`, *optional*): + Custom sigmas for the denoising process. + strength (`float`, *optional*, defaults to 0.9): + Strength for img2img/inpainting. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "qwenimage" + block_classes = [ + QwenImageImg2ImgInputStep(), + QwenImagePrepareLatentsStep(), + QwenImageSetTimestepsWithStrengthStep(), + QwenImagePrepareLatentsWithStrengthStep(), + QwenImageRoPEInputsStep(), + QwenImageDenoiseStep(), + QwenImageAfterDenoiseStep(), + ] + block_names = [ + "input", + "prepare_latents", + "set_timesteps", + "prepare_img2img_latents", + "prepare_rope_inputs", + "denoise", + "after_denoise", + ] + + @property + def description(self): + return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for img2img task." + + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + + +# Qwen Image (text2image) with controlnet +# auto_docstring +class QwenImageControlNetCoreDenoiseStep(SequentialPipelineBlocks): + """ + step that denoise noise into image for text2image task. It includes the denoise loop, as well as prepare the inputs + (timesteps, latents, rope inputs etc.). + + Components: + pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) controlnet + (`QwenImageControlNetModel`) guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + control_image_latents (`Tensor`): + The control image latents to use for the denoising process. Can be generated in controlnet vae encoder + step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`list`, *optional*): + Custom sigmas for the denoising process. + control_guidance_start (`float`, *optional*, defaults to 0.0): + When to start applying ControlNet. + control_guidance_end (`float`, *optional*, defaults to 1.0): + When to stop applying ControlNet. + controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): + Scale for ControlNet conditioning. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "qwenimage" + block_classes = [ + QwenImageTextInputsStep(), + QwenImageControlNetInputsStep(), + QwenImagePrepareLatentsStep(), + QwenImageSetTimestepsStep(), + QwenImageRoPEInputsStep(), + QwenImageControlNetBeforeDenoiserStep(), + QwenImageControlNetDenoiseStep(), + QwenImageAfterDenoiseStep(), + ] + block_names = [ + "input", + "controlnet_input", + "prepare_latents", + "set_timesteps", + "prepare_rope_inputs", + "controlnet_before_denoise", + "controlnet_denoise", + "after_denoise", + ] + + @property + def description(self): + return "step that denoise noise into image for text2image task. It includes the denoise loop, as well as prepare the inputs (timesteps, latents, rope inputs etc.)." + + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + + +# Qwen Image (inpainting) with controlnet +# auto_docstring +class QwenImageControlNetInpaintCoreDenoiseStep(SequentialPipelineBlocks): + """ + Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for inpaint + task. + + Components: + pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) controlnet + (`QwenImageControlNetModel`) guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`, *optional*): + image latents used to guide the image generation. Can be generated from vae_encoder step. + processed_mask_image (`Tensor`, *optional*): + The processed mask image + control_image_latents (`Tensor`): + The control image latents to use for the denoising process. Can be generated in controlnet vae encoder + step. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`list`, *optional*): + Custom sigmas for the denoising process. + strength (`float`, *optional*, defaults to 0.9): + Strength for img2img/inpainting. + control_guidance_start (`float`, *optional*, defaults to 0.0): + When to start applying ControlNet. + control_guidance_end (`float`, *optional*, defaults to 1.0): + When to stop applying ControlNet. + controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): + Scale for ControlNet conditioning. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "qwenimage" + block_classes = [ + QwenImageInpaintInputStep(), + QwenImageControlNetInputsStep(), + QwenImagePrepareLatentsStep(), + QwenImageSetTimestepsWithStrengthStep(), + QwenImageInpaintPrepareLatentsStep(), + QwenImageRoPEInputsStep(), + QwenImageControlNetBeforeDenoiserStep(), + QwenImageInpaintControlNetDenoiseStep(), + QwenImageAfterDenoiseStep(), + ] + block_names = [ + "input", + "controlnet_input", + "prepare_latents", + "set_timesteps", + "prepare_inpaint_latents", + "prepare_rope_inputs", + "controlnet_before_denoise", + "controlnet_denoise", + "after_denoise", + ] + + @property + def description(self): + return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for inpaint task." + + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + + +# Qwen Image (image2image) with controlnet +# auto_docstring +class QwenImageControlNetImg2ImgCoreDenoiseStep(SequentialPipelineBlocks): + """ + Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for img2img + task. + + Components: + pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) controlnet + (`QwenImageControlNetModel`) guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + control_image_latents (`Tensor`): + The control image latents to use for the denoising process. Can be generated in controlnet vae encoder + step. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`list`, *optional*): + Custom sigmas for the denoising process. + strength (`float`, *optional*, defaults to 0.9): + Strength for img2img/inpainting. + control_guidance_start (`float`, *optional*, defaults to 0.0): + When to start applying ControlNet. + control_guidance_end (`float`, *optional*, defaults to 1.0): + When to stop applying ControlNet. + controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): + Scale for ControlNet conditioning. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "qwenimage" + block_classes = [ + QwenImageImg2ImgInputStep(), + QwenImageControlNetInputsStep(), + QwenImagePrepareLatentsStep(), + QwenImageSetTimestepsWithStrengthStep(), + QwenImagePrepareLatentsWithStrengthStep(), + QwenImageRoPEInputsStep(), + QwenImageControlNetBeforeDenoiserStep(), + QwenImageControlNetDenoiseStep(), + QwenImageAfterDenoiseStep(), + ] + block_names = [ + "input", + "controlnet_input", + "prepare_latents", + "set_timesteps", + "prepare_img2img_latents", + "prepare_rope_inputs", + "controlnet_before_denoise", + "controlnet_denoise", + "after_denoise", + ] + + @property + def description(self): + return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for img2img task." + + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + + +# Auto denoise step for QwenImage +class QwenImageAutoCoreDenoiseStep(ConditionalPipelineBlocks): + block_classes = [ + QwenImageCoreDenoiseStep, + QwenImageInpaintCoreDenoiseStep, + QwenImageImg2ImgCoreDenoiseStep, + QwenImageControlNetCoreDenoiseStep, + QwenImageControlNetInpaintCoreDenoiseStep, + QwenImageControlNetImg2ImgCoreDenoiseStep, + ] + block_names = [ + "text2image", + "inpaint", + "img2img", + "controlnet_text2image", + "controlnet_inpaint", + "controlnet_img2img", + ] + block_trigger_inputs = ["control_image_latents", "processed_mask_image", "image_latents"] + default_block_name = "text2image" + + def select_block(self, control_image_latents=None, processed_mask_image=None, image_latents=None): + if control_image_latents is not None: + if processed_mask_image is not None: + return "controlnet_inpaint" + elif image_latents is not None: + return "controlnet_img2img" + else: + return "controlnet_text2image" + else: + if processed_mask_image is not None: + return "inpaint" + elif image_latents is not None: + return "img2img" + else: + return "text2image" + + @property + def description(self): + return ( + "Core step that performs the denoising process. \n" + + " - `QwenImageCoreDenoiseStep` (text2image) for text2image tasks.\n" + + " - `QwenImageInpaintCoreDenoiseStep` (inpaint) for inpaint tasks.\n" + + " - `QwenImageImg2ImgCoreDenoiseStep` (img2img) for img2img tasks.\n" + + " - `QwenImageControlNetCoreDenoiseStep` (controlnet_text2image) for text2image tasks with controlnet.\n" + + " - `QwenImageControlNetInpaintCoreDenoiseStep` (controlnet_inpaint) for inpaint tasks with controlnet.\n" + + " - `QwenImageControlNetImg2ImgCoreDenoiseStep` (controlnet_img2img) for img2img tasks with controlnet.\n" + + "This step support text-to-image, image-to-image, inpainting, and controlnet tasks for QwenImage:\n" + + " - for image-to-image generation, you need to provide `image_latents`\n" + + " - for inpainting, you need to provide `processed_mask_image` and `image_latents`\n" + + " - to run the controlnet workflow, you need to provide `control_image_latents`\n" + + " - for text-to-image generation, all you need to provide is prompt embeddings" + ) + + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + + +# ==================== +# 4. DECODE +# ==================== + + +# standard decode step works for most tasks except for inpaint +# auto_docstring +class QwenImageDecodeStep(SequentialPipelineBlocks): + """ + Decode step that decodes the latents to images and postprocess the generated image. + + Components: + vae (`AutoencoderKLQwenImage`) image_processor (`VaeImageProcessor`) + + Inputs: + latents (`Tensor`): + The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise + step. + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + + Outputs: + images (`list`): + Generated images. (tensor output of the vae decoder.) + """ + + model_name = "qwenimage" + block_classes = [QwenImageDecoderStep(), QwenImageProcessImagesOutputStep()] + block_names = ["decode", "postprocess"] + + @property + def description(self): + return "Decode step that decodes the latents to images and postprocess the generated image." + + +# Inpaint decode step +# auto_docstring +class QwenImageInpaintDecodeStep(SequentialPipelineBlocks): + """ + Decode step that decodes the latents to images and postprocess the generated image, optional apply the mask + overally to the original image. + + Components: + vae (`AutoencoderKLQwenImage`) image_mask_processor (`InpaintProcessor`) + + Inputs: + latents (`Tensor`): + The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise + step. + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + mask_overlay_kwargs (`dict`, *optional*): + The kwargs for the postprocess step to apply the mask overlay. generated in + InpaintProcessImagesInputStep. + + Outputs: + images (`list`): + Generated images. (tensor output of the vae decoder.) + """ + + model_name = "qwenimage" + block_classes = [QwenImageDecoderStep(), QwenImageInpaintProcessImagesOutputStep()] + block_names = ["decode", "postprocess"] + + @property + def description(self): + return "Decode step that decodes the latents to images and postprocess the generated image, optional apply the mask overally to the original image." + + +# Auto decode step for QwenImage +class QwenImageAutoDecodeStep(AutoPipelineBlocks): + block_classes = [QwenImageInpaintDecodeStep, QwenImageDecodeStep] + block_names = ["inpaint_decode", "decode"] + block_trigger_inputs = ["mask", None] + + @property + def description(self): + return ( + "Decode step that decode the latents into images. \n" + " This is an auto pipeline block that works for inpaint/text2image/img2img tasks, for both QwenImage and QwenImage-Edit.\n" + + " - `QwenImageInpaintDecodeStep` (inpaint) is used when `mask` is provided.\n" + + " - `QwenImageDecodeStep` (text2image/img2img) is used when `mask` is not provided.\n" + ) + + +# ==================== +# 5. AUTO BLOCKS & PRESETS +# ==================== +AUTO_BLOCKS = InsertableDict( + [ + ("text_encoder", QwenImageAutoTextEncoderStep()), + ("vae_encoder", QwenImageAutoVaeEncoderStep()), + ("controlnet_vae_encoder", QwenImageOptionalControlNetVaeEncoderStep()), + ("denoise", QwenImageAutoCoreDenoiseStep()), + ("decode", QwenImageAutoDecodeStep()), + ] +) + + +# auto_docstring +class QwenImageAutoBlocks(SequentialPipelineBlocks): + """ + Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using QwenImage. + + Supported workflows: + - `text2image`: requires `prompt` + - `image2image`: requires `prompt`, `image` + - `inpainting`: requires `prompt`, `mask_image`, `image` + - `controlnet_text2image`: requires `prompt`, `control_image` + - `controlnet_image2image`: requires `prompt`, `image`, `control_image` + - `controlnet_inpainting`: requires `prompt`, `mask_image`, `image`, `control_image` + + Components: + text_encoder (`Qwen2_5_VLForConditionalGeneration`): The text encoder to use tokenizer (`Qwen2Tokenizer`): + The tokenizer to use guider (`ClassifierFreeGuidance`) image_mask_processor (`InpaintProcessor`) vae + (`AutoencoderKLQwenImage`) image_processor (`VaeImageProcessor`) controlnet (`QwenImageControlNetModel`) + control_image_processor (`VaeImageProcessor`) pachifier (`QwenImagePachifier`) scheduler + (`FlowMatchEulerDiscreteScheduler`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + prompt (`str`, *optional*): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + max_sequence_length (`int`, *optional*, defaults to 1024): + Maximum sequence length for prompt encoding. + mask_image (`Image`, *optional*): + Mask image for inpainting. + image (`Image | list`, *optional*): + Reference image(s) for denoising. Can be a single image or list of images. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + padding_mask_crop (`int`, *optional*): + Padding for mask cropping in inpainting. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + control_image (`Image`, *optional*): + Control image for ControlNet conditioning. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + latents (`Tensor`): + Pre-generated noisy latents for image generation. + num_inference_steps (`int`): + The number of denoising steps. + sigmas (`list`, *optional*): + Custom sigmas for the denoising process. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + image_latents (`Tensor`, *optional*): + image latents used to guide the image generation. Can be generated from vae_encoder step. + processed_mask_image (`Tensor`, *optional*): + The processed mask image + strength (`float`, *optional*, defaults to 0.9): + Strength for img2img/inpainting. + control_image_latents (`Tensor`, *optional*): + The control image latents to use for the denoising process. Can be generated in controlnet vae encoder + step. + control_guidance_start (`float`, *optional*, defaults to 0.0): + When to start applying ControlNet. + control_guidance_end (`float`, *optional*, defaults to 1.0): + When to stop applying ControlNet. + controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): + Scale for ControlNet conditioning. + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + mask_overlay_kwargs (`dict`, *optional*): + The kwargs for the postprocess step to apply the mask overlay. generated in + InpaintProcessImagesInputStep. + + Outputs: + images (`list`): + Generated images. + """ + + model_name = "qwenimage" + + block_classes = AUTO_BLOCKS.values() + block_names = AUTO_BLOCKS.keys() + + # Workflow map defines the trigger conditions for each workflow. + # How to define: + # - Only include required inputs and trigger inputs (inputs that determine which blocks run) + # - currently, only supports `True` means the workflow triggers when the input is not None + + _workflow_map = { + "text2image": {"prompt": True}, + "image2image": {"prompt": True, "image": True}, + "inpainting": {"prompt": True, "mask_image": True, "image": True}, + "controlnet_text2image": {"prompt": True, "control_image": True}, + "controlnet_image2image": {"prompt": True, "image": True, "control_image": True}, + "controlnet_inpainting": {"prompt": True, "mask_image": True, "image": True, "control_image": True}, + } + + @property + def description(self): + return "Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using QwenImage." + + @property + def outputs(self): + return [OutputParam.template("images")] diff --git a/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_edit.py b/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_edit.py new file mode 100644 index 0000000000000000000000000000000000000000..37b80b69ec7ec5f8d027f59bfabafbda07772dc4 --- /dev/null +++ b/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_edit.py @@ -0,0 +1,796 @@ +# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch + +from ...utils import logging +from ..modular_pipeline import AutoPipelineBlocks, ConditionalPipelineBlocks, SequentialPipelineBlocks +from ..modular_pipeline_utils import InputParam, InsertableDict, OutputParam +from .before_denoise import ( + QwenImageCreateMaskLatentsStep, + QwenImageEditRoPEInputsStep, + QwenImagePrepareLatentsStep, + QwenImagePrepareLatentsWithStrengthStep, + QwenImageSetTimestepsStep, + QwenImageSetTimestepsWithStrengthStep, +) +from .decoders import ( + QwenImageAfterDenoiseStep, + QwenImageDecoderStep, + QwenImageInpaintProcessImagesOutputStep, + QwenImageProcessImagesOutputStep, +) +from .denoise import ( + QwenImageEditDenoiseStep, + QwenImageEditInpaintDenoiseStep, +) +from .encoders import ( + QwenImageEditInpaintProcessImagesInputStep, + QwenImageEditProcessImagesInputStep, + QwenImageEditResizeStep, + QwenImageEditTextEncoderStep, + QwenImageVaeEncoderStep, +) +from .inputs import ( + QwenImageAdditionalInputsStep, + QwenImageTextInputsStep, +) + + +logger = logging.get_logger(__name__) + + +# ==================== +# 1. TEXT ENCODER +# ==================== + + +# auto_docstring +class QwenImageEditVLEncoderStep(SequentialPipelineBlocks): + """ + QwenImage-Edit VL encoder step that encode the image and text prompts together. + + Components: + image_resize_processor (`VaeImageProcessor`) text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor + (`Qwen2VLProcessor`) guider (`ClassifierFreeGuidance`) + + Inputs: + image (`Image | list`): + Reference image(s) for denoising. Can be a single image or list of images. + prompt (`str`): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + + Outputs: + resized_image (`list`): + The resized images + prompt_embeds (`Tensor`): + The prompt embeddings. + prompt_embeds_mask (`Tensor`): + The encoder attention mask. + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. + """ + + model_name = "qwenimage-edit" + block_classes = [ + QwenImageEditResizeStep(), + QwenImageEditTextEncoderStep(), + ] + block_names = ["resize", "encode"] + + @property + def description(self) -> str: + return "QwenImage-Edit VL encoder step that encode the image and text prompts together." + + +# ==================== +# 2. VAE ENCODER +# ==================== + + +# Edit VAE encoder +# auto_docstring +class QwenImageEditVaeEncoderStep(SequentialPipelineBlocks): + """ + Vae encoder step that encode the image inputs into their latent representations. + + Components: + image_resize_processor (`VaeImageProcessor`) image_processor (`VaeImageProcessor`) vae + (`AutoencoderKLQwenImage`) + + Inputs: + image (`Image | list`): + Reference image(s) for denoising. Can be a single image or list of images. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + + Outputs: + resized_image (`list`): + The resized images + processed_image (`Tensor`): + The processed image + image_latents (`Tensor`): + The latent representation of the input image. + """ + + model_name = "qwenimage-edit" + block_classes = [ + QwenImageEditResizeStep(), + QwenImageEditProcessImagesInputStep(), + QwenImageVaeEncoderStep(), + ] + block_names = ["resize", "preprocess", "encode"] + + @property + def description(self) -> str: + return "Vae encoder step that encode the image inputs into their latent representations." + + +# Edit Inpaint VAE encoder +# auto_docstring +class QwenImageEditInpaintVaeEncoderStep(SequentialPipelineBlocks): + """ + This step is used for processing image and mask inputs for QwenImage-Edit inpaint tasks. It: + - resize the image for target area (1024 * 1024) while maintaining the aspect ratio. + - process the resized image and mask image. + - create image latents. + + Components: + image_resize_processor (`VaeImageProcessor`) image_mask_processor (`InpaintProcessor`) vae + (`AutoencoderKLQwenImage`) + + Inputs: + image (`Image | list`): + Reference image(s) for denoising. Can be a single image or list of images. + mask_image (`Image`): + Mask image for inpainting. + padding_mask_crop (`int`, *optional*): + Padding for mask cropping in inpainting. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + + Outputs: + resized_image (`list`): + The resized images + processed_image (`Tensor`): + The processed image + processed_mask_image (`Tensor`): + The processed mask image + mask_overlay_kwargs (`dict`): + The kwargs for the postprocess step to apply the mask overlay + image_latents (`Tensor`): + The latent representation of the input image. + """ + + model_name = "qwenimage-edit" + block_classes = [ + QwenImageEditResizeStep(), + QwenImageEditInpaintProcessImagesInputStep(), + QwenImageVaeEncoderStep(), + ] + block_names = ["resize", "preprocess", "encode"] + + @property + def description(self) -> str: + return ( + "This step is used for processing image and mask inputs for QwenImage-Edit inpaint tasks. It:\n" + " - resize the image for target area (1024 * 1024) while maintaining the aspect ratio.\n" + " - process the resized image and mask image.\n" + " - create image latents." + ) + + +# Auto VAE encoder +class QwenImageEditAutoVaeEncoderStep(AutoPipelineBlocks): + block_classes = [QwenImageEditInpaintVaeEncoderStep, QwenImageEditVaeEncoderStep] + block_names = ["edit_inpaint", "edit"] + block_trigger_inputs = ["mask_image", "image"] + + @property + def description(self): + return ( + "Vae encoder step that encode the image inputs into their latent representations.\n" + "This is an auto pipeline block.\n" + " - `QwenImageEditInpaintVaeEncoderStep` (edit_inpaint) is used when `mask_image` is provided.\n" + " - `QwenImageEditVaeEncoderStep` (edit) is used when `image` is provided.\n" + " - if `mask_image` or `image` is not provided, step will be skipped." + ) + + +# ==================== +# 3. DENOISE (input -> prepare_latents -> set_timesteps -> prepare_rope_inputs -> denoise -> after_denoise) +# ==================== + + +# assemble input steps +# auto_docstring +class QwenImageEditInputStep(SequentialPipelineBlocks): + """ + Input step that prepares the inputs for the edit denoising step. It: + - make sure the text embeddings have consistent batch size as well as the additional inputs. + - update height/width based `image_latents`, patchify `image_latents`. + + Components: + pachifier (`QwenImagePachifier`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + + Outputs: + batch_size (`int`): + The batch size of the prompt embeddings + dtype (`dtype`): + The data type of the prompt embeddings + prompt_embeds (`Tensor`): + The prompt embeddings. (batch-expanded) + prompt_embeds_mask (`Tensor`): + The encoder attention mask. (batch-expanded) + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. (batch-expanded) + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. (batch-expanded) + image_height (`int`): + The image height calculated from the image latents dimension + image_width (`int`): + The image width calculated from the image latents dimension + height (`int`): + if not provided, updated to image height + width (`int`): + if not provided, updated to image width + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified and + batch-expanded) + """ + + model_name = "qwenimage-edit" + block_classes = [ + QwenImageTextInputsStep(), + QwenImageAdditionalInputsStep(), + ] + block_names = ["text_inputs", "additional_inputs"] + + @property + def description(self): + return ( + "Input step that prepares the inputs for the edit denoising step. It:\n" + " - make sure the text embeddings have consistent batch size as well as the additional inputs.\n" + " - update height/width based `image_latents`, patchify `image_latents`." + ) + + +# auto_docstring +class QwenImageEditInpaintInputStep(SequentialPipelineBlocks): + """ + Input step that prepares the inputs for the edit inpaint denoising step. It: + - make sure the text embeddings have consistent batch size as well as the additional inputs. + - update height/width based `image_latents`, patchify `image_latents`. + + Components: + pachifier (`QwenImagePachifier`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + processed_mask_image (`Tensor`, *optional*): + The processed mask image + + Outputs: + batch_size (`int`): + The batch size of the prompt embeddings + dtype (`dtype`): + The data type of the prompt embeddings + prompt_embeds (`Tensor`): + The prompt embeddings. (batch-expanded) + prompt_embeds_mask (`Tensor`): + The encoder attention mask. (batch-expanded) + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. (batch-expanded) + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. (batch-expanded) + image_height (`int`): + The image height calculated from the image latents dimension + image_width (`int`): + The image width calculated from the image latents dimension + height (`int`): + if not provided, updated to image height + width (`int`): + if not provided, updated to image width + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified and + batch-expanded) + processed_mask_image (`Tensor`): + The processed mask image (batch-expanded) + """ + + model_name = "qwenimage-edit" + block_classes = [ + QwenImageTextInputsStep(), + QwenImageAdditionalInputsStep( + additional_batch_inputs=[ + InputParam(name="processed_mask_image", type_hint=torch.Tensor, description="The processed mask image") + ] + ), + ] + block_names = ["text_inputs", "additional_inputs"] + + @property + def description(self): + return ( + "Input step that prepares the inputs for the edit inpaint denoising step. It:\n" + " - make sure the text embeddings have consistent batch size as well as the additional inputs.\n" + " - update height/width based `image_latents`, patchify `image_latents`." + ) + + +# assemble prepare latents steps +# auto_docstring +class QwenImageEditInpaintPrepareLatentsStep(SequentialPipelineBlocks): + """ + This step prepares the latents/image_latents and mask inputs for the edit inpainting denoising step. It: + - Add noise to the image latents to create the latents input for the denoiser. + - Create the patchified latents `mask` based on the processed mask image. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) pachifier (`QwenImagePachifier`) + + Inputs: + latents (`Tensor`): + The initial random noised, can be generated in prepare latent step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (Can be + generated from vae encoder and updated in input step.) + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + processed_mask_image (`Tensor`): + The processed mask to use for the inpainting process. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + dtype (`dtype`, *optional*, defaults to torch.float32): + The dtype of the model inputs, can be generated in input step. + + Outputs: + initial_noise (`Tensor`): + The initial random noised used for inpainting denoising. + latents (`Tensor`): + The scaled noisy latents to use for inpainting/image-to-image denoising. + mask (`Tensor`): + The mask to use for the inpainting process. + """ + + model_name = "qwenimage-edit" + block_classes = [QwenImagePrepareLatentsWithStrengthStep(), QwenImageCreateMaskLatentsStep()] + block_names = ["add_noise_to_latents", "create_mask_latents"] + + @property + def description(self) -> str: + return ( + "This step prepares the latents/image_latents and mask inputs for the edit inpainting denoising step. It:\n" + " - Add noise to the image latents to create the latents input for the denoiser.\n" + " - Create the patchified latents `mask` based on the processed mask image.\n" + ) + + +# Qwen Image Edit (image2image) core denoise step +# auto_docstring +class QwenImageEditCoreDenoiseStep(SequentialPipelineBlocks): + """ + Core denoising workflow for QwenImage-Edit edit (img2img) task. + + Components: + pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider + (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`list`, *optional*): + Custom sigmas for the denoising process. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "qwenimage-edit" + block_classes = [ + QwenImageEditInputStep(), + QwenImagePrepareLatentsStep(), + QwenImageSetTimestepsStep(), + QwenImageEditRoPEInputsStep(), + QwenImageEditDenoiseStep(), + QwenImageAfterDenoiseStep(), + ] + block_names = [ + "input", + "prepare_latents", + "set_timesteps", + "prepare_rope_inputs", + "denoise", + "after_denoise", + ] + + @property + def description(self): + return "Core denoising workflow for QwenImage-Edit edit (img2img) task." + + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + + +# Qwen Image Edit (inpainting) core denoise step +# auto_docstring +class QwenImageEditInpaintCoreDenoiseStep(SequentialPipelineBlocks): + """ + Core denoising workflow for QwenImage-Edit edit inpaint task. + + Components: + pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider + (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + processed_mask_image (`Tensor`, *optional*): + The processed mask image + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`list`, *optional*): + Custom sigmas for the denoising process. + strength (`float`, *optional*, defaults to 0.9): + Strength for img2img/inpainting. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "qwenimage-edit" + block_classes = [ + QwenImageEditInpaintInputStep(), + QwenImagePrepareLatentsStep(), + QwenImageSetTimestepsWithStrengthStep(), + QwenImageEditInpaintPrepareLatentsStep(), + QwenImageEditRoPEInputsStep(), + QwenImageEditInpaintDenoiseStep(), + QwenImageAfterDenoiseStep(), + ] + block_names = [ + "input", + "prepare_latents", + "set_timesteps", + "prepare_inpaint_latents", + "prepare_rope_inputs", + "denoise", + "after_denoise", + ] + + @property + def description(self): + return "Core denoising workflow for QwenImage-Edit edit inpaint task." + + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + + +# Auto core denoise step for QwenImage Edit +class QwenImageEditAutoCoreDenoiseStep(ConditionalPipelineBlocks): + model_name = "qwenimage-edit" + block_classes = [ + QwenImageEditInpaintCoreDenoiseStep, + QwenImageEditCoreDenoiseStep, + ] + block_names = ["edit_inpaint", "edit"] + block_trigger_inputs = ["processed_mask_image", "image_latents"] + default_block_name = "edit" + + def select_block(self, processed_mask_image=None, image_latents=None) -> str | None: + if processed_mask_image is not None: + return "edit_inpaint" + elif image_latents is not None: + return "edit" + return None + + @property + def description(self): + return ( + "Auto core denoising step that selects the appropriate workflow based on inputs.\n" + " - `QwenImageEditInpaintCoreDenoiseStep` when `processed_mask_image` is provided\n" + " - `QwenImageEditCoreDenoiseStep` when `image_latents` is provided\n" + "Supports edit (img2img) and edit inpainting tasks for QwenImage-Edit." + ) + + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + + +# ==================== +# 4. DECODE +# ==================== + + +# Decode step (standard) +# auto_docstring +class QwenImageEditDecodeStep(SequentialPipelineBlocks): + """ + Decode step that decodes the latents to images and postprocess the generated image. + + Components: + vae (`AutoencoderKLQwenImage`) image_processor (`VaeImageProcessor`) + + Inputs: + latents (`Tensor`): + The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise + step. + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + + Outputs: + images (`list`): + Generated images. (tensor output of the vae decoder.) + """ + + model_name = "qwenimage-edit" + block_classes = [QwenImageDecoderStep(), QwenImageProcessImagesOutputStep()] + block_names = ["decode", "postprocess"] + + @property + def description(self): + return "Decode step that decodes the latents to images and postprocess the generated image." + + +# Inpaint decode step +# auto_docstring +class QwenImageEditInpaintDecodeStep(SequentialPipelineBlocks): + """ + Decode step that decodes the latents to images and postprocess the generated image, optionally apply the mask + overlay to the original image. + + Components: + vae (`AutoencoderKLQwenImage`) image_mask_processor (`InpaintProcessor`) + + Inputs: + latents (`Tensor`): + The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise + step. + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + mask_overlay_kwargs (`dict`, *optional*): + The kwargs for the postprocess step to apply the mask overlay. generated in + InpaintProcessImagesInputStep. + + Outputs: + images (`list`): + Generated images. (tensor output of the vae decoder.) + """ + + model_name = "qwenimage-edit" + block_classes = [QwenImageDecoderStep(), QwenImageInpaintProcessImagesOutputStep()] + block_names = ["decode", "postprocess"] + + @property + def description(self): + return "Decode step that decodes the latents to images and postprocess the generated image, optionally apply the mask overlay to the original image." + + +# Auto decode step +class QwenImageEditAutoDecodeStep(AutoPipelineBlocks): + block_classes = [QwenImageEditInpaintDecodeStep, QwenImageEditDecodeStep] + block_names = ["inpaint_decode", "decode"] + block_trigger_inputs = ["mask", None] + + @property + def description(self): + return ( + "Decode step that decode the latents into images.\n" + "This is an auto pipeline block.\n" + " - `QwenImageEditInpaintDecodeStep` (inpaint) is used when `mask` is provided.\n" + " - `QwenImageEditDecodeStep` (edit) is used when `mask` is not provided.\n" + ) + + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + + +# ==================== +# 5. AUTO BLOCKS & PRESETS +# ==================== + +EDIT_AUTO_BLOCKS = InsertableDict( + [ + ("text_encoder", QwenImageEditVLEncoderStep()), + ("vae_encoder", QwenImageEditAutoVaeEncoderStep()), + ("denoise", QwenImageEditAutoCoreDenoiseStep()), + ("decode", QwenImageEditAutoDecodeStep()), + ] +) + + +# auto_docstring +class QwenImageEditAutoBlocks(SequentialPipelineBlocks): + """ + Auto Modular pipeline for edit (img2img) and edit inpaint tasks using QwenImage-Edit. + - for edit (img2img) generation, you need to provide `image` + - for edit inpainting, you need to provide `mask_image` and `image`, optionally you can provide + `padding_mask_crop` + + + Supported workflows: + - `image_conditioned`: requires `prompt`, `image` + - `image_conditioned_inpainting`: requires `prompt`, `mask_image`, `image` + + Components: + image_resize_processor (`VaeImageProcessor`) text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor + (`Qwen2VLProcessor`) guider (`ClassifierFreeGuidance`) image_mask_processor (`InpaintProcessor`) vae + (`AutoencoderKLQwenImage`) image_processor (`VaeImageProcessor`) pachifier (`QwenImagePachifier`) scheduler + (`FlowMatchEulerDiscreteScheduler`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + image (`Image | list`): + Reference image(s) for denoising. Can be a single image or list of images. + prompt (`str`): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + mask_image (`Image`, *optional*): + Mask image for inpainting. + padding_mask_crop (`int`, *optional*): + Padding for mask cropping in inpainting. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + processed_mask_image (`Tensor`, *optional*): + The processed mask image + latents (`Tensor`): + Pre-generated noisy latents for image generation. + num_inference_steps (`int`): + The number of denoising steps. + sigmas (`list`, *optional*): + Custom sigmas for the denoising process. + strength (`float`, *optional*, defaults to 0.9): + Strength for img2img/inpainting. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + mask_overlay_kwargs (`dict`, *optional*): + The kwargs for the postprocess step to apply the mask overlay. generated in + InpaintProcessImagesInputStep. + + Outputs: + images (`list`): + Generated images. + """ + + model_name = "qwenimage-edit" + block_classes = EDIT_AUTO_BLOCKS.values() + block_names = EDIT_AUTO_BLOCKS.keys() + _workflow_map = { + "image_conditioned": {"prompt": True, "image": True}, + "image_conditioned_inpainting": {"prompt": True, "mask_image": True, "image": True}, + } + + @property + def description(self): + return ( + "Auto Modular pipeline for edit (img2img) and edit inpaint tasks using QwenImage-Edit.\n" + "- for edit (img2img) generation, you need to provide `image`\n" + "- for edit inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop`\n" + ) + + @property + def outputs(self): + return [OutputParam.template("images")] diff --git a/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_edit_plus.py b/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_edit_plus.py new file mode 100644 index 0000000000000000000000000000000000000000..4a1f418d7b4508c2b5702c7d04c3767174287ab4 --- /dev/null +++ b/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_edit_plus.py @@ -0,0 +1,407 @@ +# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...utils import logging +from ..modular_pipeline import SequentialPipelineBlocks +from ..modular_pipeline_utils import InsertableDict, OutputParam +from .before_denoise import ( + QwenImageEditPlusRoPEInputsStep, + QwenImagePrepareLatentsStep, + QwenImageSetTimestepsStep, +) +from .decoders import ( + QwenImageAfterDenoiseStep, + QwenImageDecoderStep, + QwenImageProcessImagesOutputStep, +) +from .denoise import ( + QwenImageEditDenoiseStep, +) +from .encoders import ( + QwenImageEditPlusProcessImagesInputStep, + QwenImageEditPlusResizeStep, + QwenImageEditPlusTextEncoderStep, + QwenImageVaeEncoderStep, +) +from .inputs import ( + QwenImageEditPlusAdditionalInputsStep, + QwenImageTextInputsStep, +) + + +logger = logging.get_logger(__name__) + + +# ==================== +# 1. TEXT ENCODER +# ==================== + + +# auto_docstring +class QwenImageEditPlusVLEncoderStep(SequentialPipelineBlocks): + """ + QwenImage-Edit Plus VL encoder step that encodes the image and text prompts together. + + Components: + image_resize_processor (`VaeImageProcessor`) text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor + (`Qwen2VLProcessor`) guider (`ClassifierFreeGuidance`) + + Inputs: + image (`Image | list`): + Reference image(s) for denoising. Can be a single image or list of images. + prompt (`str`): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + + Outputs: + resized_image (`list`): + Images resized to 1024x1024 target area for VAE encoding + resized_cond_image (`list`): + Images resized to 384x384 target area for VL text encoding + prompt_embeds (`Tensor`): + The prompt embeddings. + prompt_embeds_mask (`Tensor`): + The encoder attention mask. + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. + """ + + model_name = "qwenimage-edit-plus" + block_classes = [ + QwenImageEditPlusResizeStep(), + QwenImageEditPlusTextEncoderStep(), + ] + block_names = ["resize", "encode"] + + @property + def description(self) -> str: + return "QwenImage-Edit Plus VL encoder step that encodes the image and text prompts together." + + +# ==================== +# 2. VAE ENCODER +# ==================== + + +# auto_docstring +class QwenImageEditPlusVaeEncoderStep(SequentialPipelineBlocks): + """ + VAE encoder step that encodes image inputs into latent representations. + Each image is resized independently based on its own aspect ratio to 1024x1024 target area. + + Components: + image_resize_processor (`VaeImageProcessor`) image_processor (`VaeImageProcessor`) vae + (`AutoencoderKLQwenImage`) + + Inputs: + image (`Image | list`): + Reference image(s) for denoising. Can be a single image or list of images. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + + Outputs: + resized_image (`list`): + Images resized to 1024x1024 target area for VAE encoding + resized_cond_image (`list`): + Images resized to 384x384 target area for VL text encoding + processed_image (`Tensor`): + The processed image + image_latents (`Tensor`): + The latent representation of the input image. + """ + + model_name = "qwenimage-edit-plus" + block_classes = [ + QwenImageEditPlusResizeStep(), + QwenImageEditPlusProcessImagesInputStep(), + QwenImageVaeEncoderStep(), + ] + block_names = ["resize", "preprocess", "encode"] + + @property + def description(self) -> str: + return ( + "VAE encoder step that encodes image inputs into latent representations.\n" + "Each image is resized independently based on its own aspect ratio to 1024x1024 target area." + ) + + +# ==================== +# 3. DENOISE (input -> prepare_latents -> set_timesteps -> prepare_rope_inputs -> denoise -> after_denoise) +# ==================== + + +# assemble input steps +# auto_docstring +class QwenImageEditPlusInputStep(SequentialPipelineBlocks): + """ + Input step that prepares the inputs for the Edit Plus denoising step. It: + - Standardizes text embeddings batch size. + - Processes list of image latents: patchifies, concatenates along dim=1, expands batch. + - Outputs lists of image_height/image_width for RoPE calculation. + - Defaults height/width from last image in the list. + + Components: + pachifier (`QwenImagePachifier`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + + Outputs: + batch_size (`int`): + The batch size of the prompt embeddings + dtype (`dtype`): + The data type of the prompt embeddings + prompt_embeds (`Tensor`): + The prompt embeddings. (batch-expanded) + prompt_embeds_mask (`Tensor`): + The encoder attention mask. (batch-expanded) + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. (batch-expanded) + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. (batch-expanded) + image_height (`list`): + The image heights calculated from the image latents dimension + image_width (`list`): + The image widths calculated from the image latents dimension + height (`int`): + if not provided, updated to image height + width (`int`): + if not provided, updated to image width + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified, + concatenated, and batch-expanded) + """ + + model_name = "qwenimage-edit-plus" + block_classes = [ + QwenImageTextInputsStep(), + QwenImageEditPlusAdditionalInputsStep(), + ] + block_names = ["text_inputs", "additional_inputs"] + + @property + def description(self): + return ( + "Input step that prepares the inputs for the Edit Plus denoising step. It:\n" + " - Standardizes text embeddings batch size.\n" + " - Processes list of image latents: patchifies, concatenates along dim=1, expands batch.\n" + " - Outputs lists of image_height/image_width for RoPE calculation.\n" + " - Defaults height/width from last image in the list." + ) + + +# Qwen Image Edit Plus (image2image) core denoise step +# auto_docstring +class QwenImageEditPlusCoreDenoiseStep(SequentialPipelineBlocks): + """ + Core denoising workflow for QwenImage-Edit Plus edit (img2img) task. + + Components: + pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider + (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`list`, *optional*): + Custom sigmas for the denoising process. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "qwenimage-edit-plus" + block_classes = [ + QwenImageEditPlusInputStep(), + QwenImagePrepareLatentsStep(), + QwenImageSetTimestepsStep(), + QwenImageEditPlusRoPEInputsStep(), + QwenImageEditDenoiseStep(), + QwenImageAfterDenoiseStep(), + ] + block_names = [ + "input", + "prepare_latents", + "set_timesteps", + "prepare_rope_inputs", + "denoise", + "after_denoise", + ] + + @property + def description(self): + return "Core denoising workflow for QwenImage-Edit Plus edit (img2img) task." + + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + + +# ==================== +# 4. DECODE +# ==================== + + +# auto_docstring +class QwenImageEditPlusDecodeStep(SequentialPipelineBlocks): + """ + Decode step that decodes the latents to images and postprocesses the generated image. + + Components: + vae (`AutoencoderKLQwenImage`) image_processor (`VaeImageProcessor`) + + Inputs: + latents (`Tensor`): + The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise + step. + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + + Outputs: + images (`list`): + Generated images. (tensor output of the vae decoder.) + """ + + model_name = "qwenimage-edit-plus" + block_classes = [QwenImageDecoderStep(), QwenImageProcessImagesOutputStep()] + block_names = ["decode", "postprocess"] + + @property + def description(self): + return "Decode step that decodes the latents to images and postprocesses the generated image." + + +# ==================== +# 5. AUTO BLOCKS & PRESETS +# ==================== + +EDIT_PLUS_AUTO_BLOCKS = InsertableDict( + [ + ("text_encoder", QwenImageEditPlusVLEncoderStep()), + ("vae_encoder", QwenImageEditPlusVaeEncoderStep()), + ("denoise", QwenImageEditPlusCoreDenoiseStep()), + ("decode", QwenImageEditPlusDecodeStep()), + ] +) + + +# auto_docstring +class QwenImageEditPlusAutoBlocks(SequentialPipelineBlocks): + """ + Auto Modular pipeline for edit (img2img) tasks using QwenImage-Edit Plus. + - `image` is required input (can be single image or list of images). + - Each image is resized independently based on its own aspect ratio. + - VL encoder uses 384x384 target area, VAE encoder uses 1024x1024 target area. + + Components: + image_resize_processor (`VaeImageProcessor`) text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor + (`Qwen2VLProcessor`) guider (`ClassifierFreeGuidance`) image_processor (`VaeImageProcessor`) vae + (`AutoencoderKLQwenImage`) pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) + transformer (`QwenImageTransformer2DModel`) + + Inputs: + image (`Image | list`): + Reference image(s) for denoising. Can be a single image or list of images. + prompt (`str`): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`list`, *optional*): + Custom sigmas for the denoising process. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + + Outputs: + images (`list`): + Generated images. + """ + + model_name = "qwenimage-edit-plus" + block_classes = EDIT_PLUS_AUTO_BLOCKS.values() + block_names = EDIT_PLUS_AUTO_BLOCKS.keys() + + @property + def description(self): + return ( + "Auto Modular pipeline for edit (img2img) tasks using QwenImage-Edit Plus.\n" + "- `image` is required input (can be single image or list of images).\n" + "- Each image is resized independently based on its own aspect ratio.\n" + "- VL encoder uses 384x384 target area, VAE encoder uses 1024x1024 target area." + ) + + @property + def outputs(self): + return [OutputParam.template("images")] diff --git a/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_layered.py b/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_layered.py new file mode 100644 index 0000000000000000000000000000000000000000..a10454f1fb0c9ec20a2624ce55c933efdf069214 --- /dev/null +++ b/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_layered.py @@ -0,0 +1,366 @@ +# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...utils import logging +from ..modular_pipeline import SequentialPipelineBlocks +from ..modular_pipeline_utils import InsertableDict, OutputParam +from .before_denoise import ( + QwenImageLayeredPrepareLatentsStep, + QwenImageLayeredRoPEInputsStep, + QwenImageLayeredSetTimestepsStep, +) +from .decoders import ( + QwenImageLayeredAfterDenoiseStep, + QwenImageLayeredDecoderStep, +) +from .denoise import ( + QwenImageLayeredDenoiseStep, +) +from .encoders import ( + QwenImageEditProcessImagesInputStep, + QwenImageLayeredGetImagePromptStep, + QwenImageLayeredPermuteLatentsStep, + QwenImageLayeredResizeStep, + QwenImageTextEncoderStep, + QwenImageVaeEncoderStep, +) +from .inputs import ( + QwenImageLayeredAdditionalInputsStep, + QwenImageTextInputsStep, +) + + +logger = logging.get_logger(__name__) + + +# ==================== +# 1. TEXT ENCODER +# ==================== + + +# auto_docstring +class QwenImageLayeredTextEncoderStep(SequentialPipelineBlocks): + """ + QwenImage-Layered Text encoder step that encode the text prompt, will generate a prompt based on image if not + provided. + + Components: + image_resize_processor (`VaeImageProcessor`) text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor + (`Qwen2VLProcessor`) tokenizer (`Qwen2Tokenizer`): The tokenizer to use guider (`ClassifierFreeGuidance`) + + Inputs: + image (`Image | list`): + Reference image(s) for denoising. Can be a single image or list of images. + resolution (`int`, *optional*, defaults to 640): + The target area to resize the image to, can be 1024 or 640 + prompt (`str`, *optional*): + The prompt or prompts to guide image generation. + use_en_prompt (`bool`, *optional*, defaults to False): + Whether to use English prompt template + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + max_sequence_length (`int`, *optional*, defaults to 1024): + Maximum sequence length for prompt encoding. + + Outputs: + resized_image (`list`): + The resized images + prompt (`str`): + The prompt or prompts to guide image generation. If not provided, updated using image caption + prompt_embeds (`Tensor`): + The prompt embeddings. + prompt_embeds_mask (`Tensor`): + The encoder attention mask. + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. + """ + + model_name = "qwenimage-layered" + block_classes = [ + QwenImageLayeredResizeStep(), + QwenImageLayeredGetImagePromptStep(), + QwenImageTextEncoderStep(), + ] + block_names = ["resize", "get_image_prompt", "encode"] + + @property + def description(self) -> str: + return "QwenImage-Layered Text encoder step that encode the text prompt, will generate a prompt based on image if not provided." + + +# ==================== +# 2. VAE ENCODER +# ==================== + + +# Edit VAE encoder +# auto_docstring +class QwenImageLayeredVaeEncoderStep(SequentialPipelineBlocks): + """ + Vae encoder step that encode the image inputs into their latent representations. + + Components: + image_resize_processor (`VaeImageProcessor`) image_processor (`VaeImageProcessor`) vae + (`AutoencoderKLQwenImage`) + + Inputs: + image (`Image | list`): + Reference image(s) for denoising. Can be a single image or list of images. + resolution (`int`, *optional*, defaults to 640): + The target area to resize the image to, can be 1024 or 640 + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + + Outputs: + resized_image (`list`): + The resized images + processed_image (`Tensor`): + The processed image + image_latents (`Tensor`): + The latent representation of the input image. + """ + + model_name = "qwenimage-layered" + block_classes = [ + QwenImageLayeredResizeStep(), + QwenImageEditProcessImagesInputStep(), + QwenImageVaeEncoderStep(), + QwenImageLayeredPermuteLatentsStep(), + ] + block_names = ["resize", "preprocess", "encode", "permute"] + + @property + def description(self) -> str: + return "Vae encoder step that encode the image inputs into their latent representations." + + +# ==================== +# 3. DENOISE (input -> prepare_latents -> set_timesteps -> prepare_rope_inputs -> denoise -> after_denoise) +# ==================== + + +# assemble input steps +# auto_docstring +class QwenImageLayeredInputStep(SequentialPipelineBlocks): + """ + Input step that prepares the inputs for the layered denoising step. It: + - make sure the text embeddings have consistent batch size as well as the additional inputs. + - update height/width based `image_latents`, patchify `image_latents`. + + Components: + pachifier (`QwenImageLayeredPachifier`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + + Outputs: + batch_size (`int`): + The batch size of the prompt embeddings + dtype (`dtype`): + The data type of the prompt embeddings + prompt_embeds (`Tensor`): + The prompt embeddings. (batch-expanded) + prompt_embeds_mask (`Tensor`): + The encoder attention mask. (batch-expanded) + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. (batch-expanded) + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. (batch-expanded) + image_height (`int`): + The image height calculated from the image latents dimension + image_width (`int`): + The image width calculated from the image latents dimension + height (`int`): + if not provided, updated to image height + width (`int`): + if not provided, updated to image width + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified + with layered pachifier and batch-expanded) + """ + + model_name = "qwenimage-layered" + block_classes = [ + QwenImageTextInputsStep(), + QwenImageLayeredAdditionalInputsStep(), + ] + block_names = ["text_inputs", "additional_inputs"] + + @property + def description(self): + return ( + "Input step that prepares the inputs for the layered denoising step. It:\n" + " - make sure the text embeddings have consistent batch size as well as the additional inputs.\n" + " - update height/width based `image_latents`, patchify `image_latents`." + ) + + +# Qwen Image Layered (image2image) core denoise step +# auto_docstring +class QwenImageLayeredCoreDenoiseStep(SequentialPipelineBlocks): + """ + Core denoising workflow for QwenImage-Layered img2img task. + + Components: + pachifier (`QwenImageLayeredPachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider + (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + layers (`int`, *optional*, defaults to 4): + Number of layers to extract from the image + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`list`, *optional*): + Custom sigmas for the denoising process. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "qwenimage-layered" + block_classes = [ + QwenImageLayeredInputStep(), + QwenImageLayeredPrepareLatentsStep(), + QwenImageLayeredSetTimestepsStep(), + QwenImageLayeredRoPEInputsStep(), + QwenImageLayeredDenoiseStep(), + QwenImageLayeredAfterDenoiseStep(), + ] + block_names = [ + "input", + "prepare_latents", + "set_timesteps", + "prepare_rope_inputs", + "denoise", + "after_denoise", + ] + + @property + def description(self): + return "Core denoising workflow for QwenImage-Layered img2img task." + + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + + +# ==================== +# 4. AUTO BLOCKS & PRESETS +# ==================== + +LAYERED_AUTO_BLOCKS = InsertableDict( + [ + ("text_encoder", QwenImageLayeredTextEncoderStep()), + ("vae_encoder", QwenImageLayeredVaeEncoderStep()), + ("denoise", QwenImageLayeredCoreDenoiseStep()), + ("decode", QwenImageLayeredDecoderStep()), + ] +) + + +# auto_docstring +class QwenImageLayeredAutoBlocks(SequentialPipelineBlocks): + """ + Auto Modular pipeline for layered denoising tasks using QwenImage-Layered. + + Components: + image_resize_processor (`VaeImageProcessor`) text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor + (`Qwen2VLProcessor`) tokenizer (`Qwen2Tokenizer`): The tokenizer to use guider (`ClassifierFreeGuidance`) + image_processor (`VaeImageProcessor`) vae (`AutoencoderKLQwenImage`) pachifier (`QwenImageLayeredPachifier`) + scheduler (`FlowMatchEulerDiscreteScheduler`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + image (`Image | list`): + Reference image(s) for denoising. Can be a single image or list of images. + resolution (`int`, *optional*, defaults to 640): + The target area to resize the image to, can be 1024 or 640 + prompt (`str`, *optional*): + The prompt or prompts to guide image generation. + use_en_prompt (`bool`, *optional*, defaults to False): + Whether to use English prompt template + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + max_sequence_length (`int`, *optional*, defaults to 1024): + Maximum sequence length for prompt encoding. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + layers (`int`, *optional*, defaults to 4): + Number of layers to extract from the image + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`list`, *optional*): + Custom sigmas for the denoising process. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + + Outputs: + images (`list`): + Generated images. + """ + + model_name = "qwenimage-layered" + block_classes = LAYERED_AUTO_BLOCKS.values() + block_names = LAYERED_AUTO_BLOCKS.keys() + + @property + def description(self): + return "Auto Modular pipeline for layered denoising tasks using QwenImage-Layered." + + @property + def outputs(self): + return [OutputParam.template("images")] diff --git a/diffusers/modular_pipelines/qwenimage/modular_pipeline.py b/diffusers/modular_pipelines/qwenimage/modular_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..892435989d00fc07526f3d2c5eec6efc4380a2fa --- /dev/null +++ b/diffusers/modular_pipelines/qwenimage/modular_pipeline.py @@ -0,0 +1,297 @@ +# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import QwenImageLoraLoaderMixin +from ..modular_pipeline import ModularPipeline + + +class QwenImagePachifier(ConfigMixin): + """ + A class to pack and unpack latents for QwenImage. + """ + + config_name = "config.json" + + @register_to_config + def __init__(self, patch_size: int = 2): + super().__init__() + + def pack_latents(self, latents): + if latents.ndim != 4 and latents.ndim != 5: + raise ValueError(f"Latents must have 4 or 5 dimensions, but got {latents.ndim}") + + if latents.ndim == 4: + latents = latents.unsqueeze(2) + + batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width = latents.shape + patch_size = self.config.patch_size + + if latent_height % patch_size != 0 or latent_width % patch_size != 0: + raise ValueError( + f"Latent height and width must be divisible by {patch_size}, but got {latent_height} and {latent_width}" + ) + + latents = latents.view( + batch_size, + num_channels_latents, + latent_height // patch_size, + patch_size, + latent_width // patch_size, + patch_size, + ) + latents = latents.permute( + 0, 2, 4, 1, 3, 5 + ) # Batch_size, num_patches_height, num_patches_width, num_channels_latents, patch_size, patch_size + latents = latents.reshape( + batch_size, + (latent_height // patch_size) * (latent_width // patch_size), + num_channels_latents * patch_size * patch_size, + ) + + return latents + + def unpack_latents(self, latents, height, width, vae_scale_factor=8): + if latents.ndim != 3: + raise ValueError(f"Latents must have 3 dimensions, but got {latents.ndim}") + + batch_size, num_patches, channels = latents.shape + patch_size = self.config.patch_size + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = patch_size * (int(height) // (vae_scale_factor * patch_size)) + width = patch_size * (int(width) // (vae_scale_factor * patch_size)) + + latents = latents.view( + batch_size, + height // patch_size, + width // patch_size, + channels // (patch_size * patch_size), + patch_size, + patch_size, + ) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (patch_size * patch_size), 1, height, width) + + return latents + + +class QwenImageLayeredPachifier(ConfigMixin): + """ + A class to pack and unpack latents for QwenImage Layered. + + Unlike QwenImagePachifier, this handles 5D latents with shape (B, layers+1, C, H, W). + """ + + config_name = "config.json" + + @register_to_config + def __init__(self, patch_size: int = 2): + super().__init__() + + def pack_latents(self, latents): + """ + Pack latents from (B, layers, C, H, W) to (B, layers * H/2 * W/2, C*4). + """ + + if latents.ndim != 5: + raise ValueError(f"Latents must have 5 dimensions (B, layers, C, H, W), but got {latents.ndim}") + + batch_size, layers, num_channels_latents, latent_height, latent_width = latents.shape + patch_size = self.config.patch_size + + if latent_height % patch_size != 0 or latent_width % patch_size != 0: + raise ValueError( + f"Latent height and width must be divisible by {patch_size}, but got {latent_height} and {latent_width}" + ) + + latents = latents.view( + batch_size, + layers, + num_channels_latents, + latent_height // patch_size, + patch_size, + latent_width // patch_size, + patch_size, + ) + latents = latents.permute(0, 1, 3, 5, 2, 4, 6) + latents = latents.reshape( + batch_size, + layers * (latent_height // patch_size) * (latent_width // patch_size), + num_channels_latents * patch_size * patch_size, + ) + return latents + + def unpack_latents(self, latents, height, width, layers, vae_scale_factor=8): + """ + Unpack latents from (B, seq, C*4) to (B, C, layers+1, H, W). + """ + + if latents.ndim != 3: + raise ValueError(f"Latents must have 3 dimensions, but got {latents.ndim}") + + batch_size, _, channels = latents.shape + patch_size = self.config.patch_size + + height = patch_size * (int(height) // (vae_scale_factor * patch_size)) + width = patch_size * (int(width) // (vae_scale_factor * patch_size)) + + latents = latents.view( + batch_size, + layers + 1, + height // patch_size, + width // patch_size, + channels // (patch_size * patch_size), + patch_size, + patch_size, + ) + latents = latents.permute(0, 1, 4, 2, 5, 3, 6) + latents = latents.reshape( + batch_size, + layers + 1, + channels // (patch_size * patch_size), + height, + width, + ) + latents = latents.permute(0, 2, 1, 3, 4) # (b, c, f, h, w) + + return latents + + +class QwenImageModularPipeline(ModularPipeline, QwenImageLoraLoaderMixin): + """ + A ModularPipeline for QwenImage. + + > [!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "QwenImageAutoBlocks" + + @property + def default_height(self): + return self.default_sample_size * self.vae_scale_factor + + @property + def default_width(self): + return self.default_sample_size * self.vae_scale_factor + + @property + def default_sample_size(self): + return 128 + + @property + def vae_scale_factor(self): + vae_scale_factor = 8 + if hasattr(self, "vae") and self.vae is not None: + vae_scale_factor = 2 ** len(self.vae.temperal_downsample) + return vae_scale_factor + + @property + def num_channels_latents(self): + num_channels_latents = 16 + if hasattr(self, "transformer") and self.transformer is not None: + num_channels_latents = self.transformer.config.in_channels // 4 + return num_channels_latents + + @property + def is_guidance_distilled(self): + is_guidance_distilled = False + if hasattr(self, "transformer") and self.transformer is not None: + is_guidance_distilled = self.transformer.config.guidance_embeds + return is_guidance_distilled + + @property + def requires_unconditional_embeds(self): + requires_unconditional_embeds = False + + if hasattr(self, "guider") and self.guider is not None: + requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1 + + return requires_unconditional_embeds + + +class QwenImageEditModularPipeline(ModularPipeline, QwenImageLoraLoaderMixin): + """ + A ModularPipeline for QwenImage-Edit. + + > [!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "QwenImageEditAutoBlocks" + + # YiYi TODO: qwen edit should not provide default height/width, should be derived from the resized input image (after adjustment) produced by the resize step. + @property + def default_height(self): + return self.default_sample_size * self.vae_scale_factor + + @property + def default_width(self): + return self.default_sample_size * self.vae_scale_factor + + @property + def default_sample_size(self): + return 128 + + @property + def vae_scale_factor(self): + vae_scale_factor = 8 + if hasattr(self, "vae") and self.vae is not None: + vae_scale_factor = 2 ** len(self.vae.temperal_downsample) + return vae_scale_factor + + @property + def num_channels_latents(self): + num_channels_latents = 16 + if hasattr(self, "transformer") and self.transformer is not None: + num_channels_latents = self.transformer.config.in_channels // 4 + return num_channels_latents + + @property + def is_guidance_distilled(self): + is_guidance_distilled = False + if hasattr(self, "transformer") and self.transformer is not None: + is_guidance_distilled = self.transformer.config.guidance_embeds + return is_guidance_distilled + + @property + def requires_unconditional_embeds(self): + requires_unconditional_embeds = False + + if hasattr(self, "guider") and self.guider is not None: + requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1 + + return requires_unconditional_embeds + + +class QwenImageEditPlusModularPipeline(QwenImageEditModularPipeline): + """ + A ModularPipeline for QwenImage-Edit Plus. + + > [!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "QwenImageEditPlusAutoBlocks" + + +class QwenImageLayeredModularPipeline(QwenImageModularPipeline): + """ + A ModularPipeline for QwenImage-Layered. + + > [!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "QwenImageLayeredAutoBlocks" diff --git a/diffusers/modular_pipelines/qwenimage/prompt_templates.py b/diffusers/modular_pipelines/qwenimage/prompt_templates.py new file mode 100644 index 0000000000000000000000000000000000000000..8e7beb555760a5f987342e406f05e0697a9c7bc3 --- /dev/null +++ b/diffusers/modular_pipelines/qwenimage/prompt_templates.py @@ -0,0 +1,121 @@ +# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Prompt templates for QwenImage pipelines. + +This module centralizes all prompt templates used across different QwenImage pipeline variants: +- QwenImage (base): Text-only encoding for text-to-image generation +- QwenImage Edit: VL encoding with single image for image editing +- QwenImage Edit Plus: VL encoding with multiple images for multi-reference editing +- QwenImage Layered: Auto-captioning for image decomposition +""" + +# ============================================ +# QwenImage Base (text-only encoding) +# ============================================ +# Used for text-to-image generation where only text prompt is encoded + +QWENIMAGE_PROMPT_TEMPLATE = ( + "<|im_start|>system\n" + "Describe the image by detailing the color, shape, size, texture, quantity, text, " + "spatial relationships of the objects and background:<|im_end|>\n" + "<|im_start|>user\n{}<|im_end|>\n" + "<|im_start|>assistant\n" +) +QWENIMAGE_PROMPT_TEMPLATE_START_IDX = 34 + + +# ============================================ +# QwenImage Edit (VL encoding with single image) +# ============================================ +# Used for single-image editing where both image and text are encoded together + +QWENIMAGE_EDIT_PROMPT_TEMPLATE = ( + "<|im_start|>system\n" + "Describe the key features of the input image (color, shape, size, texture, objects, background), " + "then explain how the user's text instruction should alter or modify the image. " + "Generate a new image that meets the user's requirements while maintaining consistency " + "with the original input where appropriate.<|im_end|>\n" + "<|im_start|>user\n" + "<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n" + "<|im_start|>assistant\n" +) +QWENIMAGE_EDIT_PROMPT_TEMPLATE_START_IDX = 64 + + +# ============================================ +# QwenImage Edit Plus (VL encoding with multiple images) +# ============================================ +# Used for multi-reference editing where multiple images and text are encoded together +# The img_template is used to format each image in the prompt + +QWENIMAGE_EDIT_PLUS_PROMPT_TEMPLATE = ( + "<|im_start|>system\n" + "Describe the key features of the input image (color, shape, size, texture, objects, background), " + "then explain how the user's text instruction should alter or modify the image. " + "Generate a new image that meets the user's requirements while maintaining consistency " + "with the original input where appropriate.<|im_end|>\n" + "<|im_start|>user\n{}<|im_end|>\n" + "<|im_start|>assistant\n" +) +QWENIMAGE_EDIT_PLUS_IMG_TEMPLATE = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>" +QWENIMAGE_EDIT_PLUS_PROMPT_TEMPLATE_START_IDX = 64 + + +# ============================================ +# QwenImage Layered (auto-captioning) +# ============================================ +# Used for image decomposition where the VL model generates a caption from the input image +# if no prompt is provided. These prompts instruct the model to describe the image in detail. + +QWENIMAGE_LAYERED_CAPTION_PROMPT_EN = ( + "<|im_start|>system\n" + "You are a helpful assistant.<|im_end|>\n" + "<|im_start|>user\n" + "# Image Annotator\n" + "You are a professional image annotator. Please write an image caption based on the input image:\n" + "1. Write the caption using natural, descriptive language without structured formats or rich text.\n" + "2. Enrich caption details by including:\n" + " - Object attributes, such as quantity, color, shape, size, material, state, position, actions, and so on\n" + " - Vision Relations between objects, such as spatial relations, functional relations, possessive relations, " + "attachment relations, action relations, comparative relations, causal relations, and so on\n" + " - Environmental details, such as weather, lighting, colors, textures, atmosphere, and so on\n" + " - Identify the text clearly visible in the image, without translation or explanation, " + "and highlight it in the caption with quotation marks\n" + "3. Maintain authenticity and accuracy:\n" + " - Avoid generalizations\n" + " - Describe all visible information in the image, while do not add information not explicitly shown in the image\n" + "<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n" + "<|im_start|>assistant\n" +) + +QWENIMAGE_LAYERED_CAPTION_PROMPT_CN = ( + "<|im_start|>system\n" + "You are a helpful assistant.<|im_end|>\n" + "<|im_start|>user\n" + "# 图像标注器\n" + "你是一个专业的图像标注器。请基于输入图像,撰写图注:\n" + "1. 使用自然、描述性的语言撰写图注,不要使用结构化形式或富文本形式。\n" + "2. 通过加入以下内容,丰富图注细节:\n" + " - 对象的属性:如数量、颜色、形状、大小、位置、材质、状态、动作等\n" + " - 对象间的视觉关系:如空间关系、功能关系、动作关系、从属关系、比较关系、因果关系等\n" + " - 环境细节:例如天气、光照、颜色、纹理、气氛等\n" + " - 文字内容:识别图像中清晰可见的文字,不做翻译和解释,用引号在图注中强调\n" + "3. 保持真实性与准确性:\n" + " - 不要使用笼统的描述\n" + " - 描述图像中所有可见的信息,但不要加入没有在图像中出现的内容\n" + "<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n" + "<|im_start|>assistant\n" +) diff --git a/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py b/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..44f1c555cef3586ac9c05d53d45773040e3e77b1 --- /dev/null +++ b/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py @@ -0,0 +1,47 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["modular_blocks_stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks"] + _import_structure["modular_pipeline"] = ["StableDiffusionXLModularPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .modular_blocks_stable_diffusion_xl import StableDiffusionXLAutoBlocks + from .modular_pipeline import StableDiffusionXLModularPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py b/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py new file mode 100644 index 0000000000000000000000000000000000000000..4a393e7ce296f13d0cb470535ff9a354dffcbdd6 --- /dev/null +++ b/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py @@ -0,0 +1,1874 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any + +import PIL +import torch + +from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKL, ControlNetModel, ControlNetUnionModel, UNet2DConditionModel +from ...models.controlnets.multicontrolnet import MultiControlNetModel +from ...schedulers import EulerDiscreteScheduler +from ...utils import logging +from ...utils.torch_utils import randn_tensor, unwrap_module +from ..modular_pipeline import ( + ModularPipelineBlocks, + PipelineState, +) +from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam +from .modular_pipeline import StableDiffusionXLModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# TODO(yiyi, aryan): We need another step before text encoder to set the `num_inference_steps` attribute for guider so that +# things like when to do guidance and how many conditions to be prepared can be determined. Currently, this is done by +# always assuming you want to do guidance in the Guiders. So, negative embeddings are prepared regardless of what the +# configuration of guider is. + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +def prepare_latents_img2img( + vae, scheduler, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True +): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError(f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}") + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + + else: + latents_mean = latents_std = None + if hasattr(vae.config, "latents_mean") and vae.config.latents_mean is not None: + latents_mean = torch.tensor(vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(vae.config, "latents_std") and vae.config.latents_std is not None: + latents_std = torch.tensor(vae.config.latents_std).view(1, 4, 1, 1) + # make sure the VAE is in float32 mode, as it overflows in float16 + if vae.config.force_upcast: + image = image.float() + vae.to(dtype=torch.float32) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: + image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) + elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " + ) + + init_latents = [ + retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = retrieve_latents(vae.encode(image), generator=generator) + + if vae.config.force_upcast: + vae.to(dtype) + + init_latents = init_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=device, dtype=dtype) + latents_std = latents_std.to(device=device, dtype=dtype) + init_latents = (init_latents - latents_mean) * vae.config.scaling_factor / latents_std + else: + init_latents = vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + if add_noise: + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # get latents + init_latents = scheduler.add_noise(init_latents, noise, timestep) + + latents = init_latents + + return latents + + +class StableDiffusionXLInputStep(ModularPipelineBlocks): + model_name = "stable-diffusion-xl" + + @property + def description(self) -> str: + return ( + "Input processing step that:\n" + " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n" + " 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_images_per_prompt`\n\n" + "All input tensors are expected to have either batch_size=1 or match the batch_size\n" + "of prompt_embeds. The tensors will be duplicated across the batch dimension to\n" + "have a final batch_size of batch_size * num_images_per_prompt." + ) + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("num_images_per_prompt", default=1), + InputParam( + "prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="Pre-generated text embeddings. Can be generated from text_encoder step.", + ), + InputParam( + "negative_prompt_embeds", + type_hint=torch.Tensor, + description="Pre-generated negative text embeddings. Can be generated from text_encoder step.", + ), + InputParam( + "pooled_prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="Pre-generated pooled text embeddings. Can be generated from text_encoder step.", + ), + InputParam( + "negative_pooled_prompt_embeds", + description="Pre-generated negative pooled text embeddings. Can be generated from text_encoder step.", + ), + InputParam( + "ip_adapter_embeds", + type_hint=list[torch.Tensor], + description="Pre-generated image embeddings for IP-Adapter. Can be generated from ip_adapter step.", + ), + InputParam( + "negative_ip_adapter_embeds", + type_hint=list[torch.Tensor], + description="Pre-generated negative image embeddings for IP-Adapter. Can be generated from ip_adapter step.", + ), + ] + + @property + def intermediate_outputs(self) -> list[str]: + return [ + OutputParam( + "batch_size", + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt", + ), + OutputParam( + "dtype", + type_hint=torch.dtype, + description="Data type of model tensor inputs (determined by `prompt_embeds`)", + ), + OutputParam( + "prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", # already in intermedites state but declare here again for denoiser_input_fields + description="text embeddings used to guide the image generation", + ), + OutputParam( + "negative_prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", # already in intermedites state but declare here again for denoiser_input_fields + description="negative text embeddings used to guide the image generation", + ), + OutputParam( + "pooled_prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", # already in intermedites state but declare here again for denoiser_input_fields + description="pooled text embeddings used to guide the image generation", + ), + OutputParam( + "negative_pooled_prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", # already in intermedites state but declare here again for denoiser_input_fields + description="negative pooled text embeddings used to guide the image generation", + ), + OutputParam( + "ip_adapter_embeds", + type_hint=list[torch.Tensor], + kwargs_type="denoiser_input_fields", # already in intermedites state but declare here again for denoiser_input_fields + description="image embeddings for IP-Adapter", + ), + OutputParam( + "negative_ip_adapter_embeds", + type_hint=list[torch.Tensor], + kwargs_type="denoiser_input_fields", # already in intermedites state but declare here again for denoiser_input_fields + description="negative image embeddings for IP-Adapter", + ), + ] + + def check_inputs(self, components, block_state): + if block_state.prompt_embeds is not None and block_state.negative_prompt_embeds is not None: + if block_state.prompt_embeds.shape != block_state.negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {block_state.prompt_embeds.shape} != `negative_prompt_embeds`" + f" {block_state.negative_prompt_embeds.shape}." + ) + + if block_state.prompt_embeds is not None and block_state.pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if block_state.negative_prompt_embeds is not None and block_state.negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if block_state.ip_adapter_embeds is not None and not isinstance(block_state.ip_adapter_embeds, list): + raise ValueError("`ip_adapter_embeds` must be a list") + + if block_state.negative_ip_adapter_embeds is not None and not isinstance( + block_state.negative_ip_adapter_embeds, list + ): + raise ValueError("`negative_ip_adapter_embeds` must be a list") + + if block_state.ip_adapter_embeds is not None and block_state.negative_ip_adapter_embeds is not None: + for i, ip_adapter_embed in enumerate(block_state.ip_adapter_embeds): + if ip_adapter_embed.shape != block_state.negative_ip_adapter_embeds[i].shape: + raise ValueError( + "`ip_adapter_embeds` and `negative_ip_adapter_embeds` must have the same shape when passed directly, but" + f" got: `ip_adapter_embeds` {ip_adapter_embed.shape} != `negative_ip_adapter_embeds`" + f" {block_state.negative_ip_adapter_embeds[i].shape}." + ) + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(components, block_state) + + block_state.batch_size = block_state.prompt_embeds.shape[0] + block_state.dtype = block_state.prompt_embeds.dtype + + _, seq_len, _ = block_state.prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) + block_state.prompt_embeds = block_state.prompt_embeds.view( + block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1 + ) + + if block_state.negative_prompt_embeds is not None: + _, seq_len, _ = block_state.negative_prompt_embeds.shape + block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat( + 1, block_state.num_images_per_prompt, 1 + ) + block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view( + block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1 + ) + + block_state.pooled_prompt_embeds = block_state.pooled_prompt_embeds.repeat( + 1, block_state.num_images_per_prompt, 1 + ) + block_state.pooled_prompt_embeds = block_state.pooled_prompt_embeds.view( + block_state.batch_size * block_state.num_images_per_prompt, -1 + ) + + if block_state.negative_pooled_prompt_embeds is not None: + block_state.negative_pooled_prompt_embeds = block_state.negative_pooled_prompt_embeds.repeat( + 1, block_state.num_images_per_prompt, 1 + ) + block_state.negative_pooled_prompt_embeds = block_state.negative_pooled_prompt_embeds.view( + block_state.batch_size * block_state.num_images_per_prompt, -1 + ) + + if block_state.ip_adapter_embeds is not None: + for i, ip_adapter_embed in enumerate(block_state.ip_adapter_embeds): + block_state.ip_adapter_embeds[i] = torch.cat( + [ip_adapter_embed] * block_state.num_images_per_prompt, dim=0 + ) + + if block_state.negative_ip_adapter_embeds is not None: + for i, negative_ip_adapter_embed in enumerate(block_state.negative_ip_adapter_embeds): + block_state.negative_ip_adapter_embeds[i] = torch.cat( + [negative_ip_adapter_embed] * block_state.num_images_per_prompt, dim=0 + ) + + self.set_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLImg2ImgSetTimestepsStep(ModularPipelineBlocks): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return ( + "Step that sets the timesteps for the scheduler and determines the initial noise level (latent_timestep) for image-to-image/inpainting generation.\n" + + "The latent_timestep is calculated from the `strength` parameter - higher strength means starting from a noisier version of the input image." + ) + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("num_inference_steps", default=50), + InputParam("timesteps"), + InputParam("sigmas"), + InputParam("denoising_end"), + InputParam("strength", default=0.3), + InputParam("denoising_start"), + # YiYi TODO: do we need num_images_per_prompt here? + InputParam("num_images_per_prompt", default=1), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt", + ), + ] + + @property + def intermediate_outputs(self) -> list[str]: + return [ + OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), + OutputParam( + "num_inference_steps", + type_hint=int, + description="The number of denoising steps to perform at inference time", + ), + OutputParam( + "latent_timestep", + type_hint=torch.Tensor, + description="The timestep that represents the initial noise level for image-to-image generation", + ), + ] + + @staticmethod + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps with self->components + def get_timesteps(components, num_inference_steps, strength, device, denoising_start=None): + # get the original timestep using init_timestep + if denoising_start is None: + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + t_start = max(num_inference_steps - init_timestep, 0) + + timesteps = components.scheduler.timesteps[t_start * components.scheduler.order :] + if hasattr(components.scheduler, "set_begin_index"): + components.scheduler.set_begin_index(t_start * components.scheduler.order) + + return timesteps, num_inference_steps - t_start + + else: + # Strength is irrelevant if we directly request a timestep to start at; + # that is, strength is determined by the denoising_start instead. + discrete_timestep_cutoff = int( + round( + components.scheduler.config.num_train_timesteps + - (denoising_start * components.scheduler.config.num_train_timesteps) + ) + ) + + num_inference_steps = (components.scheduler.timesteps < discrete_timestep_cutoff).sum().item() + if components.scheduler.order == 2 and num_inference_steps % 2 == 0: + # if the scheduler is a 2nd order scheduler we might have to do +1 + # because `num_inference_steps` might be even given that every timestep + # (except the highest one) is duplicated. If `num_inference_steps` is even it would + # mean that we cut the timesteps in the middle of the denoising step + # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1 + # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler + num_inference_steps = num_inference_steps + 1 + + # because t_n+1 >= t_n, we slice the timesteps starting from the end + t_start = len(components.scheduler.timesteps) - num_inference_steps + timesteps = components.scheduler.timesteps[t_start:] + if hasattr(components.scheduler, "set_begin_index"): + components.scheduler.set_begin_index(t_start) + return timesteps, num_inference_steps + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.device = components._execution_device + + block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( + components.scheduler, + block_state.num_inference_steps, + block_state.device, + block_state.timesteps, + block_state.sigmas, + ) + + def denoising_value_valid(dnv): + return isinstance(dnv, float) and 0 < dnv < 1 + + block_state.timesteps, block_state.num_inference_steps = self.get_timesteps( + components, + block_state.num_inference_steps, + block_state.strength, + block_state.device, + denoising_start=block_state.denoising_start + if denoising_value_valid(block_state.denoising_start) + else None, + ) + block_state.latent_timestep = block_state.timesteps[:1].repeat( + block_state.batch_size * block_state.num_images_per_prompt + ) + + if ( + block_state.denoising_end is not None + and isinstance(block_state.denoising_end, float) + and block_state.denoising_end > 0 + and block_state.denoising_end < 1 + ): + block_state.discrete_timestep_cutoff = int( + round( + components.scheduler.config.num_train_timesteps + - (block_state.denoising_end * components.scheduler.config.num_train_timesteps) + ) + ) + block_state.num_inference_steps = len( + list(filter(lambda ts: ts >= block_state.discrete_timestep_cutoff, block_state.timesteps)) + ) + block_state.timesteps = block_state.timesteps[: block_state.num_inference_steps] + + self.set_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLSetTimestepsStep(ModularPipelineBlocks): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return "Step that sets the scheduler's timesteps for inference" + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("num_inference_steps", default=50), + InputParam("timesteps"), + InputParam("sigmas"), + InputParam("denoising_end"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), + OutputParam( + "num_inference_steps", + type_hint=int, + description="The number of denoising steps to perform at inference time", + ), + ] + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.device = components._execution_device + + block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( + components.scheduler, + block_state.num_inference_steps, + block_state.device, + block_state.timesteps, + block_state.sigmas, + ) + + if ( + block_state.denoising_end is not None + and isinstance(block_state.denoising_end, float) + and block_state.denoising_end > 0 + and block_state.denoising_end < 1 + ): + block_state.discrete_timestep_cutoff = int( + round( + components.scheduler.config.num_train_timesteps + - (block_state.denoising_end * components.scheduler.config.num_train_timesteps) + ) + ) + block_state.num_inference_steps = len( + list(filter(lambda ts: ts >= block_state.discrete_timestep_cutoff, block_state.timesteps)) + ) + block_state.timesteps = block_state.timesteps[: block_state.num_inference_steps] + + self.set_block_state(state, block_state) + return components, state + + +class StableDiffusionXLInpaintPrepareLatentsStep(ModularPipelineBlocks): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return "Step that prepares the latents for the inpainting process" + + @property + def inputs(self) -> list[tuple[str, Any]]: + return [ + InputParam("latents"), + InputParam("num_images_per_prompt", default=1), + InputParam("denoising_start"), + InputParam( + "strength", + default=0.9999, + description="Conceptually, indicates how much to transform the reference `image` (the masked portion of image for inpainting). Must be between 0 and 1. `image` " + "will be used as a starting point, adding more noise to it the larger the `strength`. The number of " + "denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will " + "be maximum and the denoising process will run for the full number of iterations specified in " + "`num_inference_steps`. A value of 1, therefore, essentially ignores `image`. Note that in the case of " + "`denoising_start` being declared as an integer, the value of `strength` will be ignored.", + ), + InputParam("generator"), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.", + ), + InputParam( + "latent_timestep", + required=True, + type_hint=torch.Tensor, + description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step.", + ), + InputParam( + "image_latents", + required=True, + type_hint=torch.Tensor, + description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step.", + ), + InputParam( + "mask", + required=True, + type_hint=torch.Tensor, + description="The mask for the inpainting generation. Can be generated in vae_encode step.", + ), + InputParam( + "masked_image_latents", + type_hint=torch.Tensor, + description="The masked image latents for the inpainting generation (only for inpainting-specific unet). Can be generated in vae_encode step.", + ), + InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"), + ] + + @property + def intermediate_outputs(self) -> list[str]: + return [ + OutputParam( + "latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process" + ), + OutputParam( + "noise", + type_hint=torch.Tensor, + description="The noise added to the image latents, used for inpainting generation", + ), + ] + + # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self->components + # YiYi TODO: update the _encode_vae_image so that we can use #Coped from + @staticmethod + def _encode_vae_image(components, image: torch.Tensor, generator: torch.Generator): + latents_mean = latents_std = None + if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: + latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: + latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) + + dtype = image.dtype + if components.vae.config.force_upcast: + image = image.float() + components.vae.to(dtype=torch.float32) + + if isinstance(generator, list): + image_latents = [ + retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(components.vae.encode(image), generator=generator) + + if components.vae.config.force_upcast: + components.vae.to(dtype) + + image_latents = image_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype) + latents_std = latents_std.to(device=image_latents.device, dtype=dtype) + image_latents = (image_latents - latents_mean) * components.vae.config.scaling_factor / latents_std + else: + image_latents = components.vae.config.scaling_factor * image_latents + + return image_latents + + # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_latents adding components as first argument + def prepare_latents_inpaint( + self, + components, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + image=None, + timestep=None, + is_strength_max=True, + add_noise=True, + ): + shape = ( + batch_size, + num_channels_latents, + int(height) // components.vae_scale_factor, + int(width) // components.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if (image is None or timestep is None) and not is_strength_max: + raise ValueError( + "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." + "However, either the image or the noise timestep has not been provided." + ) + + if image.shape[1] == 4: + image_latents = image.to(device=device, dtype=dtype) + image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) + elif latents is None and not is_strength_max: + image = image.to(device=device, dtype=dtype) + image_latents = self._encode_vae_image(components, image=image, generator=generator) + image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) + + if latents is None and add_noise: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # if strength is 1. then initialise the latents to noise, else initial to image + noise + latents = noise if is_strength_max else components.scheduler.add_noise(image_latents, noise, timestep) + # if pure noise then scale the initial latents by the Scheduler's init sigma + latents = latents * components.scheduler.init_noise_sigma if is_strength_max else latents + elif add_noise: + noise = latents.to(device) + latents = noise * components.scheduler.init_noise_sigma + else: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = image_latents.to(device) + + outputs = (latents, noise, image_latents) + + return outputs + + # modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents + # do not accept do_classifier_free_guidance + def prepare_mask_latents( + self, components, mask, masked_image, batch_size, height, width, dtype, device, generator + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // components.vae_scale_factor, width // components.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + + if masked_image is not None and masked_image.shape[1] == 4: + masked_image_latents = masked_image + else: + masked_image_latents = None + + if masked_image is not None: + if masked_image_latents is None: + masked_image = masked_image.to(device=device, dtype=dtype) + masked_image_latents = self._encode_vae_image(components, masked_image, generator=generator) + + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat( + batch_size // masked_image_latents.shape[0], 1, 1, 1 + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + + return mask, masked_image_latents + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype + block_state.device = components._execution_device + + block_state.is_strength_max = block_state.strength == 1.0 + + # for non-inpainting specific unet, we do not need masked_image_latents + if hasattr(components, "unet") and components.unet is not None: + if components.unet.config.in_channels == 4: + block_state.masked_image_latents = None + + block_state.add_noise = True if block_state.denoising_start is None else False + + block_state.height = block_state.image_latents.shape[-2] * components.vae_scale_factor + block_state.width = block_state.image_latents.shape[-1] * components.vae_scale_factor + + block_state.latents, block_state.noise, block_state.image_latents = self.prepare_latents_inpaint( + components, + block_state.batch_size * block_state.num_images_per_prompt, + components.num_channels_latents, + block_state.height, + block_state.width, + block_state.dtype, + block_state.device, + block_state.generator, + block_state.latents, + image=block_state.image_latents, + timestep=block_state.latent_timestep, + is_strength_max=block_state.is_strength_max, + add_noise=block_state.add_noise, + ) + + # 7. Prepare mask latent variables + block_state.mask, block_state.masked_image_latents = self.prepare_mask_latents( + components, + block_state.mask, + block_state.masked_image_latents, + block_state.batch_size * block_state.num_images_per_prompt, + block_state.height, + block_state.width, + block_state.dtype, + block_state.device, + block_state.generator, + ) + + self.set_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLImg2ImgPrepareLatentsStep(ModularPipelineBlocks): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKL), + ComponentSpec("scheduler", EulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return "Step that prepares the latents for the image-to-image generation process" + + @property + def inputs(self) -> list[tuple[str, Any]]: + return [ + InputParam("latents"), + InputParam("num_images_per_prompt", default=1), + InputParam("denoising_start"), + InputParam("generator"), + InputParam( + "latent_timestep", + required=True, + type_hint=torch.Tensor, + description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step.", + ), + InputParam( + "image_latents", + required=True, + type_hint=torch.Tensor, + description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step.", + ), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.", + ), + InputParam("dtype", required=True, type_hint=torch.dtype, description="The dtype of the model inputs"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process" + ) + ] + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype + block_state.device = components._execution_device + block_state.add_noise = True if block_state.denoising_start is None else False + if block_state.latents is None: + block_state.latents = prepare_latents_img2img( + components.vae, + components.scheduler, + block_state.image_latents, + block_state.latent_timestep, + block_state.batch_size, + block_state.num_images_per_prompt, + block_state.dtype, + block_state.device, + block_state.generator, + block_state.add_noise, + ) + + self.set_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLPrepareLatentsStep(ModularPipelineBlocks): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ComponentSpec("vae", AutoencoderKL), + ] + + @property + def description(self) -> str: + return "Prepare latents step that prepares the latents for the text-to-image generation process" + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("height"), + InputParam("width"), + InputParam("latents"), + InputParam("num_images_per_prompt", default=1), + InputParam("generator"), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.", + ), + InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process" + ) + ] + + @staticmethod + def check_inputs(components, block_state): + if ( + block_state.height is not None + and block_state.height % components.vae_scale_factor != 0 + or block_state.width is not None + and block_state.width % components.vae_scale_factor != 0 + ): + raise ValueError( + f"`height` and `width` have to be divisible by {components.vae_scale_factor} but are {block_state.height} and {block_state.width}." + ) + + @staticmethod + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents with self->comp + def prepare_latents(comp, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // comp.vae_scale_factor, + int(width) // comp.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * comp.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + if block_state.dtype is None: + block_state.dtype = components.vae.dtype + + block_state.device = components._execution_device + + self.check_inputs(components, block_state) + + block_state.height = block_state.height or components.default_sample_size * components.vae_scale_factor + block_state.width = block_state.width or components.default_sample_size * components.vae_scale_factor + block_state.num_channels_latents = components.num_channels_latents + block_state.latents = self.prepare_latents( + components, + block_state.batch_size * block_state.num_images_per_prompt, + block_state.num_channels_latents, + block_state.height, + block_state.width, + block_state.dtype, + block_state.device, + block_state.generator, + block_state.latents, + ) + + self.set_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(ModularPipelineBlocks): + model_name = "stable-diffusion-xl" + + @property + def expected_configs(self) -> list[ConfigSpec]: + return [ + ConfigSpec("requires_aesthetics_score", False), + ] + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("unet", UNet2DConditionModel), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config", + ), + ] + + @property + def description(self) -> str: + return "Step that prepares the additional conditioning for the image-to-image/inpainting generation process" + + @property + def inputs(self) -> list[tuple[str, Any]]: + return [ + InputParam("original_size"), + InputParam("target_size"), + InputParam("negative_original_size"), + InputParam("negative_target_size"), + InputParam("crops_coords_top_left", default=(0, 0)), + InputParam("negative_crops_coords_top_left", default=(0, 0)), + InputParam("num_images_per_prompt", default=1), + InputParam("aesthetic_score", default=6.0), + InputParam("negative_aesthetic_score", default=2.0), + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", + ), + InputParam( + "pooled_prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="The pooled prompt embeddings to use for the denoising process (used to determine shapes and dtypes for other additional conditioning inputs). Can be generated in text_encoder step.", + ), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "add_time_ids", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="The time ids to condition the denoising process", + ), + OutputParam( + "negative_add_time_ids", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="The negative time ids to condition the denoising process", + ), + OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM"), + ] + + @staticmethod + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids with self->components + def _get_add_time_ids( + components, + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype, + text_encoder_projection_dim=None, + ): + if components.config.requires_aesthetics_score: + add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) + add_neg_time_ids = list( + negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,) + ) + else: + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) + + passed_add_embed_dim = ( + components.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = components.unet.add_embedding.linear_1.in_features + + if ( + expected_add_embed_dim > passed_add_embed_dim + and (expected_add_embed_dim - passed_add_embed_dim) == components.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." + ) + elif ( + expected_add_embed_dim < passed_add_embed_dim + and (passed_add_embed_dim - expected_add_embed_dim) == components.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." + ) + elif expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) + + return add_time_ids, add_neg_time_ids + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.device = components._execution_device + + block_state.vae_scale_factor = components.vae_scale_factor + + block_state.height, block_state.width = block_state.latents.shape[-2:] + block_state.height = block_state.height * block_state.vae_scale_factor + block_state.width = block_state.width * block_state.vae_scale_factor + + block_state.original_size = block_state.original_size or (block_state.height, block_state.width) + block_state.target_size = block_state.target_size or (block_state.height, block_state.width) + + block_state.text_encoder_projection_dim = int(block_state.pooled_prompt_embeds.shape[-1]) + + if block_state.negative_original_size is None: + block_state.negative_original_size = block_state.original_size + if block_state.negative_target_size is None: + block_state.negative_target_size = block_state.target_size + + block_state.add_time_ids, block_state.negative_add_time_ids = self._get_add_time_ids( + components, + block_state.original_size, + block_state.crops_coords_top_left, + block_state.target_size, + block_state.aesthetic_score, + block_state.negative_aesthetic_score, + block_state.negative_original_size, + block_state.negative_crops_coords_top_left, + block_state.negative_target_size, + dtype=block_state.pooled_prompt_embeds.dtype, + text_encoder_projection_dim=block_state.text_encoder_projection_dim, + ) + block_state.add_time_ids = block_state.add_time_ids.repeat( + block_state.batch_size * block_state.num_images_per_prompt, 1 + ).to(device=block_state.device) + block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat( + block_state.batch_size * block_state.num_images_per_prompt, 1 + ).to(device=block_state.device) + + # Optionally get Guidance Scale Embedding for LCM + block_state.timestep_cond = None + if ( + hasattr(components, "unet") + and components.unet is not None + and components.unet.config.time_cond_proj_dim is not None + ): + # TODO(yiyi, aryan): Ideally, this should be `embedded_guidance_scale` instead of pulling from guider. Guider scales should be different from this! + block_state.guidance_scale_tensor = torch.tensor(components.guider.guidance_scale - 1).repeat( + block_state.batch_size * block_state.num_images_per_prompt + ) + block_state.timestep_cond = self.get_guidance_scale_embedding( + block_state.guidance_scale_tensor, embedding_dim=components.unet.config.time_cond_proj_dim + ).to(device=block_state.device, dtype=block_state.latents.dtype) + + self.set_block_state(state, block_state) + return components, state + + +class StableDiffusionXLPrepareAdditionalConditioningStep(ModularPipelineBlocks): + model_name = "stable-diffusion-xl" + + @property + def description(self) -> str: + return "Step that prepares the additional conditioning for the text-to-image generation process" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("unet", UNet2DConditionModel), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[tuple[str, Any]]: + return [ + InputParam("original_size"), + InputParam("target_size"), + InputParam("negative_original_size"), + InputParam("negative_target_size"), + InputParam("crops_coords_top_left", default=(0, 0)), + InputParam("negative_crops_coords_top_left", default=(0, 0)), + InputParam("num_images_per_prompt", default=1), + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", + ), + InputParam( + "pooled_prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="The pooled prompt embeddings to use for the denoising process (used to determine shapes and dtypes for other additional conditioning inputs). Can be generated in text_encoder step.", + ), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "add_time_ids", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="The time ids to condition the denoising process", + ), + OutputParam( + "negative_add_time_ids", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="The negative time ids to condition the denoising process", + ), + OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM"), + ] + + @staticmethod + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids with self->components + def _get_add_time_ids( + components, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + components.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = components.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.device = components._execution_device + + block_state.height, block_state.width = block_state.latents.shape[-2:] + block_state.height = block_state.height * components.vae_scale_factor + block_state.width = block_state.width * components.vae_scale_factor + + block_state.original_size = block_state.original_size or (block_state.height, block_state.width) + block_state.target_size = block_state.target_size or (block_state.height, block_state.width) + + block_state.text_encoder_projection_dim = int(block_state.pooled_prompt_embeds.shape[-1]) + + block_state.add_time_ids = self._get_add_time_ids( + components, + block_state.original_size, + block_state.crops_coords_top_left, + block_state.target_size, + block_state.pooled_prompt_embeds.dtype, + text_encoder_projection_dim=block_state.text_encoder_projection_dim, + ) + if block_state.negative_original_size is not None and block_state.negative_target_size is not None: + block_state.negative_add_time_ids = self._get_add_time_ids( + components, + block_state.negative_original_size, + block_state.negative_crops_coords_top_left, + block_state.negative_target_size, + block_state.pooled_prompt_embeds.dtype, + text_encoder_projection_dim=block_state.text_encoder_projection_dim, + ) + else: + block_state.negative_add_time_ids = block_state.add_time_ids + + block_state.add_time_ids = block_state.add_time_ids.repeat( + block_state.batch_size * block_state.num_images_per_prompt, 1 + ).to(device=block_state.device) + block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat( + block_state.batch_size * block_state.num_images_per_prompt, 1 + ).to(device=block_state.device) + + # Optionally get Guidance Scale Embedding for LCM + block_state.timestep_cond = None + if ( + hasattr(components, "unet") + and components.unet is not None + and components.unet.config.time_cond_proj_dim is not None + ): + # TODO(yiyi, aryan): Ideally, this should be `embedded_guidance_scale` instead of pulling from guider. Guider scales should be different from this! + block_state.guidance_scale_tensor = torch.tensor(components.guider.guidance_scale - 1).repeat( + block_state.batch_size * block_state.num_images_per_prompt + ) + block_state.timestep_cond = self.get_guidance_scale_embedding( + block_state.guidance_scale_tensor, embedding_dim=components.unet.config.time_cond_proj_dim + ).to(device=block_state.device, dtype=block_state.latents.dtype) + + self.set_block_state(state, block_state) + return components, state + + +class StableDiffusionXLControlNetInputStep(ModularPipelineBlocks): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("controlnet", ControlNetModel), + ComponentSpec( + "control_image_processor", + VaeImageProcessor, + config=FrozenDict({"do_convert_rgb": True, "do_normalize": False}), + default_creation_method="from_config", + ), + ] + + @property + def description(self) -> str: + return "step that prepare inputs for controlnet" + + @property + def inputs(self) -> list[tuple[str, Any]]: + return [ + InputParam("control_image", required=True), + InputParam("control_guidance_start", default=0.0), + InputParam("control_guidance_end", default=1.0), + InputParam("controlnet_conditioning_scale", default=1.0), + InputParam("guess_mode", default=False), + InputParam("num_images_per_prompt", default=1), + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", + ), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.", + ), + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", + ), + InputParam( + "crops_coords", + type_hint=tuple[int] | None, + description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step.", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam("controlnet_cond", type_hint=torch.Tensor, description="The processed control image"), + OutputParam( + "control_guidance_start", type_hint=list[float], description="The controlnet guidance start values" + ), + OutputParam( + "control_guidance_end", type_hint=list[float], description="The controlnet guidance end values" + ), + OutputParam( + "conditioning_scale", type_hint=list[float], description="The controlnet conditioning scale values" + ), + OutputParam("guess_mode", type_hint=bool, description="Whether guess mode is used"), + OutputParam("controlnet_keep", type_hint=list[float], description="The controlnet keep values"), + ] + + # Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image + # 1. return image without apply any guidance + # 2. add crops_coords and resize_mode to preprocess() + @staticmethod + def prepare_control_image( + components, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + crops_coords=None, + ): + if crops_coords is not None: + image = components.control_image_processor.preprocess( + image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill" + ).to(dtype=torch.float32) + else: + image = components.control_image_processor.preprocess(image, height=height, width=width).to( + dtype=torch.float32 + ) + + image_batch_size = image.shape[0] + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + image = image.to(device=device, dtype=dtype) + return image + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + # (1) prepare controlnet inputs + block_state.device = components._execution_device + block_state.height, block_state.width = block_state.latents.shape[-2:] + block_state.height = block_state.height * components.vae_scale_factor + block_state.width = block_state.width * components.vae_scale_factor + + controlnet = unwrap_module(components.controlnet) + + # (1.1) + # control_guidance_start/control_guidance_end (align format) + if not isinstance(block_state.control_guidance_start, list) and isinstance( + block_state.control_guidance_end, list + ): + block_state.control_guidance_start = len(block_state.control_guidance_end) * [ + block_state.control_guidance_start + ] + elif not isinstance(block_state.control_guidance_end, list) and isinstance( + block_state.control_guidance_start, list + ): + block_state.control_guidance_end = len(block_state.control_guidance_start) * [ + block_state.control_guidance_end + ] + elif not isinstance(block_state.control_guidance_start, list) and not isinstance( + block_state.control_guidance_end, list + ): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + block_state.control_guidance_start, block_state.control_guidance_end = ( + mult * [block_state.control_guidance_start], + mult * [block_state.control_guidance_end], + ) + + # (1.2) + # controlnet_conditioning_scale (align format) + if isinstance(controlnet, MultiControlNetModel) and isinstance( + block_state.controlnet_conditioning_scale, float + ): + block_state.controlnet_conditioning_scale = [block_state.controlnet_conditioning_scale] * len( + controlnet.nets + ) + + # (1.3) + # global_pool_conditions + block_state.global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + # (1.4) + # guess_mode + block_state.guess_mode = block_state.guess_mode or block_state.global_pool_conditions + + # (1.5) + # control_image + if isinstance(controlnet, ControlNetModel): + block_state.control_image = self.prepare_control_image( + components, + image=block_state.control_image, + width=block_state.width, + height=block_state.height, + batch_size=block_state.batch_size * block_state.num_images_per_prompt, + num_images_per_prompt=block_state.num_images_per_prompt, + device=block_state.device, + dtype=controlnet.dtype, + crops_coords=block_state.crops_coords, + ) + elif isinstance(controlnet, MultiControlNetModel): + control_images = [] + + for control_image_ in block_state.control_image: + control_image = self.prepare_control_image( + components, + image=control_image_, + width=block_state.width, + height=block_state.height, + batch_size=block_state.batch_size * block_state.num_images_per_prompt, + num_images_per_prompt=block_state.num_images_per_prompt, + device=block_state.device, + dtype=controlnet.dtype, + crops_coords=block_state.crops_coords, + ) + + control_images.append(control_image) + + block_state.control_image = control_images + else: + assert False + + # (1.6) + # controlnet_keep + block_state.controlnet_keep = [] + for i in range(len(block_state.timesteps)): + keeps = [ + 1.0 - float(i / len(block_state.timesteps) < s or (i + 1) / len(block_state.timesteps) > e) + for s, e in zip(block_state.control_guidance_start, block_state.control_guidance_end) + ] + block_state.controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + + block_state.controlnet_cond = block_state.control_image + block_state.conditioning_scale = block_state.controlnet_conditioning_scale + + self.set_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLControlNetUnionInputStep(ModularPipelineBlocks): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("controlnet", ControlNetUnionModel), + ComponentSpec( + "control_image_processor", + VaeImageProcessor, + config=FrozenDict({"do_convert_rgb": True, "do_normalize": False}), + default_creation_method="from_config", + ), + ] + + @property + def description(self) -> str: + return "step that prepares inputs for the ControlNetUnion model" + + @property + def inputs(self) -> list[tuple[str, Any]]: + return [ + InputParam("control_image", required=True), + InputParam("control_mode", required=True), + InputParam("control_guidance_start", default=0.0), + InputParam("control_guidance_end", default=1.0), + InputParam("controlnet_conditioning_scale", default=1.0), + InputParam("guess_mode", default=False), + InputParam("num_images_per_prompt", default=1), + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Used to determine the shape of the control images. Can be generated in prepare_latent step.", + ), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.", + ), + InputParam( + "dtype", + required=True, + type_hint=torch.dtype, + description="The dtype of model tensor inputs. Can be generated in input step.", + ), + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Needed to determine `controlnet_keep`. Can be generated in set_timesteps step.", + ), + InputParam( + "crops_coords", + type_hint=tuple[int] | None, + description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step.", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam("controlnet_cond", type_hint=list[torch.Tensor], description="The processed control images"), + OutputParam( + "control_type_idx", + type_hint=list[int], + description="The control mode indices", + kwargs_type="controlnet_kwargs", + ), + OutputParam( + "control_type", + type_hint=torch.Tensor, + description="The control type tensor that specifies which control type is active", + kwargs_type="controlnet_kwargs", + ), + OutputParam("control_guidance_start", type_hint=float, description="The controlnet guidance start value"), + OutputParam("control_guidance_end", type_hint=float, description="The controlnet guidance end value"), + OutputParam( + "conditioning_scale", type_hint=list[float], description="The controlnet conditioning scale values" + ), + OutputParam("guess_mode", type_hint=bool, description="Whether guess mode is used"), + OutputParam("controlnet_keep", type_hint=list[float], description="The controlnet keep values"), + ] + + # Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image + # 1. return image without apply any guidance + # 2. add crops_coords and resize_mode to preprocess() + @staticmethod + def prepare_control_image( + components, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + crops_coords=None, + ): + if crops_coords is not None: + image = components.control_image_processor.preprocess( + image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill" + ).to(dtype=torch.float32) + else: + image = components.control_image_processor.preprocess(image, height=height, width=width).to( + dtype=torch.float32 + ) + + image_batch_size = image.shape[0] + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + image = image.to(device=device, dtype=dtype) + return image + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + controlnet = unwrap_module(components.controlnet) + + device = components._execution_device + dtype = block_state.dtype or components.controlnet.dtype + + block_state.height, block_state.width = block_state.latents.shape[-2:] + block_state.height = block_state.height * components.vae_scale_factor + block_state.width = block_state.width * components.vae_scale_factor + + # control_guidance_start/control_guidance_end (align format) + if not isinstance(block_state.control_guidance_start, list) and isinstance( + block_state.control_guidance_end, list + ): + block_state.control_guidance_start = len(block_state.control_guidance_end) * [ + block_state.control_guidance_start + ] + elif not isinstance(block_state.control_guidance_end, list) and isinstance( + block_state.control_guidance_start, list + ): + block_state.control_guidance_end = len(block_state.control_guidance_start) * [ + block_state.control_guidance_end + ] + + # guess_mode + block_state.global_pool_conditions = controlnet.config.global_pool_conditions + block_state.guess_mode = block_state.guess_mode or block_state.global_pool_conditions + + # control_image + if not isinstance(block_state.control_image, list): + block_state.control_image = [block_state.control_image] + # control_mode + if not isinstance(block_state.control_mode, list): + block_state.control_mode = [block_state.control_mode] + + if len(block_state.control_image) != len(block_state.control_mode): + raise ValueError("Expected len(control_image) == len(control_type)") + + # control_type + block_state.num_control_type = controlnet.config.num_control_type + block_state.control_type = [0 for _ in range(block_state.num_control_type)] + for control_idx in block_state.control_mode: + block_state.control_type[control_idx] = 1 + block_state.control_type = torch.Tensor(block_state.control_type) + + block_state.control_type = block_state.control_type.reshape(1, -1).to(device, dtype=block_state.dtype) + repeat_by = block_state.batch_size * block_state.num_images_per_prompt // block_state.control_type.shape[0] + block_state.control_type = block_state.control_type.repeat_interleave(repeat_by, dim=0) + + # prepare control_image + for idx, _ in enumerate(block_state.control_image): + block_state.control_image[idx] = self.prepare_control_image( + components, + image=block_state.control_image[idx], + width=block_state.width, + height=block_state.height, + batch_size=block_state.batch_size * block_state.num_images_per_prompt, + num_images_per_prompt=block_state.num_images_per_prompt, + device=device, + dtype=dtype, + crops_coords=block_state.crops_coords, + ) + block_state.height, block_state.width = block_state.control_image[idx].shape[-2:] + + # controlnet_keep + block_state.controlnet_keep = [] + for i in range(len(block_state.timesteps)): + block_state.controlnet_keep.append( + 1.0 + - float( + i / len(block_state.timesteps) < block_state.control_guidance_start + or (i + 1) / len(block_state.timesteps) > block_state.control_guidance_end + ) + ) + block_state.control_type_idx = block_state.control_mode + block_state.controlnet_cond = block_state.control_image + block_state.conditioning_scale = block_state.controlnet_conditioning_scale + + self.set_block_state(state, block_state) + + return components, state diff --git a/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py b/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py new file mode 100644 index 0000000000000000000000000000000000000000..7e505559f68569f9c36c4971f76c9513dad016ef --- /dev/null +++ b/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py @@ -0,0 +1,198 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import numpy as np +import PIL +import torch + +from ...configuration_utils import FrozenDict +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKL +from ...utils import deprecate, logging +from ..modular_pipeline import ( + ModularPipelineBlocks, + PipelineState, +) +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class StableDiffusionXLDecodeStep(ModularPipelineBlocks): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKL), + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8}), + default_creation_method="from_config", + ), + ] + + @property + def description(self) -> str: + return "Step that decodes the denoised latents into images" + + @property + def inputs(self) -> list[tuple[str, Any]]: + return [ + InputParam("output_type", default="pil"), + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The denoised latents from the denoising step", + ), + ] + + @property + def intermediate_outputs(self) -> list[str]: + return [ + OutputParam( + "images", + type_hint=list[PIL.Image.Image] | list[torch.Tensor] | list[np.array], + description="The generated images, can be a PIL.Image.Image, torch.Tensor or a numpy array", + ) + ] + + @staticmethod + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae with self->components + def upcast_vae(components): + deprecate( + "upcast_vae", + "1.0.0", + "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.", + ) + components.vae.to(dtype=torch.float32) + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + if not block_state.output_type == "latent": + latents = block_state.latents + # make sure the VAE is in float32 mode, as it overflows in float16 + block_state.needs_upcasting = components.vae.dtype == torch.float16 and components.vae.config.force_upcast + + if block_state.needs_upcasting: + self.upcast_vae(components) + latents = latents.to(next(iter(components.vae.post_quant_conv.parameters())).dtype) + elif latents.dtype != components.vae.dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + components.vae = components.vae.to(latents.dtype) + + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + block_state.has_latents_mean = ( + hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None + ) + block_state.has_latents_std = ( + hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None + ) + if block_state.has_latents_mean and block_state.has_latents_std: + block_state.latents_mean = ( + torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + block_state.latents_std = ( + torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents = ( + latents * block_state.latents_std / components.vae.config.scaling_factor + block_state.latents_mean + ) + else: + latents = latents / components.vae.config.scaling_factor + + block_state.images = components.vae.decode(latents, return_dict=False)[0] + + # cast back to fp16 if needed + if block_state.needs_upcasting: + components.vae.to(dtype=torch.float16) + else: + block_state.images = block_state.latents + + # apply watermark if available + if hasattr(components, "watermark") and components.watermark is not None: + block_state.images = components.watermark.apply_watermark(block_state.images) + + block_state.images = components.image_processor.postprocess( + block_state.images, output_type=block_state.output_type + ) + + self.set_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLInpaintOverlayMaskStep(ModularPipelineBlocks): + model_name = "stable-diffusion-xl" + + @property + def description(self) -> str: + return ( + "A post-processing step that overlays the mask on the image (inpainting task only).\n" + + "only needed when you are using the `padding_mask_crop` option when pre-processing the image and mask" + ) + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[tuple[str, Any]]: + return [ + InputParam("image"), + InputParam("mask_image"), + InputParam("padding_mask_crop"), + InputParam( + "images", + type_hint=list[PIL.Image.Image] | list[torch.Tensor] | list[np.array], + description="The generated images from the decode step", + ), + InputParam( + "crops_coords", + type_hint=tuple[int, int], + description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step.", + ), + ] + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + if block_state.padding_mask_crop is not None and block_state.crops_coords is not None: + block_state.images = [ + components.image_processor.apply_overlay( + block_state.mask_image, block_state.image, i, block_state.crops_coords + ) + for i in block_state.images + ] + + self.set_block_state(state, block_state) + + return components, state diff --git a/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py b/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py new file mode 100644 index 0000000000000000000000000000000000000000..0190bc3ea62fa4842ffd0c944dcf61dc5ef64d9d --- /dev/null +++ b/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py @@ -0,0 +1,798 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any + +import torch + +from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance +from ...models import ControlNetModel, UNet2DConditionModel +from ...schedulers import EulerDiscreteScheduler +from ...utils import logging +from ..modular_pipeline import ( + BlockState, + LoopSequentialPipelineBlocks, + ModularPipelineBlocks, + PipelineState, +) +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import StableDiffusionXLModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# YiYi experimenting composible denoise loop +# loop step (1): prepare latent input for denoiser +class StableDiffusionXLLoopBeforeDenoiser(ModularPipelineBlocks): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return ( + "step within the denoising loop that prepare the latent input for the denoiser. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> list[str]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", + ), + ] + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int): + block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) + + return components, block_state + + +# loop step (1): prepare latent input for denoiser (with inpainting) +class StableDiffusionXLInpaintLoopBeforeDenoiser(ModularPipelineBlocks): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ComponentSpec("unet", UNet2DConditionModel), + ] + + @property + def description(self) -> str: + return ( + "step within the denoising loop that prepare the latent input for the denoiser (for inpainting workflow only). " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` object" + ) + + @property + def inputs(self) -> list[str]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", + ), + InputParam( + "mask", + type_hint=torch.Tensor | None, + description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step.", + ), + InputParam( + "masked_image_latents", + type_hint=torch.Tensor | None, + description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step.", + ), + ] + + @staticmethod + def check_inputs(components, block_state): + num_channels_unet = components.num_channels_unet + if num_channels_unet == 9: + # default case for stable-diffusion-v1-5/stable-diffusion-inpainting + if block_state.mask is None or block_state.masked_image_latents is None: + raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") + num_channels_latents = block_state.latents.shape[1] + num_channels_mask = block_state.mask.shape[1] + num_channels_masked_image = block_state.masked_image_latents.shape[1] + if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: + raise ValueError( + f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" + f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of" + " `components.unet` or your `mask_image` or `image` input." + ) + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int): + self.check_inputs(components, block_state) + + block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) + if components.num_channels_unet == 9: + block_state.scaled_latents = torch.cat( + [block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1 + ) + + return components, block_state + + +# loop step (2): denoise the latents with guidance +class StableDiffusionXLLoopDenoiser(ModularPipelineBlocks): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config", + ), + ComponentSpec("unet", UNet2DConditionModel), + ] + + @property + def description(self) -> str: + return ( + "Step within the denoising loop that denoise the latents with guidance. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> list[tuple[str, Any]]: + return [ + InputParam("cross_attention_kwargs"), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", + ), + InputParam( + "timestep_cond", + type_hint=torch.Tensor | None, + description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step.", + ), + InputParam( + kwargs_type="denoiser_input_fields", + description=( + "All conditional model inputs that need to be prepared with guider. " + "It should contain prompt_embeds/negative_prompt_embeds, " + "add_time_ids/negative_add_time_ids, " + "pooled_prompt_embeds/negative_pooled_prompt_embeds, " + "and ip_adapter_embeds/negative_ip_adapter_embeds (optional)." + "please add `kwargs_type=denoiser_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state" + ), + ), + ] + + @torch.no_grad() + def __call__( + self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int + ) -> PipelineState: + # Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds) + # to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds) + guider_inputs = { + "prompt_embeds": ( + getattr(block_state, "prompt_embeds", None), + getattr(block_state, "negative_prompt_embeds", None), + ), + "time_ids": ( + getattr(block_state, "add_time_ids", None), + getattr(block_state, "negative_add_time_ids", None), + ), + "text_embeds": ( + getattr(block_state, "pooled_prompt_embeds", None), + getattr(block_state, "negative_pooled_prompt_embeds", None), + ), + "image_embeds": ( + getattr(block_state, "ip_adapter_embeds", None), + getattr(block_state, "negative_ip_adapter_embeds", None), + ), + } + + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + + # The guider splits model inputs into separate batches for conditional/unconditional predictions. + # For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}: + # you will get a guider_state with two batches: + # guider_state = [ + # {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch + # {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch + # ] + # Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG). + guider_state = components.guider.prepare_inputs(guider_inputs) + + # run the denoiser for each guidance batch + for guider_state_batch in guider_state: + components.guider.prepare_models(components.unet) + cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()} + prompt_embeds = cond_kwargs.pop("prompt_embeds") + + # Predict the noise residual + # store the noise_pred in guider_state_batch so that we can apply guidance across all batches + guider_state_batch.noise_pred = components.unet( + block_state.scaled_latents, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=block_state.timestep_cond, + cross_attention_kwargs=block_state.cross_attention_kwargs, + added_cond_kwargs=cond_kwargs, + return_dict=False, + )[0] + components.guider.cleanup_models(components.unet) + + # Perform guidance + block_state.noise_pred = components.guider(guider_state)[0] + + return components, block_state + + +# loop step (2): denoise the latents with guidance (with controlnet) +class StableDiffusionXLControlNetLoopDenoiser(ModularPipelineBlocks): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config", + ), + ComponentSpec("unet", UNet2DConditionModel), + ComponentSpec("controlnet", ControlNetModel), + ] + + @property + def description(self) -> str: + return ( + "step within the denoising loop that denoise the latents with guidance (with controlnet). " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> list[tuple[str, Any]]: + return [ + InputParam("cross_attention_kwargs"), + InputParam( + "controlnet_cond", + required=True, + type_hint=torch.Tensor, + description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step.", + ), + InputParam( + "conditioning_scale", + type_hint=float, + description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step.", + ), + InputParam( + "guess_mode", + required=True, + type_hint=bool, + description="The guess mode value to use for the denoising process. Can be generated in prepare_controlnet_inputs step.", + ), + InputParam( + "controlnet_keep", + required=True, + type_hint=list[float], + description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step.", + ), + InputParam( + "timestep_cond", + type_hint=torch.Tensor | None, + description="The guidance scale embedding to use for Latent Consistency Models(LCMs), can be generated by prepare_additional_conditioning step", + ), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", + ), + InputParam( + kwargs_type="denoiser_input_fields", + description=( + "All conditional model inputs that need to be prepared with guider. " + "It should contain prompt_embeds/negative_prompt_embeds, " + "add_time_ids/negative_add_time_ids, " + "pooled_prompt_embeds/negative_pooled_prompt_embeds, " + "and ip_adapter_embeds/negative_ip_adapter_embeds (optional)." + "please add `kwargs_type=denoiser_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state" + ), + ), + InputParam( + kwargs_type="controlnet_kwargs", + description=( + "additional kwargs for controlnet (e.g. control_type_idx and control_type from the controlnet union input step )" + "please add `kwargs_type=controlnet_kwargs` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state" + ), + ), + ] + + @staticmethod + def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): + accepted_kwargs = set(inspect.signature(func).parameters.keys()) + extra_kwargs = {} + for key, value in kwargs.items(): + if key in accepted_kwargs and key not in exclude_kwargs: + extra_kwargs[key] = value + + return extra_kwargs + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int): + extra_controlnet_kwargs = self.prepare_extra_kwargs( + components.controlnet.forward, **block_state.controlnet_kwargs + ) + + # Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds) + # to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds) + guider_inputs = { + "prompt_embeds": ( + getattr(block_state, "prompt_embeds", None), + getattr(block_state, "negative_prompt_embeds", None), + ), + "time_ids": ( + getattr(block_state, "add_time_ids", None), + getattr(block_state, "negative_add_time_ids", None), + ), + "text_embeds": ( + getattr(block_state, "pooled_prompt_embeds", None), + getattr(block_state, "negative_pooled_prompt_embeds", None), + ), + "image_embeds": ( + getattr(block_state, "ip_adapter_embeds", None), + getattr(block_state, "negative_ip_adapter_embeds", None), + ), + } + + # cond_scale for the timestep (controlnet input) + if isinstance(block_state.controlnet_keep[i], list): + block_state.cond_scale = [ + c * s for c, s in zip(block_state.conditioning_scale, block_state.controlnet_keep[i]) + ] + else: + controlnet_cond_scale = block_state.conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + block_state.cond_scale = controlnet_cond_scale * block_state.controlnet_keep[i] + + # default controlnet output/unet input for guess mode + conditional path + block_state.down_block_res_samples_zeros = None + block_state.mid_block_res_sample_zeros = None + + # guided denoiser step + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + + # The guider splits model inputs into separate batches for conditional/unconditional predictions. + # For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}: + # you will get a guider_state with two batches: + # guider_state = [ + # {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch + # {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch + # ] + # Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG). + guider_state = components.guider.prepare_inputs(guider_inputs) + + # run the denoiser for each guidance batch + for guider_state_batch in guider_state: + components.guider.prepare_models(components.unet) + + # Prepare additional conditionings + added_cond_kwargs = { + "text_embeds": guider_state_batch.text_embeds, + "time_ids": guider_state_batch.time_ids, + } + if hasattr(guider_state_batch, "image_embeds") and guider_state_batch.image_embeds is not None: + added_cond_kwargs["image_embeds"] = guider_state_batch.image_embeds + + # Prepare controlnet additional conditionings + controlnet_added_cond_kwargs = { + "text_embeds": guider_state_batch.text_embeds, + "time_ids": guider_state_batch.time_ids, + } + # run controlnet for the guidance batch + if block_state.guess_mode and not components.guider.is_conditional: + # guider always run uncond batch first, so these tensors should be set already + down_block_res_samples = block_state.down_block_res_samples_zeros + mid_block_res_sample = block_state.mid_block_res_sample_zeros + else: + down_block_res_samples, mid_block_res_sample = components.controlnet( + block_state.scaled_latents, + t, + encoder_hidden_states=guider_state_batch.prompt_embeds, + controlnet_cond=block_state.controlnet_cond, + conditioning_scale=block_state.cond_scale, + guess_mode=block_state.guess_mode, + added_cond_kwargs=controlnet_added_cond_kwargs, + return_dict=False, + **extra_controlnet_kwargs, + ) + + # assign it to block_state so it will be available for the uncond guidance batch + if block_state.down_block_res_samples_zeros is None: + block_state.down_block_res_samples_zeros = [torch.zeros_like(d) for d in down_block_res_samples] + if block_state.mid_block_res_sample_zeros is None: + block_state.mid_block_res_sample_zeros = torch.zeros_like(mid_block_res_sample) + + # Predict the noise + # store the noise_pred in guider_state_batch so we can apply guidance across all batches + guider_state_batch.noise_pred = components.unet( + block_state.scaled_latents, + t, + encoder_hidden_states=guider_state_batch.prompt_embeds, + timestep_cond=block_state.timestep_cond, + cross_attention_kwargs=block_state.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + return_dict=False, + )[0] + components.guider.cleanup_models(components.unet) + + # Perform guidance + block_state.noise_pred = components.guider(guider_state)[0] + + return components, block_state + + +# loop step (3): scheduler step to update latents +class StableDiffusionXLLoopAfterDenoiser(ModularPipelineBlocks): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return ( + "step within the denoising loop that update the latents. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> list[tuple[str, Any]]: + return [ + InputParam("eta", default=0.0), + InputParam("generator"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] + + # YiYi TODO: move this out of here + @staticmethod + def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): + accepted_kwargs = set(inspect.signature(func).parameters.keys()) + extra_kwargs = {} + for key, value in kwargs.items(): + if key in accepted_kwargs and key not in exclude_kwargs: + extra_kwargs[key] = value + + return extra_kwargs + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int): + # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + block_state.extra_step_kwargs = self.prepare_extra_kwargs( + components.scheduler.step, generator=block_state.generator, eta=block_state.eta + ) + + # Perform scheduler step using the predicted output + block_state.latents_dtype = block_state.latents.dtype + block_state.latents = components.scheduler.step( + block_state.noise_pred, + t, + block_state.latents, + **block_state.extra_step_kwargs, + return_dict=False, + )[0] + + if block_state.latents.dtype != block_state.latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + block_state.latents = block_state.latents.to(block_state.latents_dtype) + + return components, block_state + + +# loop step (3): scheduler step to update latents (with inpainting) +class StableDiffusionXLInpaintLoopAfterDenoiser(ModularPipelineBlocks): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ComponentSpec("unet", UNet2DConditionModel), + ] + + @property + def description(self) -> str: + return ( + "step within the denoising loop that update the latents (for inpainting workflow only). " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> list[tuple[str, Any]]: + return [ + InputParam("eta", default=0.0), + InputParam("generator"), + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", + ), + InputParam( + "mask", + type_hint=torch.Tensor | None, + description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step.", + ), + InputParam( + "noise", + type_hint=torch.Tensor | None, + description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step.", + ), + InputParam( + "image_latents", + type_hint=torch.Tensor | None, + description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step.", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] + + @staticmethod + def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): + accepted_kwargs = set(inspect.signature(func).parameters.keys()) + extra_kwargs = {} + for key, value in kwargs.items(): + if key in accepted_kwargs and key not in exclude_kwargs: + extra_kwargs[key] = value + + return extra_kwargs + + def check_inputs(self, components, block_state): + if components.num_channels_unet == 4: + if block_state.image_latents is None: + raise ValueError(f"image_latents is required for this step {self.__class__.__name__}") + if block_state.mask is None: + raise ValueError(f"mask is required for this step {self.__class__.__name__}") + if block_state.noise is None: + raise ValueError(f"noise is required for this step {self.__class__.__name__}") + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int): + self.check_inputs(components, block_state) + + # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + block_state.extra_step_kwargs = self.prepare_extra_kwargs( + components.scheduler.step, generator=block_state.generator, eta=block_state.eta + ) + + # Perform scheduler step using the predicted output + block_state.latents_dtype = block_state.latents.dtype + block_state.latents = components.scheduler.step( + block_state.noise_pred, + t, + block_state.latents, + **block_state.extra_step_kwargs, + return_dict=False, + )[0] + + if block_state.latents.dtype != block_state.latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + block_state.latents = block_state.latents.to(block_state.latents_dtype) + + # adjust latent for inpainting + if components.num_channels_unet == 4: + block_state.init_latents_proper = block_state.image_latents + if i < len(block_state.timesteps) - 1: + block_state.noise_timestep = block_state.timesteps[i + 1] + block_state.init_latents_proper = components.scheduler.add_noise( + block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) + ) + + block_state.latents = ( + 1 - block_state.mask + ) * block_state.init_latents_proper + block_state.mask * block_state.latents + + return components, block_state + + +# the loop wrapper that iterates over the timesteps +class StableDiffusionXLDenoiseLoopWrapper(LoopSequentialPipelineBlocks): + model_name = "stable-diffusion-xl" + + @property + def description(self) -> str: + return ( + "Pipeline block that iteratively denoise the latents over `timesteps`. " + "The specific steps with each iteration can be customized with `sub_blocks` attributes" + ) + + @property + def loop_expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config", + ), + ComponentSpec("scheduler", EulerDiscreteScheduler), + ComponentSpec("unet", UNet2DConditionModel), + ] + + @property + def loop_inputs(self) -> list[InputParam]: + return [ + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", + ), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", + ), + ] + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False + if block_state.disable_guidance: + components.guider.disable() + else: + components.guider.enable() + + block_state.num_warmup_steps = max( + len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0 + ) + + with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: + for i, t in enumerate(block_state.timesteps): + components, block_state = self.loop_step(components, block_state, i=i, t=t) + if i == len(block_state.timesteps) - 1 or ( + (i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0 + ): + progress_bar.update() + + self.set_block_state(state, block_state) + + return components, state + + +# composing the denoising loops +class StableDiffusionXLDenoiseStep(StableDiffusionXLDenoiseLoopWrapper): + block_classes = [ + StableDiffusionXLLoopBeforeDenoiser, + StableDiffusionXLLoopDenoiser, + StableDiffusionXLLoopAfterDenoiser, + ] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. \n" + "Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" + " - `StableDiffusionXLLoopBeforeDenoiser`\n" + " - `StableDiffusionXLLoopDenoiser`\n" + " - `StableDiffusionXLLoopAfterDenoiser`\n" + "This block supports both text2img and img2img tasks." + ) + + +# control_cond +class StableDiffusionXLControlNetDenoiseStep(StableDiffusionXLDenoiseLoopWrapper): + block_classes = [ + StableDiffusionXLLoopBeforeDenoiser, + StableDiffusionXLControlNetLoopDenoiser, + StableDiffusionXLLoopAfterDenoiser, + ] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents with controlnet. \n" + "Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" + " - `StableDiffusionXLLoopBeforeDenoiser`\n" + " - `StableDiffusionXLControlNetLoopDenoiser`\n" + " - `StableDiffusionXLLoopAfterDenoiser`\n" + "This block supports using controlnet for both text2img and img2img tasks." + ) + + +# mask +class StableDiffusionXLInpaintDenoiseStep(StableDiffusionXLDenoiseLoopWrapper): + block_classes = [ + StableDiffusionXLInpaintLoopBeforeDenoiser, + StableDiffusionXLLoopDenoiser, + StableDiffusionXLInpaintLoopAfterDenoiser, + ] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents(for inpainting task only). \n" + "Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" + " - `StableDiffusionXLInpaintLoopBeforeDenoiser`\n" + " - `StableDiffusionXLLoopDenoiser`\n" + " - `StableDiffusionXLInpaintLoopAfterDenoiser`\n" + "This block onlysupports inpainting tasks." + ) + + +# control_cond + mask +class StableDiffusionXLInpaintControlNetDenoiseStep(StableDiffusionXLDenoiseLoopWrapper): + block_classes = [ + StableDiffusionXLInpaintLoopBeforeDenoiser, + StableDiffusionXLControlNetLoopDenoiser, + StableDiffusionXLInpaintLoopAfterDenoiser, + ] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents(for inpainting task only) with controlnet. \n" + "Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" + " - `StableDiffusionXLInpaintLoopBeforeDenoiser`\n" + " - `StableDiffusionXLControlNetLoopDenoiser`\n" + " - `StableDiffusionXLInpaintLoopAfterDenoiser`\n" + "This block only supports using controlnet for inpainting tasks." + ) diff --git a/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py b/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py new file mode 100644 index 0000000000000000000000000000000000000000..8387ae7bd6b6d06bee9e1b9b0bf63a6f28191781 --- /dev/null +++ b/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py @@ -0,0 +1,885 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...utils import ( + USE_PEFT_BACKEND, + logging, + scale_lora_layers, + unscale_lora_layers, +) +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam +from .modular_pipeline import StableDiffusionXLModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class StableDiffusionXLIPAdapterStep(ModularPipelineBlocks): + model_name = "stable-diffusion-xl" + + @property + def description(self) -> str: + return ( + "IP Adapter step that prepares ip adapter image embeddings.\n" + "Note that this step only prepares the embeddings - in order for it to work correctly, " + "you need to load ip adapter weights into unet via ModularPipeline.load_ip_adapter() and pipeline.set_ip_adapter_scale().\n" + "See [ModularIPAdapterMixin](https://huggingface.co/docs/diffusers/api/loaders/ip_adapter#diffusers.loaders.ModularIPAdapterMixin)" + " for more details" + ) + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("image_encoder", CLIPVisionModelWithProjection), + ComponentSpec( + "feature_extractor", + CLIPImageProcessor, + config=FrozenDict({"size": 224, "crop_size": 224}), + default_creation_method="from_config", + ), + ComponentSpec("unet", UNet2DConditionModel), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + "ip_adapter_image", + PipelineImageInput, + required=True, + description="The image(s) to be used as ip adapter", + ) + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam("ip_adapter_embeds", type_hint=torch.Tensor, description="IP adapter image embeddings"), + OutputParam( + "negative_ip_adapter_embeds", + type_hint=torch.Tensor, + description="Negative IP adapter image embeddings", + ), + ] + + @staticmethod + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image with self->components + def encode_image(components, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(components.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = components.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = components.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = components.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = components.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, + components, + ip_adapter_image, + ip_adapter_image_embeds, + device, + num_images_per_prompt, + prepare_unconditional_embeds, + ): + image_embeds = [] + if prepare_unconditional_embeds: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(components.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(components.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, components.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + components, single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if prepare_unconditional_embeds: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if prepare_unconditional_embeds: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if prepare_unconditional_embeds: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1 + block_state.device = components._execution_device + + block_state.ip_adapter_embeds = self.prepare_ip_adapter_image_embeds( + components, + ip_adapter_image=block_state.ip_adapter_image, + ip_adapter_image_embeds=None, + device=block_state.device, + num_images_per_prompt=1, + prepare_unconditional_embeds=block_state.prepare_unconditional_embeds, + ) + if block_state.prepare_unconditional_embeds: + block_state.negative_ip_adapter_embeds = [] + for i, image_embeds in enumerate(block_state.ip_adapter_embeds): + negative_image_embeds, image_embeds = image_embeds.chunk(2) + block_state.negative_ip_adapter_embeds.append(negative_image_embeds) + block_state.ip_adapter_embeds[i] = image_embeds + + self.set_block_state(state, block_state) + return components, state + + +class StableDiffusionXLTextEncoderStep(ModularPipelineBlocks): + model_name = "stable-diffusion-xl" + + @property + def description(self) -> str: + return "Text Encoder step that generate text_embeddings to guide the image generation" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("text_encoder", CLIPTextModel), + ComponentSpec("text_encoder_2", CLIPTextModelWithProjection), + ComponentSpec("tokenizer", CLIPTokenizer), + ComponentSpec("tokenizer_2", CLIPTokenizer), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config", + ), + ] + + @property + def expected_configs(self) -> list[ConfigSpec]: + return [ConfigSpec("force_zeros_for_empty_prompt", True)] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("prompt"), + InputParam("prompt_2"), + InputParam("negative_prompt"), + InputParam("negative_prompt_2"), + InputParam("cross_attention_kwargs"), + InputParam("clip_skip"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="text embeddings used to guide the image generation", + ), + OutputParam( + "negative_prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="negative text embeddings used to guide the image generation", + ), + OutputParam( + "pooled_prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="pooled text embeddings used to guide the image generation", + ), + OutputParam( + "negative_pooled_prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="negative pooled text embeddings used to guide the image generation", + ), + ] + + @staticmethod + def check_inputs(block_state): + if block_state.prompt is not None and ( + not isinstance(block_state.prompt, str) and not isinstance(block_state.prompt, list) + ): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}") + elif block_state.prompt_2 is not None and ( + not isinstance(block_state.prompt_2, str) and not isinstance(block_state.prompt_2, list) + ): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(block_state.prompt_2)}") + + @staticmethod + def encode_prompt( + components, + prompt: str, + prompt_2: str | None = None, + device: torch.device | None = None, + num_images_per_prompt: int = 1, + prepare_unconditional_embeds: bool = True, + negative_prompt: str | None = None, + negative_prompt_2: str | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prepare_unconditional_embeds (`bool`): + whether to use prepare unconditional embeddings or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or components._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(components, StableDiffusionXLLoraLoaderMixin): + components._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if components.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(components.text_encoder, lora_scale) + else: + scale_lora_layers(components.text_encoder, lora_scale) + + if components.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(components.text_encoder_2, lora_scale) + else: + scale_lora_layers(components.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = ( + [components.tokenizer, components.tokenizer_2] + if components.tokenizer is not None + else [components.tokenizer_2] + ) + text_encoders = ( + [components.text_encoder, components.text_encoder_2] + if components.text_encoder is not None + else [components.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(components, TextualInversionLoaderMixin): + prompt = components.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and components.config.force_zeros_for_empty_prompt + if prepare_unconditional_embeds and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif prepare_unconditional_embeds and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: list[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(components, TextualInversionLoaderMixin): + negative_prompt = components.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if components.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=components.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=components.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if prepare_unconditional_embeds: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if components.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to( + dtype=components.text_encoder_2.dtype, device=device + ) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=components.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if prepare_unconditional_embeds: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if components.text_encoder is not None: + if isinstance(components, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(components.text_encoder, lora_scale) + + if components.text_encoder_2 is not None: + if isinstance(components, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(components.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: + # Get inputs and intermediates + block_state = self.get_block_state(state) + self.check_inputs(block_state) + + block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1 + block_state.device = components._execution_device + + # Encode input prompt + block_state.text_encoder_lora_scale = ( + block_state.cross_attention_kwargs.get("scale", None) + if block_state.cross_attention_kwargs is not None + else None + ) + ( + block_state.prompt_embeds, + block_state.negative_prompt_embeds, + block_state.pooled_prompt_embeds, + block_state.negative_pooled_prompt_embeds, + ) = self.encode_prompt( + components, + block_state.prompt, + block_state.prompt_2, + block_state.device, + 1, + block_state.prepare_unconditional_embeds, + block_state.negative_prompt, + block_state.negative_prompt_2, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + lora_scale=block_state.text_encoder_lora_scale, + clip_skip=block_state.clip_skip, + ) + # Add outputs + self.set_block_state(state, block_state) + return components, state + + +class StableDiffusionXLVaeEncoderStep(ModularPipelineBlocks): + model_name = "stable-diffusion-xl" + + @property + def description(self) -> str: + return "Vae Encoder step that encode the input image into a latent representation" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKL), + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("image", required=True), + InputParam("height"), + InputParam("width"), + InputParam("generator"), + InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), + InputParam( + "preprocess_kwargs", + type_hint=dict | None, + description="A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under `self.image_processor` in [diffusers.image_processor.VaeImageProcessor]", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "image_latents", + type_hint=torch.Tensor, + description="The latents representing the reference image for image-to-image/inpainting generation", + ) + ] + + # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components + # YiYi TODO: update the _encode_vae_image so that we can use #Coped from + def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator): + latents_mean = latents_std = None + if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: + latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: + latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) + + dtype = image.dtype + if components.vae.config.force_upcast: + image = image.float() + components.vae.to(dtype=torch.float32) + + if isinstance(generator, list): + image_latents = [ + retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(components.vae.encode(image), generator=generator) + + if components.vae.config.force_upcast: + components.vae.to(dtype) + + image_latents = image_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype) + latents_std = latents_std.to(device=image_latents.device, dtype=dtype) + image_latents = (image_latents - latents_mean) * components.vae.config.scaling_factor / latents_std + else: + image_latents = components.vae.config.scaling_factor * image_latents + + return image_latents + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.preprocess_kwargs = block_state.preprocess_kwargs or {} + block_state.device = components._execution_device + block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype + + image = components.image_processor.preprocess( + block_state.image, height=block_state.height, width=block_state.width, **block_state.preprocess_kwargs + ) + image = image.to(device=block_state.device, dtype=block_state.dtype) + block_state.batch_size = image.shape[0] + + # if generator is a list, make sure the length of it matches the length of images (both should be batch_size) + if isinstance(block_state.generator, list) and len(block_state.generator) != block_state.batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch" + f" size of {block_state.batch_size}. Make sure the batch size matches the length of the generators." + ) + + block_state.image_latents = self._encode_vae_image(components, image=image, generator=block_state.generator) + + self.set_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLInpaintVaeEncoderStep(ModularPipelineBlocks): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKL), + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8}), + default_creation_method="from_config", + ), + ComponentSpec( + "mask_processor", + VaeImageProcessor, + config=FrozenDict( + {"do_normalize": False, "vae_scale_factor": 8, "do_binarize": True, "do_convert_grayscale": True} + ), + default_creation_method="from_config", + ), + ] + + @property + def description(self) -> str: + return "Vae encoder step that prepares the image and mask for the inpainting process" + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("height"), + InputParam("width"), + InputParam("image", required=True), + InputParam("mask_image", required=True), + InputParam("padding_mask_crop"), + InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"), + InputParam("generator"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "image_latents", type_hint=torch.Tensor, description="The latents representation of the input image" + ), + OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for the inpainting process"), + OutputParam( + "masked_image_latents", + type_hint=torch.Tensor, + description="The masked image latents to use for the inpainting process (only for inpainting-specifid unet)", + ), + OutputParam( + "crops_coords", + type_hint=tuple[int, int] | None, + description="The crop coordinates to use for the preprocess/postprocess of the image and mask", + ), + ] + + # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components + # YiYi TODO: update the _encode_vae_image so that we can use #Coped from + def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator): + latents_mean = latents_std = None + if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: + latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: + latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) + + dtype = image.dtype + if components.vae.config.force_upcast: + image = image.float() + components.vae.to(dtype=torch.float32) + + if isinstance(generator, list): + image_latents = [ + retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(components.vae.encode(image), generator=generator) + + if components.vae.config.force_upcast: + components.vae.to(dtype) + + image_latents = image_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype) + latents_std = latents_std.to(device=image_latents.device, dtype=dtype) + image_latents = (image_latents - latents_mean) * self.vae.config.scaling_factor / latents_std + else: + image_latents = components.vae.config.scaling_factor * image_latents + + return image_latents + + # modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents + # do not accept do_classifier_free_guidance + def prepare_mask_latents( + self, components, mask, masked_image, batch_size, height, width, dtype, device, generator + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // components.vae_scale_factor, width // components.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + + if masked_image is not None and masked_image.shape[1] == 4: + masked_image_latents = masked_image + else: + masked_image_latents = None + + if masked_image is not None: + if masked_image_latents is None: + masked_image = masked_image.to(device=device, dtype=dtype) + masked_image_latents = self._encode_vae_image(components, masked_image, generator=generator) + + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat( + batch_size // masked_image_latents.shape[0], 1, 1, 1 + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + + return mask, masked_image_latents + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype + block_state.device = components._execution_device + + if block_state.height is None: + block_state.height = components.default_height + if block_state.width is None: + block_state.width = components.default_width + + if block_state.padding_mask_crop is not None: + block_state.crops_coords = components.mask_processor.get_crop_region( + block_state.mask_image, block_state.width, block_state.height, pad=block_state.padding_mask_crop + ) + block_state.resize_mode = "fill" + else: + block_state.crops_coords = None + block_state.resize_mode = "default" + + image = components.image_processor.preprocess( + block_state.image, + height=block_state.height, + width=block_state.width, + crops_coords=block_state.crops_coords, + resize_mode=block_state.resize_mode, + ) + image = image.to(dtype=torch.float32) + + mask = components.mask_processor.preprocess( + block_state.mask_image, + height=block_state.height, + width=block_state.width, + resize_mode=block_state.resize_mode, + crops_coords=block_state.crops_coords, + ) + block_state.masked_image = image * (mask < 0.5) + + block_state.batch_size = image.shape[0] + image = image.to(device=block_state.device, dtype=block_state.dtype) + block_state.image_latents = self._encode_vae_image(components, image=image, generator=block_state.generator) + + # 7. Prepare mask latent variables + block_state.mask, block_state.masked_image_latents = self.prepare_mask_latents( + components, + mask, + block_state.masked_image, + block_state.batch_size, + block_state.height, + block_state.width, + block_state.dtype, + block_state.device, + block_state.generator, + ) + + self.set_block_state(state, block_state) + + return components, state diff --git a/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks_stable_diffusion_xl.py b/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks_stable_diffusion_xl.py new file mode 100644 index 0000000000000000000000000000000000000000..a7a18e51477716619d43b1bd2933b86fed53bf88 --- /dev/null +++ b/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks_stable_diffusion_xl.py @@ -0,0 +1,512 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...utils import logging +from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks +from ..modular_pipeline_utils import OutputParam +from .before_denoise import ( + StableDiffusionXLControlNetInputStep, + StableDiffusionXLControlNetUnionInputStep, + StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep, + StableDiffusionXLImg2ImgPrepareLatentsStep, + StableDiffusionXLImg2ImgSetTimestepsStep, + StableDiffusionXLInpaintPrepareLatentsStep, + StableDiffusionXLInputStep, + StableDiffusionXLPrepareAdditionalConditioningStep, + StableDiffusionXLPrepareLatentsStep, + StableDiffusionXLSetTimestepsStep, +) +from .decoders import ( + StableDiffusionXLDecodeStep, + StableDiffusionXLInpaintOverlayMaskStep, +) +from .denoise import ( + StableDiffusionXLControlNetDenoiseStep, + StableDiffusionXLDenoiseStep, + StableDiffusionXLInpaintControlNetDenoiseStep, + StableDiffusionXLInpaintDenoiseStep, +) +from .encoders import ( + StableDiffusionXLInpaintVaeEncoderStep, + StableDiffusionXLIPAdapterStep, + StableDiffusionXLTextEncoderStep, + StableDiffusionXLVaeEncoderStep, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# auto blocks & sequential blocks & mappings + + +# vae encoder (run before before_denoise) +class StableDiffusionXLAutoVaeEncoderStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLInpaintVaeEncoderStep, StableDiffusionXLVaeEncoderStep] + block_names = ["inpaint", "img2img"] + block_trigger_inputs = ["mask_image", "image"] + + @property + def description(self): + return ( + "Vae encoder step that encode the image inputs into their latent representations.\n" + + "This is an auto pipeline block that works for both inpainting and img2img tasks.\n" + + " - `StableDiffusionXLInpaintVaeEncoderStep` (inpaint) is used when `mask_image` is provided.\n" + + " - `StableDiffusionXLVaeEncoderStep` (img2img) is used when only `image` is provided." + + " - if neither `mask_image` nor `image` is provided, step will be skipped." + ) + + +# optional ip-adapter (run before input step) +class StableDiffusionXLAutoIPAdapterStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLIPAdapterStep] + block_names = ["ip_adapter"] + block_trigger_inputs = ["ip_adapter_image"] + + @property + def description(self): + return "Run IP Adapter step if `ip_adapter_image` is provided. This step should be placed before the 'input' step.\n" + + +# before_denoise: text2img +class StableDiffusionXLBeforeDenoiseStep(SequentialPipelineBlocks): + block_classes = [ + StableDiffusionXLSetTimestepsStep, + StableDiffusionXLPrepareLatentsStep, + StableDiffusionXLPrepareAdditionalConditioningStep, + ] + block_names = ["set_timesteps", "prepare_latents", "prepare_add_cond"] + + @property + def description(self): + return ( + "Before denoise step that prepare the inputs for the denoise step.\n" + + "This is a sequential pipeline blocks:\n" + + " - `StableDiffusionXLSetTimestepsStep` is used to set the timesteps\n" + + " - `StableDiffusionXLPrepareLatentsStep` is used to prepare the latents\n" + + " - `StableDiffusionXLPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n" + ) + + +# before_denoise: img2img +class StableDiffusionXLImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks): + block_classes = [ + StableDiffusionXLImg2ImgSetTimestepsStep, + StableDiffusionXLImg2ImgPrepareLatentsStep, + StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep, + ] + block_names = ["set_timesteps", "prepare_latents", "prepare_add_cond"] + + @property + def description(self): + return ( + "Before denoise step that prepare the inputs for the denoise step for img2img task.\n" + + "This is a sequential pipeline blocks:\n" + + " - `StableDiffusionXLImg2ImgSetTimestepsStep` is used to set the timesteps\n" + + " - `StableDiffusionXLImg2ImgPrepareLatentsStep` is used to prepare the latents\n" + + " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n" + ) + + +# before_denoise: inpainting +class StableDiffusionXLInpaintBeforeDenoiseStep(SequentialPipelineBlocks): + block_classes = [ + StableDiffusionXLImg2ImgSetTimestepsStep, + StableDiffusionXLInpaintPrepareLatentsStep, + StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep, + ] + block_names = ["set_timesteps", "prepare_latents", "prepare_add_cond"] + + @property + def description(self): + return ( + "Before denoise step that prepare the inputs for the denoise step for inpainting task.\n" + + "This is a sequential pipeline blocks:\n" + + " - `StableDiffusionXLImg2ImgSetTimestepsStep` is used to set the timesteps\n" + + " - `StableDiffusionXLInpaintPrepareLatentsStep` is used to prepare the latents\n" + + " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n" + ) + + +# before_denoise: all task (text2img, img2img, inpainting) +class StableDiffusionXLAutoBeforeDenoiseStep(AutoPipelineBlocks): + block_classes = [ + StableDiffusionXLInpaintBeforeDenoiseStep, + StableDiffusionXLImg2ImgBeforeDenoiseStep, + StableDiffusionXLBeforeDenoiseStep, + ] + block_names = ["inpaint", "img2img", "text2img"] + block_trigger_inputs = ["mask", "image_latents", None] + + @property + def description(self): + return ( + "Before denoise step that prepare the inputs for the denoise step.\n" + + "This is an auto pipeline block that works for text2img, img2img and inpainting tasks as well as controlnet, controlnet_union.\n" + + " - `StableDiffusionXLInpaintBeforeDenoiseStep` (inpaint) is used when both `mask` and `image_latents` are provided.\n" + + " - `StableDiffusionXLImg2ImgBeforeDenoiseStep` (img2img) is used when only `image_latents` is provided.\n" + + " - `StableDiffusionXLBeforeDenoiseStep` (text2img) is used when both `image_latents` and `mask` are not provided.\n" + ) + + +# optional controlnet input step (after before_denoise, before denoise) +# works for both controlnet and controlnet_union +class StableDiffusionXLAutoControlNetInputStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLControlNetUnionInputStep, StableDiffusionXLControlNetInputStep] + block_names = ["controlnet_union", "controlnet"] + block_trigger_inputs = ["control_mode", "control_image"] + + @property + def description(self): + return ( + "Controlnet Input step that prepare the controlnet input.\n" + + "This is an auto pipeline block that works for both controlnet and controlnet_union.\n" + + " (it should be called right before the denoise step)" + + " - `StableDiffusionXLControlNetUnionInputStep` is called to prepare the controlnet input when `control_mode` and `control_image` are provided.\n" + + " - `StableDiffusionXLControlNetInputStep` is called to prepare the controlnet input when `control_image` is provided." + + " - if neither `control_mode` nor `control_image` is provided, step will be skipped." + ) + + +# denoise: controlnet (text2img, img2img, inpainting) +class StableDiffusionXLAutoControlNetDenoiseStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLInpaintControlNetDenoiseStep, StableDiffusionXLControlNetDenoiseStep] + block_names = ["inpaint_controlnet_denoise", "controlnet_denoise"] + block_trigger_inputs = ["mask", "controlnet_cond"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents with controlnet. " + "This is a auto pipeline block that using controlnet for text2img, img2img and inpainting tasks." + "This block should not be used without a controlnet_cond input" + " - `StableDiffusionXLInpaintControlNetDenoiseStep` (inpaint_controlnet_denoise) is used when mask is provided." + " - `StableDiffusionXLControlNetDenoiseStep` (controlnet_denoise) is used when mask is not provided but controlnet_cond is provided." + " - If neither mask nor controlnet_cond are provided, step will be skipped." + ) + + +# denoise: all task with or without controlnet (text2img, img2img, inpainting) +class StableDiffusionXLAutoDenoiseStep(AutoPipelineBlocks): + block_classes = [ + StableDiffusionXLAutoControlNetDenoiseStep, + StableDiffusionXLInpaintDenoiseStep, + StableDiffusionXLDenoiseStep, + ] + block_names = ["controlnet_denoise", "inpaint_denoise", "denoise"] + block_trigger_inputs = ["controlnet_cond", "mask", None] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. " + "This is a auto pipeline block that works for text2img, img2img and inpainting tasks. And can be used with or without controlnet." + " - `StableDiffusionXLAutoControlNetDenoiseStep` (controlnet_denoise) is used when controlnet_cond is provided (support controlnet withtext2img, img2img and inpainting tasks)." + " - `StableDiffusionXLInpaintDenoiseStep` (inpaint_denoise) is used when mask is provided (support inpainting tasks)." + " - `StableDiffusionXLDenoiseStep` (denoise) is used when neither mask nor controlnet_cond are provided (support text2img and img2img tasks)." + ) + + +# decode: inpaint +class StableDiffusionXLInpaintDecodeStep(SequentialPipelineBlocks): + block_classes = [StableDiffusionXLDecodeStep, StableDiffusionXLInpaintOverlayMaskStep] + block_names = ["decode", "mask_overlay"] + + @property + def description(self): + return ( + "Inpaint decode step that decode the denoised latents into images outputs.\n" + + "This is a sequential pipeline blocks:\n" + + " - `StableDiffusionXLDecodeStep` is used to decode the denoised latents into images\n" + + " - `StableDiffusionXLInpaintOverlayMaskStep` is used to overlay the mask on the image" + ) + + +# decode: all task (text2img, img2img, inpainting) +class StableDiffusionXLAutoDecodeStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLInpaintDecodeStep, StableDiffusionXLDecodeStep] + block_names = ["inpaint", "non-inpaint"] + block_trigger_inputs = ["padding_mask_crop", None] + + @property + def description(self): + return ( + "Decode step that decode the denoised latents into images outputs.\n" + + "This is an auto pipeline block that works for inpainting and non-inpainting tasks.\n" + + " - `StableDiffusionXLInpaintDecodeStep` (inpaint) is used when `padding_mask_crop` is provided.\n" + + " - `StableDiffusionXLDecodeStep` (non-inpaint) is used when `padding_mask_crop` is not provided." + ) + + +class StableDiffusionXLCoreDenoiseStep(SequentialPipelineBlocks): + block_classes = [ + StableDiffusionXLInputStep, + StableDiffusionXLAutoBeforeDenoiseStep, + StableDiffusionXLAutoControlNetInputStep, + StableDiffusionXLAutoDenoiseStep, + ] + block_names = ["input", "before_denoise", "controlnet_input", "denoise"] + + @property + def description(self): + return ( + "Core step that performs the denoising process. \n" + + " - `StableDiffusionXLInputStep` (input) standardizes the inputs for the denoising step.\n" + + " - `StableDiffusionXLAutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n" + + " - `StableDiffusionXLAutoControlNetInputStep` (controlnet_input) prepares the controlnet input.\n" + + " - `StableDiffusionXLAutoDenoiseStep` (denoise) iteratively denoises the latents.\n\n" + + "This step support text-to-image, image-to-image, inpainting, with or without controlnet/controlnet_union/ip_adapter for Stable Diffusion XL:\n" + + "- for image-to-image generation, you need to provide `image_latents`\n" + + "- for inpainting, you need to provide `mask_image` and `image_latents`\n" + + "- to run the controlnet workflow, you need to provide `control_image`\n" + + "- to run the controlnet_union workflow, you need to provide `control_image` and `control_mode`\n" + + "- to run the ip_adapter workflow, you need to load ip_adapter into your unet and provide `ip_adapter_embeds`\n" + + "- for text-to-image generation, all you need to provide is prompt embeddings\n" + ) + + +# ip-adapter, controlnet, text2img, img2img, inpainting +# auto_docstring +class StableDiffusionXLAutoBlocks(SequentialPipelineBlocks): + """ + Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using Stable Diffusion + XL. + + Supported workflows: + - `text2image`: requires `prompt` + - `image2image`: requires `image`, `prompt` + - `inpainting`: requires `mask_image`, `image`, `prompt` + - `controlnet_text2image`: requires `control_image`, `prompt` + - `controlnet_image2image`: requires `control_image`, `image`, `prompt` + - `controlnet_inpainting`: requires `control_image`, `mask_image`, `image`, `prompt` + - `controlnet_union_text2image`: requires `control_image`, `control_mode`, `prompt` + - `controlnet_union_image2image`: requires `control_image`, `control_mode`, `image`, `prompt` + - `controlnet_union_inpainting`: requires `control_image`, `control_mode`, `mask_image`, `image`, `prompt` + - `ip_adapter_text2image`: requires `ip_adapter_image`, `prompt` + - `ip_adapter_image2image`: requires `ip_adapter_image`, `image`, `prompt` + - `ip_adapter_inpainting`: requires `ip_adapter_image`, `mask_image`, `image`, `prompt` + - `ip_adapter_controlnet_text2image`: requires `ip_adapter_image`, `control_image`, `prompt` + - `ip_adapter_controlnet_image2image`: requires `ip_adapter_image`, `control_image`, `image`, `prompt` + - `ip_adapter_controlnet_inpainting`: requires `ip_adapter_image`, `control_image`, `mask_image`, `image`, + `prompt` + - `ip_adapter_controlnet_union_text2image`: requires `ip_adapter_image`, `control_image`, `control_mode`, + `prompt` + - `ip_adapter_controlnet_union_image2image`: requires `ip_adapter_image`, `control_image`, `control_mode`, + `image`, `prompt` + - `ip_adapter_controlnet_union_inpainting`: requires `ip_adapter_image`, `control_image`, `control_mode`, + `mask_image`, `image`, `prompt` + + Components: + text_encoder (`CLIPTextModel`) text_encoder_2 (`CLIPTextModelWithProjection`) tokenizer (`CLIPTokenizer`) + tokenizer_2 (`CLIPTokenizer`) guider (`ClassifierFreeGuidance`) image_encoder + (`CLIPVisionModelWithProjection`) feature_extractor (`CLIPImageProcessor`) unet (`UNet2DConditionModel`) vae + (`AutoencoderKL`) image_processor (`VaeImageProcessor`) mask_processor (`VaeImageProcessor`) scheduler + (`EulerDiscreteScheduler`) controlnet (`ControlNetUnionModel`) control_image_processor (`VaeImageProcessor`) + + Configs: + force_zeros_for_empty_prompt (default: True) requires_aesthetics_score (default: False) + + Inputs: + prompt (`None`, *optional*): + TODO: Add description. + prompt_2 (`None`, *optional*): + TODO: Add description. + negative_prompt (`None`, *optional*): + TODO: Add description. + negative_prompt_2 (`None`, *optional*): + TODO: Add description. + cross_attention_kwargs (`None`, *optional*): + TODO: Add description. + clip_skip (`None`, *optional*): + TODO: Add description. + ip_adapter_image (`Image | ndarray | Tensor | list | list | list`, *optional*): + The image(s) to be used as ip adapter + height (`None`, *optional*): + TODO: Add description. + width (`None`, *optional*): + TODO: Add description. + image (`None`, *optional*): + TODO: Add description. + mask_image (`None`, *optional*): + TODO: Add description. + padding_mask_crop (`None`, *optional*): + TODO: Add description. + dtype (`dtype`, *optional*): + The dtype of the model inputs + generator (`None`, *optional*): + TODO: Add description. + preprocess_kwargs (`dict | NoneType`, *optional*): + A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under + `self.image_processor` in [diffusers.image_processor.VaeImageProcessor] + num_images_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + ip_adapter_embeds (`list`, *optional*): + Pre-generated image embeddings for IP-Adapter. Can be generated from ip_adapter step. + negative_ip_adapter_embeds (`list`, *optional*): + Pre-generated negative image embeddings for IP-Adapter. Can be generated from ip_adapter step. + num_inference_steps (`None`, *optional*, defaults to 50): + TODO: Add description. + timesteps (`None`, *optional*): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + denoising_end (`None`, *optional*): + TODO: Add description. + strength (`None`, *optional*, defaults to 0.3): + TODO: Add description. + denoising_start (`None`, *optional*): + TODO: Add description. + latents (`None`): + TODO: Add description. + image_latents (`Tensor`, *optional*): + The latents representing the reference image for image-to-image/inpainting generation. Can be generated + in vae_encode step. + mask (`Tensor`, *optional*): + The mask for the inpainting generation. Can be generated in vae_encode step. + masked_image_latents (`Tensor`, *optional*): + The masked image latents for the inpainting generation (only for inpainting-specific unet). Can be + generated in vae_encode step. + original_size (`None`, *optional*): + TODO: Add description. + target_size (`None`, *optional*): + TODO: Add description. + negative_original_size (`None`, *optional*): + TODO: Add description. + negative_target_size (`None`, *optional*): + TODO: Add description. + crops_coords_top_left (`None`, *optional*, defaults to (0, 0)): + TODO: Add description. + negative_crops_coords_top_left (`None`, *optional*, defaults to (0, 0)): + TODO: Add description. + aesthetic_score (`None`, *optional*, defaults to 6.0): + TODO: Add description. + negative_aesthetic_score (`None`, *optional*, defaults to 2.0): + TODO: Add description. + control_image (`None`, *optional*): + TODO: Add description. + control_mode (`None`, *optional*): + TODO: Add description. + control_guidance_start (`None`, *optional*, defaults to 0.0): + TODO: Add description. + control_guidance_end (`None`, *optional*, defaults to 1.0): + TODO: Add description. + controlnet_conditioning_scale (`None`, *optional*, defaults to 1.0): + TODO: Add description. + guess_mode (`None`, *optional*, defaults to False): + TODO: Add description. + crops_coords (`tuple | NoneType`, *optional*): + The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can + be generated in vae_encode step. + controlnet_cond (`Tensor`, *optional*): + The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step. + conditioning_scale (`float`, *optional*): + The controlnet conditioning scale value to use for the denoising process. Can be generated in + prepare_controlnet_inputs step. + controlnet_keep (`list`, *optional*): + The controlnet keep values to use for the denoising process. Can be generated in + prepare_controlnet_inputs step. + **denoiser_input_fields (`None`, *optional*): + All conditional model inputs that need to be prepared with guider. It should contain + prompt_embeds/negative_prompt_embeds, add_time_ids/negative_add_time_ids, + pooled_prompt_embeds/negative_pooled_prompt_embeds, and ip_adapter_embeds/negative_ip_adapter_embeds + (optional).please add `kwargs_type=denoiser_input_fields` to their parameter spec (`OutputParam`) when + they are created and added to the pipeline state + eta (`None`, *optional*, defaults to 0.0): + TODO: Add description. + output_type (`None`, *optional*, defaults to pil): + TODO: Add description. + + Outputs: + images (`list`): + Generated images. + """ + + block_classes = [ + StableDiffusionXLTextEncoderStep, + StableDiffusionXLAutoIPAdapterStep, + StableDiffusionXLAutoVaeEncoderStep, + StableDiffusionXLCoreDenoiseStep, + StableDiffusionXLAutoDecodeStep, + ] + block_names = [ + "text_encoder", + "ip_adapter", + "vae_encoder", + "denoise", + "decode", + ] + + _workflow_map = { + "text2image": {"prompt": True}, + "image2image": {"image": True, "prompt": True}, + "inpainting": {"mask_image": True, "image": True, "prompt": True}, + "controlnet_text2image": {"control_image": True, "prompt": True}, + "controlnet_image2image": {"control_image": True, "image": True, "prompt": True}, + "controlnet_inpainting": {"control_image": True, "mask_image": True, "image": True, "prompt": True}, + "controlnet_union_text2image": {"control_image": True, "control_mode": True, "prompt": True}, + "controlnet_union_image2image": {"control_image": True, "control_mode": True, "image": True, "prompt": True}, + "controlnet_union_inpainting": { + "control_image": True, + "control_mode": True, + "mask_image": True, + "image": True, + "prompt": True, + }, + "ip_adapter_text2image": {"ip_adapter_image": True, "prompt": True}, + "ip_adapter_image2image": {"ip_adapter_image": True, "image": True, "prompt": True}, + "ip_adapter_inpainting": {"ip_adapter_image": True, "mask_image": True, "image": True, "prompt": True}, + "ip_adapter_controlnet_text2image": {"ip_adapter_image": True, "control_image": True, "prompt": True}, + "ip_adapter_controlnet_image2image": { + "ip_adapter_image": True, + "control_image": True, + "image": True, + "prompt": True, + }, + "ip_adapter_controlnet_inpainting": { + "ip_adapter_image": True, + "control_image": True, + "mask_image": True, + "image": True, + "prompt": True, + }, + "ip_adapter_controlnet_union_text2image": { + "ip_adapter_image": True, + "control_image": True, + "control_mode": True, + "prompt": True, + }, + "ip_adapter_controlnet_union_image2image": { + "ip_adapter_image": True, + "control_image": True, + "control_mode": True, + "image": True, + "prompt": True, + }, + "ip_adapter_controlnet_union_inpainting": { + "ip_adapter_image": True, + "control_image": True, + "control_mode": True, + "mask_image": True, + "image": True, + "prompt": True, + }, + } + + @property + def description(self): + return "Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using Stable Diffusion XL." + + @property + def outputs(self): + return [OutputParam.template("images")] diff --git a/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py b/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..209e2b11814f8e1242db1de79e79977add1ad7a5 --- /dev/null +++ b/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py @@ -0,0 +1,360 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import PIL +import torch + +from ...image_processor import PipelineImageInput +from ...loaders import ModularIPAdapterMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin +from ...pipelines.pipeline_utils import StableDiffusionMixin +from ...pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput +from ...utils import logging +from ..modular_pipeline import ModularPipeline +from ..modular_pipeline_utils import InputParam, OutputParam + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# YiYi TODO: move to a different file? stable_diffusion_xl_module should have its own folder? +# YiYi Notes: model specific components: +## (1) it should inherit from ModularPipeline +## (2) acts like a container that holds components and configs +## (3) define default config (related to components), e.g. default_sample_size, vae_scale_factor, num_channels_unet, num_channels_latents +## (4) inherit from model-specic loader class (e.g. StableDiffusionXLLoraLoaderMixin) +## (5) how to use together with Components_manager? +class StableDiffusionXLModularPipeline( + ModularPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionXLLoraLoaderMixin, + ModularIPAdapterMixin, +): + """ + A ModularPipeline for Stable Diffusion XL. + + > [!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "StableDiffusionXLAutoBlocks" + + @property + def default_height(self): + return self.default_sample_size * self.vae_scale_factor + + @property + def default_width(self): + return self.default_sample_size * self.vae_scale_factor + + @property + def default_sample_size(self): + default_sample_size = 128 + if hasattr(self, "unet") and self.unet is not None: + default_sample_size = self.unet.config.sample_size + return default_sample_size + + @property + def vae_scale_factor(self): + vae_scale_factor = 8 + if hasattr(self, "vae") and self.vae is not None: + vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + return vae_scale_factor + + # YiYi TODO: change to num_channels_latents + @property + def num_channels_unet(self): + num_channels_unet = 4 + if hasattr(self, "unet") and self.unet is not None: + num_channels_unet = self.unet.config.in_channels + return num_channels_unet + + @property + def num_channels_latents(self): + num_channels_latents = 4 + if hasattr(self, "vae") and self.vae is not None: + num_channels_latents = self.vae.config.latent_channels + return num_channels_latents + + +# YiYi/Sayak TODO: not used yet, maintain a list of schema that can be used across all pipeline blocks +# auto_docstring +SDXL_INPUTS_SCHEMA = { + "prompt": InputParam( + "prompt", type_hint=str | list[str], description="The prompt or prompts to guide the image generation" + ), + "prompt_2": InputParam( + "prompt_2", + type_hint=str | list[str], + description="The prompt or prompts to be sent to the tokenizer_2 and text_encoder_2", + ), + "negative_prompt": InputParam( + "negative_prompt", + type_hint=str | list[str], + description="The prompt or prompts not to guide the image generation", + ), + "negative_prompt_2": InputParam( + "negative_prompt_2", + type_hint=str | list[str], + description="The negative prompt or prompts for text_encoder_2", + ), + "cross_attention_kwargs": InputParam( + "cross_attention_kwargs", + type_hint=dict | None, + description="Kwargs dictionary passed to the AttentionProcessor", + ), + "clip_skip": InputParam( + "clip_skip", type_hint=int | None, description="Number of layers to skip in CLIP text encoder" + ), + "image": InputParam( + "image", + type_hint=PipelineImageInput, + required=True, + description="The image(s) to modify for img2img or inpainting", + ), + "mask_image": InputParam( + "mask_image", + type_hint=PipelineImageInput, + required=True, + description="Mask image for inpainting, white pixels will be repainted", + ), + "generator": InputParam( + "generator", + type_hint=torch.Generator | list[torch.Generator] | None, + description="Generator(s) for deterministic generation", + ), + "height": InputParam("height", type_hint=int | None, description="Height in pixels of the generated image"), + "width": InputParam("width", type_hint=int | None, description="Width in pixels of the generated image"), + "num_images_per_prompt": InputParam( + "num_images_per_prompt", type_hint=int, default=1, description="Number of images to generate per prompt" + ), + "num_inference_steps": InputParam( + "num_inference_steps", type_hint=int, default=50, description="Number of denoising steps" + ), + "timesteps": InputParam( + "timesteps", type_hint=torch.Tensor | None, description="Custom timesteps for the denoising process" + ), + "sigmas": InputParam( + "sigmas", type_hint=torch.Tensor | None, description="Custom sigmas for the denoising process" + ), + "denoising_end": InputParam( + "denoising_end", + type_hint=float | None, + description="Fraction of denoising process to complete before termination", + ), + # YiYi Notes: img2img defaults to 0.3, inpainting defaults to 0.9999 + "strength": InputParam( + "strength", type_hint=float, default=0.3, description="How much to transform the reference image" + ), + "denoising_start": InputParam( + "denoising_start", type_hint=float | None, description="Starting point of the denoising process" + ), + "latents": InputParam( + "latents", type_hint=torch.Tensor | None, description="Pre-generated noisy latents for image generation" + ), + "padding_mask_crop": InputParam( + "padding_mask_crop", + type_hint=tuple[int, int] | None, + description="Size of margin in crop for image and mask", + ), + "original_size": InputParam( + "original_size", + type_hint=tuple[int, int] | None, + description="Original size of the image for SDXL's micro-conditioning", + ), + "target_size": InputParam( + "target_size", type_hint=tuple[int, int] | None, description="Target size for SDXL's micro-conditioning" + ), + "negative_original_size": InputParam( + "negative_original_size", + type_hint=tuple[int, int] | None, + description="Negative conditioning based on image resolution", + ), + "negative_target_size": InputParam( + "negative_target_size", + type_hint=tuple[int, int] | None, + description="Negative conditioning based on target resolution", + ), + "crops_coords_top_left": InputParam( + "crops_coords_top_left", + type_hint=tuple[int, int], + default=(0, 0), + description="Top-left coordinates for SDXL's micro-conditioning", + ), + "negative_crops_coords_top_left": InputParam( + "negative_crops_coords_top_left", + type_hint=tuple[int, int], + default=(0, 0), + description="Negative conditioning crop coordinates", + ), + "aesthetic_score": InputParam( + "aesthetic_score", type_hint=float, default=6.0, description="Simulates aesthetic score of generated image" + ), + "negative_aesthetic_score": InputParam( + "negative_aesthetic_score", type_hint=float, default=2.0, description="Simulates negative aesthetic score" + ), + "eta": InputParam("eta", type_hint=float, default=0.0, description="Parameter η in the DDIM paper"), + "output_type": InputParam( + "output_type", type_hint=str, default="pil", description="Output format (pil/tensor/np.array)" + ), + "ip_adapter_image": InputParam( + "ip_adapter_image", + type_hint=PipelineImageInput, + required=True, + description="Image(s) to be used as IP adapter", + ), + "control_image": InputParam( + "control_image", type_hint=PipelineImageInput, required=True, description="ControlNet input condition" + ), + "control_guidance_start": InputParam( + "control_guidance_start", + type_hint=float | list[float], + default=0.0, + description="When ControlNet starts applying", + ), + "control_guidance_end": InputParam( + "control_guidance_end", + type_hint=float | list[float], + default=1.0, + description="When ControlNet stops applying", + ), + "controlnet_conditioning_scale": InputParam( + "controlnet_conditioning_scale", + type_hint=float | list[float], + default=1.0, + description="Scale factor for ControlNet outputs", + ), + "guess_mode": InputParam( + "guess_mode", + type_hint=bool, + default=False, + description="Enables ControlNet encoder to recognize input without prompts", + ), + "control_mode": InputParam( + "control_mode", type_hint=list[int], required=True, description="Control mode for union controlnet" + ), + "prompt_embeds": InputParam( + "prompt_embeds", + type_hint=torch.Tensor, + required=True, + description="Text embeddings used to guide image generation", + ), + "negative_prompt_embeds": InputParam( + "negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings" + ), + "pooled_prompt_embeds": InputParam( + "pooled_prompt_embeds", type_hint=torch.Tensor, required=True, description="Pooled text embeddings" + ), + "negative_pooled_prompt_embeds": InputParam( + "negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings" + ), + "batch_size": InputParam("batch_size", type_hint=int, required=True, description="Number of prompts"), + "dtype": InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), + "preprocess_kwargs": InputParam( + "preprocess_kwargs", type_hint=dict | None, description="Kwargs for ImageProcessor" + ), + "latent_timestep": InputParam( + "latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep" + ), + "image_latents": InputParam( + "image_latents", type_hint=torch.Tensor, required=True, description="Latents representing reference image" + ), + "mask": InputParam("mask", type_hint=torch.Tensor, required=True, description="Mask for inpainting"), + "masked_image_latents": InputParam( + "masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting" + ), + "add_time_ids": InputParam( + "add_time_ids", type_hint=torch.Tensor, required=True, description="Time ids for conditioning" + ), + "negative_add_time_ids": InputParam( + "negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids" + ), + "timestep_cond": InputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"), + "noise": InputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"), + "crops_coords": InputParam("crops_coords", type_hint=tuple[int] | None, description="Crop coordinates"), + "ip_adapter_embeds": InputParam( + "ip_adapter_embeds", type_hint=list[torch.Tensor], description="Image embeddings for IP-Adapter" + ), + "negative_ip_adapter_embeds": InputParam( + "negative_ip_adapter_embeds", + type_hint=list[torch.Tensor], + description="Negative image embeddings for IP-Adapter", + ), + "images": InputParam( + "images", + type_hint=list[PIL.Image.Image] | list[torch.Tensor] | list[np.array], + required=True, + description="Generated images", + ), +} + + +SDXL_INTERMEDIATE_OUTPUTS_SCHEMA = { + "prompt_embeds": OutputParam( + "prompt_embeds", type_hint=torch.Tensor, description="Text embeddings used to guide image generation" + ), + "negative_prompt_embeds": OutputParam( + "negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings" + ), + "pooled_prompt_embeds": OutputParam( + "pooled_prompt_embeds", type_hint=torch.Tensor, description="Pooled text embeddings" + ), + "negative_pooled_prompt_embeds": OutputParam( + "negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings" + ), + "batch_size": OutputParam("batch_size", type_hint=int, description="Number of prompts"), + "dtype": OutputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), + "image_latents": OutputParam( + "image_latents", type_hint=torch.Tensor, description="Latents representing reference image" + ), + "mask": OutputParam("mask", type_hint=torch.Tensor, description="Mask for inpainting"), + "masked_image_latents": OutputParam( + "masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting" + ), + "crops_coords": OutputParam("crops_coords", type_hint=tuple[int] | None, description="Crop coordinates"), + "timesteps": OutputParam("timesteps", type_hint=torch.Tensor, description="Timesteps for inference"), + "num_inference_steps": OutputParam("num_inference_steps", type_hint=int, description="Number of denoising steps"), + "latent_timestep": OutputParam( + "latent_timestep", type_hint=torch.Tensor, description="Initial noise level timestep" + ), + "add_time_ids": OutputParam("add_time_ids", type_hint=torch.Tensor, description="Time ids for conditioning"), + "negative_add_time_ids": OutputParam( + "negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids" + ), + "timestep_cond": OutputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"), + "latents": OutputParam("latents", type_hint=torch.Tensor, description="Denoised latents"), + "noise": OutputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"), + "ip_adapter_embeds": OutputParam( + "ip_adapter_embeds", type_hint=list[torch.Tensor], description="Image embeddings for IP-Adapter" + ), + "negative_ip_adapter_embeds": OutputParam( + "negative_ip_adapter_embeds", + type_hint=list[torch.Tensor], + description="Negative image embeddings for IP-Adapter", + ), + "images": OutputParam( + "images", + type_hint=list[PIL.Image.Image] | list[torch.Tensor] | list[np.array], + description="Generated images", + ), +} + + +SDXL_OUTPUTS_SCHEMA = { + "images": OutputParam( + "images", + type_hint=tuple[list[PIL.Image.Image] | list[torch.Tensor] | list[np.array]] | StableDiffusionXLPipelineOutput, + description="The final generated images", + ) +} diff --git a/diffusers/modular_pipelines/wan/__init__.py b/diffusers/modular_pipelines/wan/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..284b6c9fa436619460ce456dab54e6fd63f7906f --- /dev/null +++ b/diffusers/modular_pipelines/wan/__init__.py @@ -0,0 +1,63 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["modular_blocks_wan"] = ["WanBlocks"] + _import_structure["modular_blocks_wan22"] = ["Wan22Blocks"] + _import_structure["modular_blocks_wan22_i2v"] = ["Wan22Image2VideoBlocks"] + _import_structure["modular_blocks_wan_i2v"] = ["WanImage2VideoAutoBlocks"] + _import_structure["modular_pipeline"] = [ + "Wan22Image2VideoModularPipeline", + "Wan22ModularPipeline", + "WanImage2VideoModularPipeline", + "WanModularPipeline", + ] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .modular_blocks_wan import WanBlocks + from .modular_blocks_wan22 import Wan22Blocks + from .modular_blocks_wan22_i2v import Wan22Image2VideoBlocks + from .modular_blocks_wan_i2v import WanImage2VideoAutoBlocks + from .modular_pipeline import ( + Wan22Image2VideoModularPipeline, + Wan22ModularPipeline, + WanImage2VideoModularPipeline, + WanModularPipeline, + ) +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/modular_pipelines/wan/before_denoise.py b/diffusers/modular_pipelines/wan/before_denoise.py new file mode 100644 index 0000000000000000000000000000000000000000..398b9665522c8877817c88ab4aeb034d5c4f3464 --- /dev/null +++ b/diffusers/modular_pipelines/wan/before_denoise.py @@ -0,0 +1,554 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect + +import torch + +from ...models import WanTransformer3DModel +from ...schedulers import UniPCMultistepScheduler +from ...utils import logging +from ...utils.torch_utils import randn_tensor +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import WanModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# TODO(yiyi, aryan): We need another step before text encoder to set the `num_inference_steps` attribute for guider so that +# things like when to do guidance and how many conditions to be prepared can be determined. Currently, this is done by +# always assuming you want to do guidance in the Guiders. So, negative embeddings are prepared regardless of what the +# configuration of guider is. + + +def repeat_tensor_to_batch_size( + input_name: str, + input_tensor: torch.Tensor, + batch_size: int, + num_videos_per_prompt: int = 1, +) -> torch.Tensor: + """Repeat tensor elements to match the final batch size. + + This function expands a tensor's batch dimension to match the final batch size (batch_size * num_videos_per_prompt) + by repeating each element along dimension 0. + + The input tensor must have batch size 1 or batch_size. The function will: + - If batch size is 1: repeat each element (batch_size * num_videos_per_prompt) times + - If batch size equals batch_size: repeat each element num_videos_per_prompt times + + Args: + input_name (str): Name of the input tensor (used for error messages) + input_tensor (torch.Tensor): The tensor to repeat. Must have batch size 1 or batch_size. + batch_size (int): The base batch size (number of prompts) + num_videos_per_prompt (int, optional): Number of videos to generate per prompt. Defaults to 1. + + Returns: + torch.Tensor: The repeated tensor with final batch size (batch_size * num_videos_per_prompt) + + Raises: + ValueError: If input_tensor is not a torch.Tensor or has invalid batch size + + Examples: + tensor = torch.tensor([[1, 2, 3]]) # shape: [1, 3] repeated = repeat_tensor_to_batch_size("image", tensor, + batch_size=2, num_videos_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]]) - shape: + [4, 3] + + tensor = torch.tensor([[1, 2, 3], [4, 5, 6]]) # shape: [2, 3] repeated = repeat_tensor_to_batch_size("image", + tensor, batch_size=2, num_videos_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [4, 5, 6], [4, 5, 6]]) + - shape: [4, 3] + """ + # make sure input is a tensor + if not isinstance(input_tensor, torch.Tensor): + raise ValueError(f"`{input_name}` must be a tensor") + + # make sure input tensor e.g. image_latents has batch size 1 or batch_size same as prompts + if input_tensor.shape[0] == 1: + repeat_by = batch_size * num_videos_per_prompt + elif input_tensor.shape[0] == batch_size: + repeat_by = num_videos_per_prompt + else: + raise ValueError( + f"`{input_name}` must have have batch size 1 or {batch_size}, but got {input_tensor.shape[0]}" + ) + + # expand the tensor to match the batch_size * num_videos_per_prompt + input_tensor = input_tensor.repeat_interleave(repeat_by, dim=0) + + return input_tensor + + +def calculate_dimension_from_latents( + latents: torch.Tensor, vae_scale_factor_temporal: int, vae_scale_factor_spatial: int +) -> tuple[int, int]: + """Calculate image dimensions from latent tensor dimensions. + + This function converts latent temporal and spatial dimensions to image temporal and spatial dimensions by + multiplying the latent num_frames/height/width by the VAE scale factor. + + Args: + latents (torch.Tensor): The latent tensor. Must have 4 or 5 dimensions. + Expected shapes: [batch, channels, height, width] or [batch, channels, frames, height, width] + vae_scale_factor_temporal (int): The scale factor used by the VAE to compress temporal dimension. + Typically 4 for most VAEs (video is 4x larger than latents in temporal dimension) + vae_scale_factor_spatial (int): The scale factor used by the VAE to compress spatial dimension. + Typically 8 for most VAEs (image is 8x larger than latents in each dimension) + + Returns: + tuple[int, int]: The calculated image dimensions as (height, width) + + Raises: + ValueError: If latents tensor doesn't have 4 or 5 dimensions + + """ + if latents.ndim != 5: + raise ValueError(f"latents must have 5 dimensions, but got {latents.ndim}") + + _, _, num_latent_frames, latent_height, latent_width = latents.shape + + num_frames = (num_latent_frames - 1) * vae_scale_factor_temporal + 1 + height = latent_height * vae_scale_factor_spatial + width = latent_width * vae_scale_factor_spatial + + return num_frames, height, width + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class WanTextInputStep(ModularPipelineBlocks): + model_name = "wan" + + @property + def description(self) -> str: + return ( + "Input processing step that:\n" + " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n" + " 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_videos_per_prompt`\n\n" + "All input tensors are expected to have either batch_size=1 or match the batch_size\n" + "of prompt_embeds. The tensors will be duplicated across the batch dimension to\n" + "have a final batch_size of batch_size * num_videos_per_prompt." + ) + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("transformer", WanTransformer3DModel), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("num_videos_per_prompt", default=1), + InputParam( + "prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="Pre-generated text embeddings. Can be generated from text_encoder step.", + ), + InputParam( + "negative_prompt_embeds", + type_hint=torch.Tensor, + description="Pre-generated negative text embeddings. Can be generated from text_encoder step.", + ), + ] + + @property + def intermediate_outputs(self) -> list[str]: + return [ + OutputParam( + "batch_size", + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_videos_per_prompt", + ), + OutputParam( + "dtype", + type_hint=torch.dtype, + description="Data type of model tensor inputs (determined by `transformer.dtype`)", + ), + ] + + def check_inputs(self, components, block_state): + if block_state.prompt_embeds is not None and block_state.negative_prompt_embeds is not None: + if block_state.prompt_embeds.shape != block_state.negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {block_state.prompt_embeds.shape} != `negative_prompt_embeds`" + f" {block_state.negative_prompt_embeds.shape}." + ) + + @torch.no_grad() + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(components, block_state) + + block_state.batch_size = block_state.prompt_embeds.shape[0] + block_state.dtype = block_state.prompt_embeds.dtype + + _, seq_len, _ = block_state.prompt_embeds.shape + block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_videos_per_prompt, 1) + block_state.prompt_embeds = block_state.prompt_embeds.view( + block_state.batch_size * block_state.num_videos_per_prompt, seq_len, -1 + ) + + if block_state.negative_prompt_embeds is not None: + _, seq_len, _ = block_state.negative_prompt_embeds.shape + block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat( + 1, block_state.num_videos_per_prompt, 1 + ) + block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view( + block_state.batch_size * block_state.num_videos_per_prompt, seq_len, -1 + ) + + self.set_block_state(state, block_state) + + return components, state + + +class WanAdditionalInputsStep(ModularPipelineBlocks): + model_name = "wan" + + def __init__( + self, + image_latent_inputs: list[str] = ["image_condition_latents"], + additional_batch_inputs: list[str] = [], + ): + """Initialize a configurable step that standardizes the inputs for the denoising step. It:\n" + + This step handles multiple common tasks to prepare inputs for the denoising step: + 1. For encoded image latents, use it update height/width if None, and expands batch size + 2. For additional_batch_inputs: Only expands batch dimensions to match final batch size + + This is a dynamic block that allows you to configure which inputs to process. + + Args: + image_latent_inputs (list[str], optional): Names of image latent tensors to process. + In additional to adjust batch size of these inputs, they will be used to determine height/width. Can be + a single string or list of strings. Defaults to ["image_condition_latents"]. + additional_batch_inputs (List[str], optional): + Names of additional conditional input tensors to expand batch size. These tensors will only have their + batch dimensions adjusted to match the final batch size. Can be a single string or list of strings. + Defaults to []. + + Examples: + # Configure to process image_condition_latents (default behavior) WanAdditionalInputsStep() # Configure to + process image latents and additional batch inputs WanAdditionalInputsStep( + image_latent_inputs=["image_condition_latents"], additional_batch_inputs=["image_embeds"] + ) + """ + if not isinstance(image_latent_inputs, list): + image_latent_inputs = [image_latent_inputs] + if not isinstance(additional_batch_inputs, list): + additional_batch_inputs = [additional_batch_inputs] + + self._image_latent_inputs = image_latent_inputs + self._additional_batch_inputs = additional_batch_inputs + super().__init__() + + @property + def description(self) -> str: + # Functionality section + summary_section = ( + "Input processing step that:\n" + " 1. For image latent inputs: Updates height/width if None, and expands batch size\n" + " 2. For additional batch inputs: Expands batch dimensions to match final batch size" + ) + + # Inputs info + inputs_info = "" + if self._image_latent_inputs or self._additional_batch_inputs: + inputs_info = "\n\nConfigured inputs:" + if self._image_latent_inputs: + inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}" + if self._additional_batch_inputs: + inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}" + + # Placement guidance + placement_section = "\n\nThis block should be placed after the encoder steps and the text input step." + + return summary_section + inputs_info + placement_section + + @property + def inputs(self) -> list[InputParam]: + inputs = [ + InputParam(name="num_videos_per_prompt", default=1), + InputParam(name="batch_size", required=True), + InputParam(name="height"), + InputParam(name="width"), + InputParam(name="num_frames"), + ] + + # Add image latent inputs + for image_latent_input_name in self._image_latent_inputs: + inputs.append(InputParam(name=image_latent_input_name)) + + # Add additional batch inputs + for input_name in self._additional_batch_inputs: + inputs.append(InputParam(name=input_name)) + + return inputs + + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + # Process image latent inputs (height/width calculation, patchify, and batch expansion) + for image_latent_input_name in self._image_latent_inputs: + image_latent_tensor = getattr(block_state, image_latent_input_name) + if image_latent_tensor is None: + continue + + # 1. Calculate num_frames, height/width from latents + num_frames, height, width = calculate_dimension_from_latents( + image_latent_tensor, components.vae_scale_factor_temporal, components.vae_scale_factor_spatial + ) + block_state.num_frames = block_state.num_frames or num_frames + block_state.height = block_state.height or height + block_state.width = block_state.width or width + + # 3. Expand batch size + image_latent_tensor = repeat_tensor_to_batch_size( + input_name=image_latent_input_name, + input_tensor=image_latent_tensor, + num_videos_per_prompt=block_state.num_videos_per_prompt, + batch_size=block_state.batch_size, + ) + + setattr(block_state, image_latent_input_name, image_latent_tensor) + + # Process additional batch inputs (only batch expansion) + for input_name in self._additional_batch_inputs: + input_tensor = getattr(block_state, input_name) + if input_tensor is None: + continue + + # Only expand batch size + input_tensor = repeat_tensor_to_batch_size( + input_name=input_name, + input_tensor=input_tensor, + num_videos_per_prompt=block_state.num_videos_per_prompt, + batch_size=block_state.batch_size, + ) + + setattr(block_state, input_name, input_tensor) + + self.set_block_state(state, block_state) + return components, state + + +class WanSetTimestepsStep(ModularPipelineBlocks): + model_name = "wan" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("scheduler", UniPCMultistepScheduler), + ] + + @property + def description(self) -> str: + return "Step that sets the scheduler's timesteps for inference" + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("num_inference_steps", default=50), + InputParam("timesteps"), + InputParam("sigmas"), + ] + + @torch.no_grad() + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + device = components._execution_device + + block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( + components.scheduler, + block_state.num_inference_steps, + device, + block_state.timesteps, + block_state.sigmas, + ) + + self.set_block_state(state, block_state) + return components, state + + +class WanPrepareLatentsStep(ModularPipelineBlocks): + model_name = "wan" + + @property + def description(self) -> str: + return "Prepare latents step that prepares the latents for the text-to-video generation process" + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("height", type_hint=int), + InputParam("width", type_hint=int), + InputParam("num_frames", type_hint=int), + InputParam("latents", type_hint=torch.Tensor | None), + InputParam("num_videos_per_prompt", type_hint=int, default=1), + InputParam("generator"), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be `batch_size * num_videos_per_prompt`. Can be generated in input step.", + ), + InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process" + ) + ] + + @staticmethod + def check_inputs(components, block_state): + if (block_state.height is not None and block_state.height % components.vae_scale_factor_spatial != 0) or ( + block_state.width is not None and block_state.width % components.vae_scale_factor_spatial != 0 + ): + raise ValueError( + f"`height` and `width` have to be divisible by {components.vae_scale_factor_spatial} but are {block_state.height} and {block_state.width}." + ) + if block_state.num_frames is not None and ( + block_state.num_frames < 1 or (block_state.num_frames - 1) % components.vae_scale_factor_temporal != 0 + ): + raise ValueError( + f"`num_frames` has to be greater than 0, and (num_frames - 1) must be divisible by {components.vae_scale_factor_temporal}, but got {block_state.num_frames}." + ) + + @staticmethod + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.prepare_latents with self->comp + def prepare_latents( + comp, + batch_size: int, + num_channels_latents: int = 16, + height: int = 480, + width: int = 832, + num_frames: int = 81, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + num_latent_frames = (num_frames - 1) // comp.vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_channels_latents, + num_latent_frames, + int(height) // comp.vae_scale_factor_spatial, + int(width) // comp.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + @torch.no_grad() + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(components, block_state) + + device = components._execution_device + dtype = torch.float32 # Wan latents should be torch.float32 for best quality + + block_state.height = block_state.height or components.default_height + block_state.width = block_state.width or components.default_width + block_state.num_frames = block_state.num_frames or components.default_num_frames + + block_state.latents = self.prepare_latents( + components, + batch_size=block_state.batch_size * block_state.num_videos_per_prompt, + num_channels_latents=components.num_channels_latents, + height=block_state.height, + width=block_state.width, + num_frames=block_state.num_frames, + dtype=dtype, + device=device, + generator=block_state.generator, + latents=block_state.latents, + ) + + self.set_block_state(state, block_state) + + return components, state diff --git a/diffusers/modular_pipelines/wan/decoders.py b/diffusers/modular_pipelines/wan/decoders.py new file mode 100644 index 0000000000000000000000000000000000000000..9d1a4cf4f3486245ed15c3a753738ae7d0608e4d --- /dev/null +++ b/diffusers/modular_pipelines/wan/decoders.py @@ -0,0 +1,98 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import numpy as np +import PIL +import torch + +from ...configuration_utils import FrozenDict +from ...models import AutoencoderKLWan +from ...utils import logging +from ...video_processor import VideoProcessor +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class WanVaeDecoderStep(ModularPipelineBlocks): + model_name = "wan" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKLWan), + ComponentSpec( + "video_processor", + VideoProcessor, + config=FrozenDict({"vae_scale_factor": 8}), + default_creation_method="from_config", + ), + ] + + @property + def description(self) -> str: + return "Step that decodes the denoised latents into images" + + @property + def inputs(self) -> list[tuple[str, Any]]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The denoised latents from the denoising step", + ), + InputParam( + "output_type", default="np", type_hint=str, description="The output type of the decoded videos" + ), + ] + + @property + def intermediate_outputs(self) -> list[str]: + return [ + OutputParam( + "videos", + type_hint=list[list[PIL.Image.Image]] | list[torch.Tensor] | list[np.ndarray], + description="The generated videos, can be a PIL.Image.Image, torch.Tensor or a numpy array", + ) + ] + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + vae_dtype = components.vae.dtype + + latents = block_state.latents + latents_mean = ( + torch.tensor(components.vae.config.latents_mean) + .view(1, components.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view( + 1, components.vae.config.z_dim, 1, 1, 1 + ).to(latents.device, latents.dtype) + latents = latents / latents_std + latents_mean + latents = latents.to(vae_dtype) + block_state.videos = components.vae.decode(latents, return_dict=False)[0] + + output_type = getattr(block_state, "output_type", "np") + block_state.videos = components.video_processor.postprocess_video(block_state.videos, output_type=output_type) + + self.set_block_state(state, block_state) + + return components, state diff --git a/diffusers/modular_pipelines/wan/denoise.py b/diffusers/modular_pipelines/wan/denoise.py new file mode 100644 index 0000000000000000000000000000000000000000..2f51f353012ea6452610d327087bff3db867a99c --- /dev/null +++ b/diffusers/modular_pipelines/wan/denoise.py @@ -0,0 +1,544 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import torch + +from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance +from ...models import WanTransformer3DModel +from ...schedulers import UniPCMultistepScheduler +from ...utils import logging +from ..modular_pipeline import ( + BlockState, + LoopSequentialPipelineBlocks, + ModularPipelineBlocks, + PipelineState, +) +from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam +from .modular_pipeline import WanModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class WanLoopBeforeDenoiser(ModularPipelineBlocks): + model_name = "wan" + + @property + def description(self) -> str: + return ( + "step within the denoising loop that prepares the latent input for the denoiser. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `WanDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", + ), + InputParam( + "dtype", + required=True, + type_hint=torch.dtype, + description="The dtype of the model inputs. Can be generated in input step.", + ), + ] + + @torch.no_grad() + def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + block_state.latent_model_input = block_state.latents.to(block_state.dtype) + return components, block_state + + +class WanImage2VideoLoopBeforeDenoiser(ModularPipelineBlocks): + model_name = "wan" + + @property + def description(self) -> str: + return ( + "step within the denoising loop that prepares the latent input for the denoiser. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `WanDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", + ), + InputParam( + "image_condition_latents", + required=True, + type_hint=torch.Tensor, + description="The image condition latents to use for the denoising process. Can be generated in prepare_first_frame_latents/prepare_first_last_frame_latents step.", + ), + InputParam( + "dtype", + required=True, + type_hint=torch.dtype, + description="The dtype of the model inputs. Can be generated in input step.", + ), + ] + + @torch.no_grad() + def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + block_state.latent_model_input = torch.cat( + [block_state.latents, block_state.image_condition_latents], dim=1 + ).to(block_state.dtype) + return components, block_state + + +class WanLoopDenoiser(ModularPipelineBlocks): + model_name = "wan" + + def __init__( + self, + guider_input_fields: dict[str, Any] = {"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds")}, + ): + """Initialize a denoiser block that calls the denoiser model. This block is used in Wan2.1. + + Args: + guider_input_fields: A dictionary that maps each argument expected by the denoiser model + (for example, "encoder_hidden_states") to data stored on 'block_state'. The value can be either: + + - A tuple of strings. For instance, {"encoder_hidden_states": ("prompt_embeds", + "negative_prompt_embeds")} tells the guider to read `block_state.prompt_embeds` and + `block_state.negative_prompt_embeds` and pass them as the conditional and unconditional batches of + 'encoder_hidden_states'. + - A string. For example, {"encoder_hidden_image": "image_embeds"} makes the guider forward + `block_state.image_embeds` for both conditional and unconditional batches. + """ + if not isinstance(guider_input_fields, dict): + raise ValueError(f"guider_input_fields must be a dictionary but is {type(guider_input_fields)}") + self._guider_input_fields = guider_input_fields + super().__init__() + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 5.0}), + default_creation_method="from_config", + ), + ComponentSpec("transformer", WanTransformer3DModel), + ] + + @property + def description(self) -> str: + return ( + "Step within the denoising loop that denoise the latents with guidance. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `WanDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> list[tuple[str, Any]]: + inputs = [ + InputParam("attention_kwargs"), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", + ), + ] + guider_input_names = [] + for value in self._guider_input_fields.values(): + if isinstance(value, tuple): + guider_input_names.extend(value) + else: + guider_input_names.append(value) + + for name in guider_input_names: + inputs.append(InputParam(name=name, required=True, type_hint=torch.Tensor)) + return inputs + + @torch.no_grad() + def __call__( + self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor + ) -> PipelineState: + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + + # The guider splits model inputs into separate batches for conditional/unconditional predictions. + # For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}: + # you will get a guider_state with two batches: + # guider_state = [ + # {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch + # {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch + # ] + # Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG). + guider_state = components.guider.prepare_inputs_from_block_state(block_state, self._guider_input_fields) + + # run the denoiser for each guidance batch + for guider_state_batch in guider_state: + components.guider.prepare_models(components.transformer) + cond_kwargs = guider_state_batch.as_dict() + cond_kwargs = { + k: v.to(block_state.dtype) if isinstance(v, torch.Tensor) else v + for k, v in cond_kwargs.items() + if k in self._guider_input_fields.keys() + } + + # Predict the noise residual + # store the noise_pred in guider_state_batch so that we can apply guidance across all batches + guider_state_batch.noise_pred = components.transformer( + hidden_states=block_state.latent_model_input.to(block_state.dtype), + timestep=t.expand(block_state.latent_model_input.shape[0]).to(block_state.dtype), + attention_kwargs=block_state.attention_kwargs, + return_dict=False, + **cond_kwargs, + )[0] + components.guider.cleanup_models(components.transformer) + + # Perform guidance + block_state.noise_pred = components.guider(guider_state)[0] + + return components, block_state + + +class Wan22LoopDenoiser(ModularPipelineBlocks): + model_name = "wan" + + def __init__( + self, + guider_input_fields: dict[str, Any] = {"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds")}, + ): + """Initialize a denoiser block that calls the denoiser model. This block is used in Wan2.2. + + Args: + guider_input_fields: A dictionary that maps each argument expected by the denoiser model + (for example, "encoder_hidden_states") to data stored on `block_state`. The value can be either: + + - A tuple of strings. For instance, `{"encoder_hidden_states": ("prompt_embeds", + "negative_prompt_embeds")}` tells the guider to read `block_state.prompt_embeds` and + `block_state.negative_prompt_embeds` and pass them as the conditional and unconditional batches of + `encoder_hidden_states`. + - A string. For example, `{"encoder_hidden_image": "image_embeds"}` makes the guider forward + `block_state.image_embeds` for both conditional and unconditional batches. + """ + if not isinstance(guider_input_fields, dict): + raise ValueError(f"guider_input_fields must be a dictionary but is {type(guider_input_fields)}") + self._guider_input_fields = guider_input_fields + super().__init__() + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 4.0}), + default_creation_method="from_config", + ), + ComponentSpec( + "guider_2", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 3.0}), + default_creation_method="from_config", + ), + ComponentSpec("transformer", WanTransformer3DModel), + ComponentSpec("transformer_2", WanTransformer3DModel), + ] + + @property + def description(self) -> str: + return ( + "Step within the denoising loop that denoise the latents with guidance. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `WanDenoiseLoopWrapper`)" + ) + + @property + def expected_configs(self) -> list[ConfigSpec]: + return [ + ConfigSpec( + name="boundary_ratio", + default=0.875, + description="The boundary ratio to divide the denoising loop into high noise and low noise stages.", + ), + ] + + @property + def inputs(self) -> list[tuple[str, Any]]: + inputs = [ + InputParam("attention_kwargs"), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", + ), + ] + guider_input_names = [] + for value in self._guider_input_fields.values(): + if isinstance(value, tuple): + guider_input_names.extend(value) + else: + guider_input_names.append(value) + + for name in guider_input_names: + inputs.append(InputParam(name=name, required=True, type_hint=torch.Tensor)) + return inputs + + @torch.no_grad() + def __call__( + self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor + ) -> PipelineState: + boundary_timestep = components.config.boundary_ratio * components.num_train_timesteps + if t >= boundary_timestep: + block_state.current_model = components.transformer + block_state.guider = components.guider + else: + block_state.current_model = components.transformer_2 + block_state.guider = components.guider_2 + + block_state.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + + # The guider splits model inputs into separate batches for conditional/unconditional predictions. + # For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}: + # you will get a guider_state with two batches: + # guider_state = [ + # {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch + # {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch + # ] + # Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG). + guider_state = block_state.guider.prepare_inputs_from_block_state(block_state, self._guider_input_fields) + + # run the denoiser for each guidance batch + for guider_state_batch in guider_state: + block_state.guider.prepare_models(block_state.current_model) + cond_kwargs = guider_state_batch.as_dict() + cond_kwargs = { + k: v.to(block_state.dtype) if isinstance(v, torch.Tensor) else v + for k, v in cond_kwargs.items() + if k in self._guider_input_fields.keys() + } + + # Predict the noise residual + # store the noise_pred in guider_state_batch so that we can apply guidance across all batches + guider_state_batch.noise_pred = block_state.current_model( + hidden_states=block_state.latent_model_input.to(block_state.dtype), + timestep=t.expand(block_state.latent_model_input.shape[0]).to(block_state.dtype), + attention_kwargs=block_state.attention_kwargs, + return_dict=False, + **cond_kwargs, + )[0] + block_state.guider.cleanup_models(block_state.current_model) + + # Perform guidance + block_state.noise_pred = block_state.guider(guider_state)[0] + + return components, block_state + + +class WanLoopAfterDenoiser(ModularPipelineBlocks): + model_name = "wan" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("scheduler", UniPCMultistepScheduler), + ] + + @property + def description(self) -> str: + return ( + "step within the denoising loop that update the latents. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `WanDenoiseLoopWrapper`)" + ) + + @torch.no_grad() + def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + # Perform scheduler step using the predicted output + latents_dtype = block_state.latents.dtype + block_state.latents = components.scheduler.step( + block_state.noise_pred.float(), + t, + block_state.latents.float(), + return_dict=False, + )[0] + + if block_state.latents.dtype != latents_dtype: + block_state.latents = block_state.latents.to(latents_dtype) + + return components, block_state + + +class WanDenoiseLoopWrapper(LoopSequentialPipelineBlocks): + model_name = "wan" + + @property + def description(self) -> str: + return ( + "Pipeline block that iteratively denoise the latents over `timesteps`. " + "The specific steps with each iteration can be customized with `sub_blocks` attributes" + ) + + @property + def loop_expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("scheduler", UniPCMultistepScheduler), + ] + + @property + def loop_inputs(self) -> list[InputParam]: + return [ + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", + ), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", + ), + ] + + @torch.no_grad() + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.num_warmup_steps = max( + len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0 + ) + + with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: + for i, t in enumerate(block_state.timesteps): + components, block_state = self.loop_step(components, block_state, i=i, t=t) + if i == len(block_state.timesteps) - 1 or ( + (i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0 + ): + progress_bar.update() + + self.set_block_state(state, block_state) + + return components, state + + +class WanDenoiseStep(WanDenoiseLoopWrapper): + block_classes = [ + WanLoopBeforeDenoiser, + WanLoopDenoiser( + guider_input_fields={ + "encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"), + } + ), + WanLoopAfterDenoiser, + ] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. \n" + "Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" + " - `WanLoopBeforeDenoiser`\n" + " - `WanLoopDenoiser`\n" + " - `WanLoopAfterDenoiser`\n" + "This block supports text-to-video tasks for wan2.1." + ) + + +class Wan22DenoiseStep(WanDenoiseLoopWrapper): + block_classes = [ + WanLoopBeforeDenoiser, + Wan22LoopDenoiser( + guider_input_fields={ + "encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"), + } + ), + WanLoopAfterDenoiser, + ] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. \n" + "Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" + " - `WanLoopBeforeDenoiser`\n" + " - `Wan22LoopDenoiser`\n" + " - `WanLoopAfterDenoiser`\n" + "This block supports text-to-video tasks for Wan2.2." + ) + + +class WanImage2VideoDenoiseStep(WanDenoiseLoopWrapper): + block_classes = [ + WanImage2VideoLoopBeforeDenoiser, + WanLoopDenoiser( + guider_input_fields={ + "encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"), + "encoder_hidden_states_image": "image_embeds", + } + ), + WanLoopAfterDenoiser, + ] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. \n" + "Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" + " - `WanImage2VideoLoopBeforeDenoiser`\n" + " - `WanLoopDenoiser`\n" + " - `WanLoopAfterDenoiser`\n" + "This block supports image-to-video tasks for wan2.1." + ) + + +class Wan22Image2VideoDenoiseStep(WanDenoiseLoopWrapper): + block_classes = [ + WanImage2VideoLoopBeforeDenoiser, + Wan22LoopDenoiser( + guider_input_fields={ + "encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"), + } + ), + WanLoopAfterDenoiser, + ] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. \n" + "Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" + " - `WanImage2VideoLoopBeforeDenoiser`\n" + " - `WanLoopDenoiser`\n" + " - `WanLoopAfterDenoiser`\n" + "This block supports image-to-video tasks for Wan2.2." + ) diff --git a/diffusers/modular_pipelines/wan/encoders.py b/diffusers/modular_pipelines/wan/encoders.py new file mode 100644 index 0000000000000000000000000000000000000000..3e675a66e4f2b7acbdd09fcbf9a8fd61483c581e --- /dev/null +++ b/diffusers/modular_pipelines/wan/encoders.py @@ -0,0 +1,758 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html + +import numpy as np +import PIL +import regex as re +import torch +from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel + +from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance +from ...image_processor import PipelineImageInput +from ...models import AutoencoderKLWan +from ...utils import is_ftfy_available, is_torchvision_available, logging +from ...video_processor import VideoProcessor +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import WanModularPipeline + + +if is_ftfy_available(): + import ftfy + +if is_torchvision_available(): + from torchvision import transforms + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + + +def get_t5_prompt_embeds( + text_encoder: UMT5EncoderModel, + tokenizer: AutoTokenizer, + prompt: str | list[str], + max_sequence_length: int, + device: torch.device, +): + dtype = text_encoder.dtype + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + prompt_embeds = text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + return prompt_embeds + + +def encode_image( + image: PipelineImageInput, + image_processor: CLIPImageProcessor, + image_encoder: CLIPVisionModel, + device: torch.device | None = None, +): + image = image_processor(images=image, return_tensors="pt").to(device) + image_embeds = image_encoder(**image, output_hidden_states=True) + return image_embeds.hidden_states[-2] + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +def encode_vae_image( + video_tensor: torch.Tensor, + vae: AutoencoderKLWan, + generator: torch.Generator, + device: torch.device, + dtype: torch.dtype, + latent_channels: int = 16, +): + if not isinstance(video_tensor, torch.Tensor): + raise ValueError(f"Expected video_tensor to be a tensor, got {type(video_tensor)}.") + + if isinstance(generator, list) and len(generator) != video_tensor.shape[0]: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but it is not same as number of images {video_tensor.shape[0]}." + ) + + video_tensor = video_tensor.to(device=device, dtype=dtype) + + if isinstance(generator, list): + video_latents = [ + retrieve_latents(vae.encode(video_tensor[i : i + 1]), generator=generator[i], sample_mode="argmax") + for i in range(video_tensor.shape[0]) + ] + video_latents = torch.cat(video_latents, dim=0) + else: + video_latents = retrieve_latents(vae.encode(video_tensor), sample_mode="argmax") + + latents_mean = ( + torch.tensor(vae.config.latents_mean) + .view(1, latent_channels, 1, 1, 1) + .to(video_latents.device, video_latents.dtype) + ) + latents_std = 1.0 / torch.tensor(vae.config.latents_std).view(1, latent_channels, 1, 1, 1).to( + video_latents.device, video_latents.dtype + ) + video_latents = (video_latents - latents_mean) * latents_std + + return video_latents + + +class WanTextEncoderStep(ModularPipelineBlocks): + model_name = "wan" + + @property + def description(self) -> str: + return "Text Encoder step that generate text_embeddings to guide the video generation" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("text_encoder", UMT5EncoderModel), + ComponentSpec("tokenizer", AutoTokenizer), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 5.0}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("prompt"), + InputParam("negative_prompt"), + InputParam("max_sequence_length", default=512), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="text embeddings used to guide the image generation", + ), + OutputParam( + "negative_prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="negative text embeddings used to guide the image generation", + ), + ] + + @staticmethod + def check_inputs(block_state): + if block_state.prompt is not None and ( + not isinstance(block_state.prompt, str) and not isinstance(block_state.prompt, list) + ): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}") + + @staticmethod + def encode_prompt( + components, + prompt: str, + device: torch.device | None = None, + prepare_unconditional_embeds: bool = True, + negative_prompt: str | None = None, + max_sequence_length: int = 512, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + prepare_unconditional_embeds (`bool`): + whether to use prepare unconditional embeddings or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + max_sequence_length (`int`, defaults to `512`): + The maximum number of text tokens to be used for the generation process. + """ + device = device or components._execution_device + if not isinstance(prompt, list): + prompt = [prompt] + batch_size = len(prompt) + + prompt_embeds = get_t5_prompt_embeds( + text_encoder=components.text_encoder, + tokenizer=components.tokenizer, + prompt=prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + if prepare_unconditional_embeds: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = get_t5_prompt_embeds( + text_encoder=components.text_encoder, + tokenizer=components.tokenizer, + prompt=negative_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + return prompt_embeds, negative_prompt_embeds + + @torch.no_grad() + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + # Get inputs and intermediates + block_state = self.get_block_state(state) + self.check_inputs(block_state) + + block_state.device = components._execution_device + + # Encode input prompt + ( + block_state.prompt_embeds, + block_state.negative_prompt_embeds, + ) = self.encode_prompt( + components=components, + prompt=block_state.prompt, + device=block_state.device, + prepare_unconditional_embeds=components.requires_unconditional_embeds, + negative_prompt=block_state.negative_prompt, + max_sequence_length=block_state.max_sequence_length, + ) + + # Add outputs + self.set_block_state(state, block_state) + return components, state + + +class WanImageResizeStep(ModularPipelineBlocks): + model_name = "wan" + + @property + def description(self) -> str: + return "Image Resize step that resize the image to the target area (height * width) while maintaining the aspect ratio." + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("image", type_hint=PIL.Image.Image, required=True), + InputParam("height", type_hint=int, default=480), + InputParam("width", type_hint=int, default=832), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam("resized_image", type_hint=PIL.Image.Image), + ] + + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + max_area = block_state.height * block_state.width + + image = block_state.image + aspect_ratio = image.height / image.width + mod_value = components.vae_scale_factor_spatial * components.patch_size_spatial + block_state.height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value + block_state.width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value + block_state.resized_image = image.resize((block_state.width, block_state.height)) + + self.set_block_state(state, block_state) + return components, state + + +class WanImageCropResizeStep(ModularPipelineBlocks): + model_name = "wan" + + @property + def description(self) -> str: + return "Image Resize step that resize the last_image to the same size of first frame image with center crop." + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + "resized_image", type_hint=PIL.Image.Image, required=True, description="The resized first frame image" + ), + InputParam("last_image", type_hint=PIL.Image.Image, required=True, description="The last frameimage"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam("resized_last_image", type_hint=PIL.Image.Image), + ] + + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + height = block_state.resized_image.height + width = block_state.resized_image.width + image = block_state.last_image + + # Calculate resize ratio to match first frame dimensions + resize_ratio = max(width / image.width, height / image.height) + + # Resize the image + width = round(image.width * resize_ratio) + height = round(image.height * resize_ratio) + size = [width, height] + resized_image = transforms.functional.center_crop(image, size) + block_state.resized_last_image = resized_image + + self.set_block_state(state, block_state) + return components, state + + +class WanImageEncoderStep(ModularPipelineBlocks): + model_name = "wan" + + @property + def description(self) -> str: + return "Image Encoder step that generate image_embeds based on first frame image to guide the video generation" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("image_processor", CLIPImageProcessor), + ComponentSpec("image_encoder", CLIPVisionModel), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("resized_image", type_hint=PIL.Image.Image, required=True), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam("image_embeds", type_hint=torch.Tensor, description="The image embeddings"), + ] + + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + device = components._execution_device + + image = block_state.resized_image + + image_embeds = encode_image( + image_processor=components.image_processor, + image_encoder=components.image_encoder, + image=image, + device=device, + ) + block_state.image_embeds = image_embeds + self.set_block_state(state, block_state) + return components, state + + +class WanFirstLastFrameImageEncoderStep(ModularPipelineBlocks): + model_name = "wan" + + @property + def description(self) -> str: + return "Image Encoder step that generate image_embeds based on first and last frame images to guide the video generation" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("image_processor", CLIPImageProcessor), + ComponentSpec("image_encoder", CLIPVisionModel), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("resized_image", type_hint=PIL.Image.Image, required=True), + InputParam("resized_last_image", type_hint=PIL.Image.Image, required=True), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam("image_embeds", type_hint=torch.Tensor, description="The image embeddings"), + ] + + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + device = components._execution_device + + first_frame_image = block_state.resized_image + last_frame_image = block_state.resized_last_image + + image_embeds = encode_image( + image_processor=components.image_processor, + image_encoder=components.image_encoder, + image=[first_frame_image, last_frame_image], + device=device, + ) + block_state.image_embeds = image_embeds + self.set_block_state(state, block_state) + return components, state + + +class WanVaeEncoderStep(ModularPipelineBlocks): + model_name = "wan" + + @property + def description(self) -> str: + return "Vae Image Encoder step that generate condition_latents based on first frame image to guide the video generation" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKLWan), + ComponentSpec( + "video_processor", + VideoProcessor, + config=FrozenDict({"vae_scale_factor": 8}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("resized_image", type_hint=PIL.Image.Image, required=True), + InputParam("height"), + InputParam("width"), + InputParam("num_frames", type_hint=int, default=81), + InputParam("generator"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "first_frame_latents", + type_hint=torch.Tensor, + description="video latent representation with the first frame image condition", + ), + ] + + @staticmethod + def check_inputs(components, block_state): + if (block_state.height is not None and block_state.height % components.vae_scale_factor_spatial != 0) or ( + block_state.width is not None and block_state.width % components.vae_scale_factor_spatial != 0 + ): + raise ValueError( + f"`height` and `width` have to be divisible by {components.vae_scale_factor_spatial} but are {block_state.height} and {block_state.width}." + ) + if block_state.num_frames is not None and ( + block_state.num_frames < 1 or (block_state.num_frames - 1) % components.vae_scale_factor_temporal != 0 + ): + raise ValueError( + f"`num_frames` has to be greater than 0, and (num_frames - 1) must be divisible by {components.vae_scale_factor_temporal}, but got {block_state.num_frames}." + ) + + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(components, block_state) + + image = block_state.resized_image + + device = components._execution_device + dtype = torch.float32 + vae_dtype = components.vae.dtype + + height = block_state.height or components.default_height + width = block_state.width or components.default_width + num_frames = block_state.num_frames or components.default_num_frames + + image_tensor = components.video_processor.preprocess(image, height=height, width=width).to( + device=device, dtype=dtype + ) + + if image_tensor.dim() == 4: + image_tensor = image_tensor.unsqueeze(2) + + video_tensor = torch.cat( + [ + image_tensor, + image_tensor.new_zeros(image_tensor.shape[0], image_tensor.shape[1], num_frames - 1, height, width), + ], + dim=2, + ).to(device=device, dtype=dtype) + + block_state.first_frame_latents = encode_vae_image( + video_tensor=video_tensor, + vae=components.vae, + generator=block_state.generator, + device=device, + dtype=vae_dtype, + latent_channels=components.num_channels_latents, + ) + + self.set_block_state(state, block_state) + return components, state + + +class WanPrepareFirstFrameLatentsStep(ModularPipelineBlocks): + model_name = "wan" + + @property + def description(self) -> str: + return "step that prepares the masked first frame latents and add it to the latent condition" + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("first_frame_latents", type_hint=torch.Tensor | None), + InputParam("num_frames", required=True), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam("image_condition_latents", type_hint=torch.Tensor | None), + ] + + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + batch_size, _, _, latent_height, latent_width = block_state.first_frame_latents.shape + + mask_lat_size = torch.ones(batch_size, 1, block_state.num_frames, latent_height, latent_width) + mask_lat_size[:, :, list(range(1, block_state.num_frames))] = 0 + + first_frame_mask = mask_lat_size[:, :, 0:1] + first_frame_mask = torch.repeat_interleave( + first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal + ) + mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) + mask_lat_size = mask_lat_size.view( + batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width + ) + mask_lat_size = mask_lat_size.transpose(1, 2) + mask_lat_size = mask_lat_size.to(block_state.first_frame_latents.device) + block_state.image_condition_latents = torch.concat([mask_lat_size, block_state.first_frame_latents], dim=1) + + self.set_block_state(state, block_state) + return components, state + + +class WanFirstLastFrameVaeEncoderStep(ModularPipelineBlocks): + model_name = "wan" + + @property + def description(self) -> str: + return "Vae Image Encoder step that generate condition_latents based on first and last frame images to guide the video generation" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKLWan), + ComponentSpec( + "video_processor", + VideoProcessor, + config=FrozenDict({"vae_scale_factor": 8}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("resized_image", type_hint=PIL.Image.Image, required=True), + InputParam("resized_last_image", type_hint=PIL.Image.Image, required=True), + InputParam("height"), + InputParam("width"), + InputParam("num_frames", type_hint=int, default=81), + InputParam("generator"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "first_last_frame_latents", + type_hint=torch.Tensor, + description="video latent representation with the first and last frame images condition", + ), + ] + + @staticmethod + def check_inputs(components, block_state): + if (block_state.height is not None and block_state.height % components.vae_scale_factor_spatial != 0) or ( + block_state.width is not None and block_state.width % components.vae_scale_factor_spatial != 0 + ): + raise ValueError( + f"`height` and `width` have to be divisible by {components.vae_scale_factor_spatial} but are {block_state.height} and {block_state.width}." + ) + if block_state.num_frames is not None and ( + block_state.num_frames < 1 or (block_state.num_frames - 1) % components.vae_scale_factor_temporal != 0 + ): + raise ValueError( + f"`num_frames` has to be greater than 0, and (num_frames - 1) must be divisible by {components.vae_scale_factor_temporal}, but got {block_state.num_frames}." + ) + + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(components, block_state) + + first_frame_image = block_state.resized_image + last_frame_image = block_state.resized_last_image + + device = components._execution_device + dtype = torch.float32 + vae_dtype = components.vae.dtype + + height = block_state.height or components.default_height + width = block_state.width or components.default_width + num_frames = block_state.num_frames or components.default_num_frames + + first_image_tensor = components.video_processor.preprocess(first_frame_image, height=height, width=width).to( + device=device, dtype=dtype + ) + first_image_tensor = first_image_tensor.unsqueeze(2) + + last_image_tensor = components.video_processor.preprocess(last_frame_image, height=height, width=width).to( + device=device, dtype=dtype + ) + + last_image_tensor = last_image_tensor.unsqueeze(2) + + video_tensor = torch.cat( + [ + first_image_tensor, + first_image_tensor.new_zeros( + first_image_tensor.shape[0], first_image_tensor.shape[1], num_frames - 2, height, width + ), + last_image_tensor, + ], + dim=2, + ).to(device=device, dtype=dtype) + + block_state.first_last_frame_latents = encode_vae_image( + video_tensor=video_tensor, + vae=components.vae, + generator=block_state.generator, + device=device, + dtype=vae_dtype, + latent_channels=components.num_channels_latents, + ) + + self.set_block_state(state, block_state) + return components, state + + +class WanPrepareFirstLastFrameLatentsStep(ModularPipelineBlocks): + model_name = "wan" + + @property + def description(self) -> str: + return "step that prepares the masked latents with first and last frames and add it to the latent condition" + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("first_last_frame_latents", type_hint=torch.Tensor | None), + InputParam("num_frames", type_hint=int, required=True), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam("image_condition_latents", type_hint=torch.Tensor | None), + ] + + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + batch_size, _, _, latent_height, latent_width = block_state.first_last_frame_latents.shape + + mask_lat_size = torch.ones(batch_size, 1, block_state.num_frames, latent_height, latent_width) + mask_lat_size[:, :, list(range(1, block_state.num_frames - 1))] = 0 + + first_frame_mask = mask_lat_size[:, :, 0:1] + first_frame_mask = torch.repeat_interleave( + first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal + ) + mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) + mask_lat_size = mask_lat_size.view( + batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width + ) + mask_lat_size = mask_lat_size.transpose(1, 2) + mask_lat_size = mask_lat_size.to(block_state.first_last_frame_latents.device) + block_state.image_condition_latents = torch.concat( + [mask_lat_size, block_state.first_last_frame_latents], dim=1 + ) + + self.set_block_state(state, block_state) + return components, state diff --git a/diffusers/modular_pipelines/wan/modular_blocks_wan.py b/diffusers/modular_pipelines/wan/modular_blocks_wan.py new file mode 100644 index 0000000000000000000000000000000000000000..b641c6cd7fcc481f50e91631bb747fad8f2b3e13 --- /dev/null +++ b/diffusers/modular_pipelines/wan/modular_blocks_wan.py @@ -0,0 +1,162 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...utils import logging +from ..modular_pipeline import SequentialPipelineBlocks +from ..modular_pipeline_utils import OutputParam +from .before_denoise import ( + WanPrepareLatentsStep, + WanSetTimestepsStep, + WanTextInputStep, +) +from .decoders import WanVaeDecoderStep +from .denoise import ( + WanDenoiseStep, +) +from .encoders import ( + WanTextEncoderStep, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# ==================== +# 1. DENOISE +# ==================== + + +# inputs(text) -> set_timesteps -> prepare_latents -> denoise +# auto_docstring +class WanCoreDenoiseStep(SequentialPipelineBlocks): + """ + denoise block that takes encoded conditions and runs the denoising process. + + Components: + transformer (`WanTransformer3DModel`) scheduler (`UniPCMultistepScheduler`) guider (`ClassifierFreeGuidance`) + + Inputs: + num_videos_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + prompt_embeds (`Tensor`): + Pre-generated text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + Pre-generated negative text embeddings. Can be generated from text_encoder step. + num_inference_steps (`None`, *optional*, defaults to 50): + TODO: Add description. + timesteps (`None`, *optional*): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + height (`int`, *optional*): + TODO: Add description. + width (`int`, *optional*): + TODO: Add description. + num_frames (`int`, *optional*): + TODO: Add description. + latents (`Tensor | NoneType`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + attention_kwargs (`None`, *optional*): + TODO: Add description. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "wan" + block_classes = [ + WanTextInputStep, + WanSetTimestepsStep, + WanPrepareLatentsStep, + WanDenoiseStep, + ] + block_names = ["input", "set_timesteps", "prepare_latents", "denoise"] + + @property + def description(self): + return "denoise block that takes encoded conditions and runs the denoising process." + + @property + def outputs(self): + return [OutputParam.template("latents")] + + +# ==================== +# 2. BLOCKS (Wan2.1 text2video) +# ==================== + + +# auto_docstring +class WanBlocks(SequentialPipelineBlocks): + """ + Modular pipeline blocks for Wan2.1. + + Components: + text_encoder (`UMT5EncoderModel`) tokenizer (`AutoTokenizer`) guider (`ClassifierFreeGuidance`) transformer + (`WanTransformer3DModel`) scheduler (`UniPCMultistepScheduler`) vae (`AutoencoderKLWan`) video_processor + (`VideoProcessor`) + + Inputs: + prompt (`None`, *optional*): + TODO: Add description. + negative_prompt (`None`, *optional*): + TODO: Add description. + max_sequence_length (`None`, *optional*, defaults to 512): + TODO: Add description. + num_videos_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + num_inference_steps (`None`, *optional*, defaults to 50): + TODO: Add description. + timesteps (`None`, *optional*): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + height (`int`, *optional*): + TODO: Add description. + width (`int`, *optional*): + TODO: Add description. + num_frames (`int`, *optional*): + TODO: Add description. + latents (`Tensor | NoneType`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + attention_kwargs (`None`, *optional*): + TODO: Add description. + output_type (`str`, *optional*, defaults to np): + The output type of the decoded videos + + Outputs: + videos (`list`): + The generated videos. + """ + + model_name = "wan" + block_classes = [ + WanTextEncoderStep, + WanCoreDenoiseStep, + WanVaeDecoderStep, + ] + block_names = ["text_encoder", "denoise", "decode"] + + @property + def description(self): + return "Modular pipeline blocks for Wan2.1." + + @property + def outputs(self): + return [OutputParam.template("videos")] diff --git a/diffusers/modular_pipelines/wan/modular_blocks_wan22.py b/diffusers/modular_pipelines/wan/modular_blocks_wan22.py new file mode 100644 index 0000000000000000000000000000000000000000..9f602c24713b73a30ae4950b254b0e94f9cae750 --- /dev/null +++ b/diffusers/modular_pipelines/wan/modular_blocks_wan22.py @@ -0,0 +1,176 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...utils import logging +from ..modular_pipeline import SequentialPipelineBlocks +from ..modular_pipeline_utils import OutputParam +from .before_denoise import ( + WanPrepareLatentsStep, + WanSetTimestepsStep, + WanTextInputStep, +) +from .decoders import WanVaeDecoderStep +from .denoise import ( + Wan22DenoiseStep, +) +from .encoders import ( + WanTextEncoderStep, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# ==================== +# 1. DENOISE +# ==================== + +# inputs(text) -> set_timesteps -> prepare_latents -> denoise + + +# auto_docstring +class Wan22CoreDenoiseStep(SequentialPipelineBlocks): + """ + denoise block that takes encoded conditions and runs the denoising process. + + Components: + transformer (`WanTransformer3DModel`) scheduler (`UniPCMultistepScheduler`) guider (`ClassifierFreeGuidance`) + guider_2 (`ClassifierFreeGuidance`) transformer_2 (`WanTransformer3DModel`) + + Configs: + boundary_ratio (default: 0.875): The boundary ratio to divide the denoising loop into high noise and low + noise stages. + + Inputs: + num_videos_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + prompt_embeds (`Tensor`): + Pre-generated text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + Pre-generated negative text embeddings. Can be generated from text_encoder step. + num_inference_steps (`None`, *optional*, defaults to 50): + TODO: Add description. + timesteps (`None`, *optional*): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + height (`int`, *optional*): + TODO: Add description. + width (`int`, *optional*): + TODO: Add description. + num_frames (`int`, *optional*): + TODO: Add description. + latents (`Tensor | NoneType`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + attention_kwargs (`None`, *optional*): + TODO: Add description. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "wan" + block_classes = [ + WanTextInputStep, + WanSetTimestepsStep, + WanPrepareLatentsStep, + Wan22DenoiseStep, + ] + block_names = ["input", "set_timesteps", "prepare_latents", "denoise"] + + @property + def description(self): + return "denoise block that takes encoded conditions and runs the denoising process." + + @property + def outputs(self): + return [OutputParam.template("latents")] + + +# ==================== +# 2. BLOCKS (Wan2.2 text2video) +# ==================== + + +# auto_docstring +class Wan22Blocks(SequentialPipelineBlocks): + """ + Modular pipeline for text-to-video using Wan2.2. + + Components: + text_encoder (`UMT5EncoderModel`) tokenizer (`AutoTokenizer`) guider (`ClassifierFreeGuidance`) transformer + (`WanTransformer3DModel`) scheduler (`UniPCMultistepScheduler`) guider_2 (`ClassifierFreeGuidance`) + transformer_2 (`WanTransformer3DModel`) vae (`AutoencoderKLWan`) video_processor (`VideoProcessor`) + + Configs: + boundary_ratio (default: 0.875): The boundary ratio to divide the denoising loop into high noise and low + noise stages. + + Inputs: + prompt (`None`, *optional*): + TODO: Add description. + negative_prompt (`None`, *optional*): + TODO: Add description. + max_sequence_length (`None`, *optional*, defaults to 512): + TODO: Add description. + num_videos_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + num_inference_steps (`None`, *optional*, defaults to 50): + TODO: Add description. + timesteps (`None`, *optional*): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + height (`int`, *optional*): + TODO: Add description. + width (`int`, *optional*): + TODO: Add description. + num_frames (`int`, *optional*): + TODO: Add description. + latents (`Tensor | NoneType`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + attention_kwargs (`None`, *optional*): + TODO: Add description. + output_type (`str`, *optional*, defaults to np): + The output type of the decoded videos + + Outputs: + videos (`list`): + The generated videos. + """ + + model_name = "wan" + block_classes = [ + WanTextEncoderStep, + Wan22CoreDenoiseStep, + WanVaeDecoderStep, + ] + block_names = [ + "text_encoder", + "denoise", + "decode", + ] + + @property + def description(self): + return "Modular pipeline for text-to-video using Wan2.2." + + @property + def outputs(self): + return [OutputParam.template("videos")] diff --git a/diffusers/modular_pipelines/wan/modular_blocks_wan22_i2v.py b/diffusers/modular_pipelines/wan/modular_blocks_wan22_i2v.py new file mode 100644 index 0000000000000000000000000000000000000000..8e55b7a50f08ae1f6e78f0084dac38abaa676e5a --- /dev/null +++ b/diffusers/modular_pipelines/wan/modular_blocks_wan22_i2v.py @@ -0,0 +1,236 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...utils import logging +from ..modular_pipeline import SequentialPipelineBlocks +from ..modular_pipeline_utils import OutputParam +from .before_denoise import ( + WanAdditionalInputsStep, + WanPrepareLatentsStep, + WanSetTimestepsStep, + WanTextInputStep, +) +from .decoders import WanVaeDecoderStep +from .denoise import ( + Wan22Image2VideoDenoiseStep, +) +from .encoders import ( + WanImageResizeStep, + WanPrepareFirstFrameLatentsStep, + WanTextEncoderStep, + WanVaeEncoderStep, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# ==================== +# 1. VAE ENCODER +# ==================== + + +# auto_docstring +class WanImage2VideoVaeEncoderStep(SequentialPipelineBlocks): + """ + Image2Video Vae Image Encoder step that resize the image and encode the first frame image to its latent + representation + + Components: + vae (`AutoencoderKLWan`) video_processor (`VideoProcessor`) + + Inputs: + image (`Image`): + TODO: Add description. + height (`int`, *optional*, defaults to 480): + TODO: Add description. + width (`int`, *optional*, defaults to 832): + TODO: Add description. + num_frames (`int`, *optional*, defaults to 81): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + + Outputs: + resized_image (`Image`): + TODO: Add description. + first_frame_latents (`Tensor`): + video latent representation with the first frame image condition + image_condition_latents (`Tensor | NoneType`): + TODO: Add description. + """ + + model_name = "wan-i2v" + block_classes = [WanImageResizeStep, WanVaeEncoderStep, WanPrepareFirstFrameLatentsStep] + block_names = ["image_resize", "vae_encoder", "prepare_first_frame_latents"] + + @property + def description(self): + return "Image2Video Vae Image Encoder step that resize the image and encode the first frame image to its latent representation" + + +# ==================== +# 2. DENOISE +# ==================== + + +# inputs (text + image_condition_latents) -> set_timesteps -> prepare_latents -> denoise (latents) +# auto_docstring +class Wan22Image2VideoCoreDenoiseStep(SequentialPipelineBlocks): + """ + denoise block that takes encoded text and image latent conditions and runs the denoising process. + + Components: + transformer (`WanTransformer3DModel`) scheduler (`UniPCMultistepScheduler`) guider (`ClassifierFreeGuidance`) + guider_2 (`ClassifierFreeGuidance`) transformer_2 (`WanTransformer3DModel`) + + Configs: + boundary_ratio (default: 0.875): The boundary ratio to divide the denoising loop into high noise and low + noise stages. + + Inputs: + num_videos_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + prompt_embeds (`Tensor`): + Pre-generated text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + Pre-generated negative text embeddings. Can be generated from text_encoder step. + height (`None`, *optional*): + TODO: Add description. + width (`None`, *optional*): + TODO: Add description. + num_frames (`None`, *optional*): + TODO: Add description. + image_condition_latents (`None`, *optional*): + TODO: Add description. + num_inference_steps (`None`, *optional*, defaults to 50): + TODO: Add description. + timesteps (`None`, *optional*): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + latents (`Tensor | NoneType`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + attention_kwargs (`None`, *optional*): + TODO: Add description. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "wan-i2v" + block_classes = [ + WanTextInputStep, + WanAdditionalInputsStep(image_latent_inputs=["image_condition_latents"]), + WanSetTimestepsStep, + WanPrepareLatentsStep, + Wan22Image2VideoDenoiseStep, + ] + block_names = [ + "input", + "additional_inputs", + "set_timesteps", + "prepare_latents", + "denoise", + ] + + @property + def description(self): + return "denoise block that takes encoded text and image latent conditions and runs the denoising process." + + @property + def outputs(self): + return [OutputParam.template("latents")] + + +# ==================== +# 3. BLOCKS (Wan2.2 Image2Video) +# ==================== + + +# auto_docstring +class Wan22Image2VideoBlocks(SequentialPipelineBlocks): + """ + Modular pipeline for image-to-video using Wan2.2. + + Components: + text_encoder (`UMT5EncoderModel`) tokenizer (`AutoTokenizer`) guider (`ClassifierFreeGuidance`) vae + (`AutoencoderKLWan`) video_processor (`VideoProcessor`) transformer (`WanTransformer3DModel`) scheduler + (`UniPCMultistepScheduler`) guider_2 (`ClassifierFreeGuidance`) transformer_2 (`WanTransformer3DModel`) + + Configs: + boundary_ratio (default: 0.875): The boundary ratio to divide the denoising loop into high noise and low + noise stages. + + Inputs: + prompt (`None`, *optional*): + TODO: Add description. + negative_prompt (`None`, *optional*): + TODO: Add description. + max_sequence_length (`None`, *optional*, defaults to 512): + TODO: Add description. + image (`Image`): + TODO: Add description. + height (`int`, *optional*, defaults to 480): + TODO: Add description. + width (`int`, *optional*, defaults to 832): + TODO: Add description. + num_frames (`int`, *optional*, defaults to 81): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + num_videos_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + num_inference_steps (`None`, *optional*, defaults to 50): + TODO: Add description. + timesteps (`None`, *optional*): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + latents (`Tensor | NoneType`, *optional*): + TODO: Add description. + attention_kwargs (`None`, *optional*): + TODO: Add description. + output_type (`str`, *optional*, defaults to np): + The output type of the decoded videos + + Outputs: + videos (`list`): + The generated videos. + """ + + model_name = "wan-i2v" + block_classes = [ + WanTextEncoderStep, + WanImage2VideoVaeEncoderStep, + Wan22Image2VideoCoreDenoiseStep, + WanVaeDecoderStep, + ] + block_names = [ + "text_encoder", + "vae_encoder", + "denoise", + "decode", + ] + + @property + def description(self): + return "Modular pipeline for image-to-video using Wan2.2." + + @property + def outputs(self): + return [OutputParam.template("videos")] diff --git a/diffusers/modular_pipelines/wan/modular_blocks_wan_i2v.py b/diffusers/modular_pipelines/wan/modular_blocks_wan_i2v.py new file mode 100644 index 0000000000000000000000000000000000000000..c08db62c469a67298946d751326020d45c1ef646 --- /dev/null +++ b/diffusers/modular_pipelines/wan/modular_blocks_wan_i2v.py @@ -0,0 +1,481 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...utils import logging +from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks +from ..modular_pipeline_utils import OutputParam +from .before_denoise import ( + WanAdditionalInputsStep, + WanPrepareLatentsStep, + WanSetTimestepsStep, + WanTextInputStep, +) +from .decoders import WanVaeDecoderStep +from .denoise import ( + WanImage2VideoDenoiseStep, +) +from .encoders import ( + WanFirstLastFrameImageEncoderStep, + WanFirstLastFrameVaeEncoderStep, + WanImageCropResizeStep, + WanImageEncoderStep, + WanImageResizeStep, + WanPrepareFirstFrameLatentsStep, + WanPrepareFirstLastFrameLatentsStep, + WanTextEncoderStep, + WanVaeEncoderStep, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +# ==================== +# 1. IMAGE ENCODER +# ==================== + + +# wan2.1 I2V (first frame only) +# auto_docstring +class WanImage2VideoImageEncoderStep(SequentialPipelineBlocks): + """ + Image2Video Image Encoder step that resize the image and encode the image to generate the image embeddings + + Components: + image_processor (`CLIPImageProcessor`) image_encoder (`CLIPVisionModel`) + + Inputs: + image (`Image`): + TODO: Add description. + height (`int`, *optional*, defaults to 480): + TODO: Add description. + width (`int`, *optional*, defaults to 832): + TODO: Add description. + + Outputs: + resized_image (`Image`): + TODO: Add description. + image_embeds (`Tensor`): + The image embeddings + """ + + model_name = "wan-i2v" + block_classes = [WanImageResizeStep, WanImageEncoderStep] + block_names = ["image_resize", "image_encoder"] + + @property + def description(self): + return "Image2Video Image Encoder step that resize the image and encode the image to generate the image embeddings" + + +# wan2.1 FLF2V (first and last frame) +# auto_docstring +class WanFLF2VImageEncoderStep(SequentialPipelineBlocks): + """ + FLF2V Image Encoder step that resize and encode and encode the first and last frame images to generate the image + embeddings + + Components: + image_processor (`CLIPImageProcessor`) image_encoder (`CLIPVisionModel`) + + Inputs: + image (`Image`): + TODO: Add description. + height (`int`, *optional*, defaults to 480): + TODO: Add description. + width (`int`, *optional*, defaults to 832): + TODO: Add description. + last_image (`Image`): + The last frameimage + + Outputs: + resized_image (`Image`): + TODO: Add description. + resized_last_image (`Image`): + TODO: Add description. + image_embeds (`Tensor`): + The image embeddings + """ + + model_name = "wan-i2v" + block_classes = [WanImageResizeStep, WanImageCropResizeStep, WanFirstLastFrameImageEncoderStep] + block_names = ["image_resize", "last_image_resize", "image_encoder"] + + @property + def description(self): + return "FLF2V Image Encoder step that resize and encode and encode the first and last frame images to generate the image embeddings" + + +# wan2.1 Auto Image Encoder +# auto_docstring +class WanAutoImageEncoderStep(AutoPipelineBlocks): + """ + Image Encoder step that encode the image to generate the image embeddingsThis is an auto pipeline block that works + for image2video tasks. - `WanFLF2VImageEncoderStep` (flf2v) is used when `last_image` is provided. - + `WanImage2VideoImageEncoderStep` (image2video) is used when `image` is provided. - if `last_image` or `image` is + not provided, step will be skipped. + + Components: + image_processor (`CLIPImageProcessor`) image_encoder (`CLIPVisionModel`) + + Inputs: + image (`Image`, *optional*): + TODO: Add description. + height (`int`, *optional*, defaults to 480): + TODO: Add description. + width (`int`, *optional*, defaults to 832): + TODO: Add description. + last_image (`Image`, *optional*): + The last frameimage + + Outputs: + resized_image (`Image`): + TODO: Add description. + resized_last_image (`Image`): + TODO: Add description. + image_embeds (`Tensor`): + The image embeddings + """ + + block_classes = [WanFLF2VImageEncoderStep, WanImage2VideoImageEncoderStep] + block_names = ["flf2v_image_encoder", "image2video_image_encoder"] + block_trigger_inputs = ["last_image", "image"] + model_name = "wan-i2v" + + @property + def description(self): + return ( + "Image Encoder step that encode the image to generate the image embeddings" + + "This is an auto pipeline block that works for image2video tasks." + + " - `WanFLF2VImageEncoderStep` (flf2v) is used when `last_image` is provided." + + " - `WanImage2VideoImageEncoderStep` (image2video) is used when `image` is provided." + + " - if `last_image` or `image` is not provided, step will be skipped." + ) + + +# ==================== +# 2. VAE ENCODER +# ==================== + + +# wan2.1 I2V (first frame only) +# auto_docstring +class WanImage2VideoVaeEncoderStep(SequentialPipelineBlocks): + """ + Image2Video Vae Image Encoder step that resize the image and encode the first frame image to its latent + representation + + Components: + vae (`AutoencoderKLWan`) video_processor (`VideoProcessor`) + + Inputs: + image (`Image`): + TODO: Add description. + height (`int`, *optional*, defaults to 480): + TODO: Add description. + width (`int`, *optional*, defaults to 832): + TODO: Add description. + num_frames (`int`, *optional*, defaults to 81): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + + Outputs: + resized_image (`Image`): + TODO: Add description. + first_frame_latents (`Tensor`): + video latent representation with the first frame image condition + image_condition_latents (`Tensor | NoneType`): + TODO: Add description. + """ + + model_name = "wan-i2v" + block_classes = [WanImageResizeStep, WanVaeEncoderStep, WanPrepareFirstFrameLatentsStep] + block_names = ["image_resize", "vae_encoder", "prepare_first_frame_latents"] + + @property + def description(self): + return "Image2Video Vae Image Encoder step that resize the image and encode the first frame image to its latent representation" + + +# wan2.1 FLF2V (first and last frame) +# auto_docstring +class WanFLF2VVaeEncoderStep(SequentialPipelineBlocks): + """ + FLF2V Vae Image Encoder step that resize and encode and encode the first and last frame images to generate the + latent conditions + + Components: + vae (`AutoencoderKLWan`) video_processor (`VideoProcessor`) + + Inputs: + image (`Image`): + TODO: Add description. + height (`int`, *optional*, defaults to 480): + TODO: Add description. + width (`int`, *optional*, defaults to 832): + TODO: Add description. + last_image (`Image`): + The last frameimage + num_frames (`int`, *optional*, defaults to 81): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + + Outputs: + resized_image (`Image`): + TODO: Add description. + resized_last_image (`Image`): + TODO: Add description. + first_last_frame_latents (`Tensor`): + video latent representation with the first and last frame images condition + image_condition_latents (`Tensor | NoneType`): + TODO: Add description. + """ + + model_name = "wan-i2v" + block_classes = [ + WanImageResizeStep, + WanImageCropResizeStep, + WanFirstLastFrameVaeEncoderStep, + WanPrepareFirstLastFrameLatentsStep, + ] + block_names = ["image_resize", "last_image_resize", "vae_encoder", "prepare_first_last_frame_latents"] + + @property + def description(self): + return "FLF2V Vae Image Encoder step that resize and encode and encode the first and last frame images to generate the latent conditions" + + +# wan2.1 Auto Vae Encoder +# auto_docstring +class WanAutoVaeEncoderStep(AutoPipelineBlocks): + """ + Vae Image Encoder step that encode the image to generate the image latentsThis is an auto pipeline block that works + for image2video tasks. - `WanFLF2VVaeEncoderStep` (flf2v) is used when `last_image` is provided. - + `WanImage2VideoVaeEncoderStep` (image2video) is used when `image` is provided. - if `last_image` or `image` is not + provided, step will be skipped. + + Components: + vae (`AutoencoderKLWan`) video_processor (`VideoProcessor`) + + Inputs: + image (`Image`, *optional*): + TODO: Add description. + height (`int`, *optional*, defaults to 480): + TODO: Add description. + width (`int`, *optional*, defaults to 832): + TODO: Add description. + last_image (`Image`, *optional*): + The last frameimage + num_frames (`int`, *optional*, defaults to 81): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + + Outputs: + resized_image (`Image`): + TODO: Add description. + resized_last_image (`Image`): + TODO: Add description. + first_last_frame_latents (`Tensor`): + video latent representation with the first and last frame images condition + image_condition_latents (`Tensor | NoneType`): + TODO: Add description. + first_frame_latents (`Tensor`): + video latent representation with the first frame image condition + """ + + model_name = "wan-i2v" + block_classes = [WanFLF2VVaeEncoderStep, WanImage2VideoVaeEncoderStep] + block_names = ["flf2v_vae_encoder", "image2video_vae_encoder"] + block_trigger_inputs = ["last_image", "image"] + + @property + def description(self): + return ( + "Vae Image Encoder step that encode the image to generate the image latents" + + "This is an auto pipeline block that works for image2video tasks." + + " - `WanFLF2VVaeEncoderStep` (flf2v) is used when `last_image` is provided." + + " - `WanImage2VideoVaeEncoderStep` (image2video) is used when `image` is provided." + + " - if `last_image` or `image` is not provided, step will be skipped." + ) + + +# ==================== +# 3. DENOISE (inputs -> set_timesteps -> prepare_latents -> denoise) +# ==================== + + +# wan2.1 I2V core denoise (support both I2V and FLF2V) +# inputs (text + image_condition_latents) -> set_timesteps -> prepare_latents -> denoise (latents) +# auto_docstring +class WanImage2VideoCoreDenoiseStep(SequentialPipelineBlocks): + """ + denoise block that takes encoded text and image latent conditions and runs the denoising process. + + Components: + transformer (`WanTransformer3DModel`) scheduler (`UniPCMultistepScheduler`) guider (`ClassifierFreeGuidance`) + + Inputs: + num_videos_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + prompt_embeds (`Tensor`): + Pre-generated text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + Pre-generated negative text embeddings. Can be generated from text_encoder step. + height (`None`, *optional*): + TODO: Add description. + width (`None`, *optional*): + TODO: Add description. + num_frames (`None`, *optional*): + TODO: Add description. + image_condition_latents (`None`, *optional*): + TODO: Add description. + num_inference_steps (`None`, *optional*, defaults to 50): + TODO: Add description. + timesteps (`None`, *optional*): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + latents (`Tensor | NoneType`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + attention_kwargs (`None`, *optional*): + TODO: Add description. + image_embeds (`Tensor`): + TODO: Add description. + + Outputs: + batch_size (`int`): + Number of prompts, the final batch size of model inputs should be batch_size * num_videos_per_prompt + dtype (`dtype`): + Data type of model tensor inputs (determined by `transformer.dtype`) + latents (`Tensor`): + The initial latents to use for the denoising process + """ + + model_name = "wan-i2v" + block_classes = [ + WanTextInputStep, + WanAdditionalInputsStep(image_latent_inputs=["image_condition_latents"]), + WanSetTimestepsStep, + WanPrepareLatentsStep, + WanImage2VideoDenoiseStep, + ] + block_names = [ + "input", + "additional_inputs", + "set_timesteps", + "prepare_latents", + "denoise", + ] + + @property + def description(self): + return "denoise block that takes encoded text and image latent conditions and runs the denoising process." + + +# ==================== +# 4. BLOCKS (Wan2.1 Image2Video) +# ==================== + + +# wan2.1 Image2Video Auto Blocks +# auto_docstring +class WanImage2VideoAutoBlocks(SequentialPipelineBlocks): + """ + Auto Modular pipeline for image-to-video using Wan. + + Supported workflows: + - `image2video`: requires `image`, `prompt` + - `flf2v`: requires `last_image`, `image`, `prompt` + + Components: + text_encoder (`UMT5EncoderModel`) tokenizer (`AutoTokenizer`) guider (`ClassifierFreeGuidance`) + image_processor (`CLIPImageProcessor`) image_encoder (`CLIPVisionModel`) vae (`AutoencoderKLWan`) + video_processor (`VideoProcessor`) transformer (`WanTransformer3DModel`) scheduler + (`UniPCMultistepScheduler`) + + Inputs: + prompt (`None`, *optional*): + TODO: Add description. + negative_prompt (`None`, *optional*): + TODO: Add description. + max_sequence_length (`None`, *optional*, defaults to 512): + TODO: Add description. + image (`Image`, *optional*): + TODO: Add description. + height (`int`, *optional*, defaults to 480): + TODO: Add description. + width (`int`, *optional*, defaults to 832): + TODO: Add description. + last_image (`Image`, *optional*): + The last frameimage + num_frames (`int`, *optional*, defaults to 81): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + num_videos_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + image_condition_latents (`None`, *optional*): + TODO: Add description. + num_inference_steps (`None`, *optional*, defaults to 50): + TODO: Add description. + timesteps (`None`, *optional*): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + latents (`Tensor | NoneType`, *optional*): + TODO: Add description. + attention_kwargs (`None`, *optional*): + TODO: Add description. + image_embeds (`Tensor`): + TODO: Add description. + output_type (`str`, *optional*, defaults to np): + The output type of the decoded videos + + Outputs: + videos (`list`): + The generated videos. + """ + + model_name = "wan-i2v" + block_classes = [ + WanTextEncoderStep, + WanAutoImageEncoderStep, + WanAutoVaeEncoderStep, + WanImage2VideoCoreDenoiseStep, + WanVaeDecoderStep, + ] + block_names = [ + "text_encoder", + "image_encoder", + "vae_encoder", + "denoise", + "decode", + ] + + _workflow_map = { + "image2video": {"image": True, "prompt": True}, + "flf2v": {"last_image": True, "image": True, "prompt": True}, + } + + @property + def description(self): + return "Auto Modular pipeline for image-to-video using Wan." + + @property + def outputs(self): + return [OutputParam.template("videos")] diff --git a/diffusers/modular_pipelines/wan/modular_pipeline.py b/diffusers/modular_pipelines/wan/modular_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..0e52026a51bf9462752c566fb99df0b8de0e3814 --- /dev/null +++ b/diffusers/modular_pipelines/wan/modular_pipeline.py @@ -0,0 +1,141 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ...loaders import WanLoraLoaderMixin +from ...pipelines.pipeline_utils import StableDiffusionMixin +from ...utils import logging +from ..modular_pipeline import ModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class WanModularPipeline( + ModularPipeline, + StableDiffusionMixin, + WanLoraLoaderMixin, +): + """ + A ModularPipeline for Wan2.1 text2video. + + > [!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "WanBlocks" + + @property + def default_height(self): + return self.default_sample_height * self.vae_scale_factor_spatial + + @property + def default_width(self): + return self.default_sample_width * self.vae_scale_factor_spatial + + @property + def default_num_frames(self): + return (self.default_sample_num_frames - 1) * self.vae_scale_factor_temporal + 1 + + @property + def default_sample_height(self): + return 60 + + @property + def default_sample_width(self): + return 104 + + @property + def default_sample_num_frames(self): + return 21 + + @property + def patch_size_spatial(self): + patch_size_spatial = 2 + if hasattr(self, "transformer") and self.transformer is not None: + patch_size_spatial = self.transformer.config.patch_size[1] + return patch_size_spatial + + @property + def vae_scale_factor_spatial(self): + vae_scale_factor = 8 + if hasattr(self, "vae") and self.vae is not None: + vae_scale_factor = 2 ** len(self.vae.temperal_downsample) + return vae_scale_factor + + @property + def vae_scale_factor_temporal(self): + vae_scale_factor = 4 + if hasattr(self, "vae") and self.vae is not None: + vae_scale_factor = 2 ** sum(self.vae.temperal_downsample) + return vae_scale_factor + + @property + def num_channels_transformer(self): + num_channels_transformer = 16 + if hasattr(self, "transformer") and self.transformer is not None: + num_channels_transformer = self.transformer.config.in_channels + return num_channels_transformer + + @property + def num_channels_latents(self): + num_channels_latents = 16 + if hasattr(self, "vae") and self.vae is not None: + num_channels_latents = self.vae.config.z_dim + return num_channels_latents + + @property + def requires_unconditional_embeds(self): + requires_unconditional_embeds = False + + if hasattr(self, "guider") and self.guider is not None: + requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1 + + return requires_unconditional_embeds + + @property + def num_train_timesteps(self): + num_train_timesteps = 1000 + if hasattr(self, "scheduler") and self.scheduler is not None: + num_train_timesteps = self.scheduler.config.num_train_timesteps + return num_train_timesteps + + +class WanImage2VideoModularPipeline(WanModularPipeline): + """ + A ModularPipeline for Wan2.1 image2video (both I2V and FLF2V). + + > [!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "WanImage2VideoAutoBlocks" + + +class Wan22ModularPipeline(WanModularPipeline): + """ + A ModularPipeline for Wan2.2 text2video. + + > [!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "Wan22Blocks" + + +class Wan22Image2VideoModularPipeline(Wan22ModularPipeline): + """ + A ModularPipeline for Wan2.2 image2video. + + > [!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "Wan22Image2VideoBlocks" diff --git a/diffusers/modular_pipelines/z_image/__init__.py b/diffusers/modular_pipelines/z_image/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5c04008d33052ede093286a8a800b01b960d0e27 --- /dev/null +++ b/diffusers/modular_pipelines/z_image/__init__.py @@ -0,0 +1,47 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["modular_blocks_z_image"] = ["ZImageAutoBlocks"] + _import_structure["modular_pipeline"] = ["ZImageModularPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .modular_blocks_z_image import ZImageAutoBlocks + from .modular_pipeline import ZImageModularPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/modular_pipelines/z_image/before_denoise.py b/diffusers/modular_pipelines/z_image/before_denoise.py new file mode 100644 index 0000000000000000000000000000000000000000..8558f2c67f655c907979eba60520a99f2b139a0f --- /dev/null +++ b/diffusers/modular_pipelines/z_image/before_denoise.py @@ -0,0 +1,620 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect + +import torch + +from ...models import ZImageTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging +from ...utils.torch_utils import randn_tensor +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import ZImageModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# TODO(yiyi, aryan): We need another step before text encoder to set the `num_inference_steps` attribute for guider so that +# things like when to do guidance and how many conditions to be prepared can be determined. Currently, this is done by +# always assuming you want to do guidance in the Guiders. So, negative embeddings are prepared regardless of what the +# configuration of guider is. + + +def repeat_tensor_to_batch_size( + input_name: str, + input_tensor: torch.Tensor, + batch_size: int, + num_images_per_prompt: int = 1, +) -> torch.Tensor: + """Repeat tensor elements to match the final batch size. + + This function expands a tensor's batch dimension to match the final batch size (batch_size * num_images_per_prompt) + by repeating each element along dimension 0. + + The input tensor must have batch size 1 or batch_size. The function will: + - If batch size is 1: repeat each element (batch_size * num_images_per_prompt) times + - If batch size equals batch_size: repeat each element num_images_per_prompt times + + Args: + input_name (str): Name of the input tensor (used for error messages) + input_tensor (torch.Tensor): The tensor to repeat. Must have batch size 1 or batch_size. + batch_size (int): The base batch size (number of prompts) + num_images_per_prompt (int, optional): Number of images to generate per prompt. Defaults to 1. + + Returns: + torch.Tensor: The repeated tensor with final batch size (batch_size * num_images_per_prompt) + + Raises: + ValueError: If input_tensor is not a torch.Tensor or has invalid batch size + + Examples: + tensor = torch.tensor([[1, 2, 3]]) # shape: [1, 3] repeated = repeat_tensor_to_batch_size("image", tensor, + batch_size=2, num_images_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]]) - shape: + [4, 3] + + tensor = torch.tensor([[1, 2, 3], [4, 5, 6]]) # shape: [2, 3] repeated = repeat_tensor_to_batch_size("image", + tensor, batch_size=2, num_images_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [4, 5, 6], [4, 5, 6]]) + - shape: [4, 3] + """ + # make sure input is a tensor + if not isinstance(input_tensor, torch.Tensor): + raise ValueError(f"`{input_name}` must be a tensor") + + # make sure input tensor e.g. image_latents has batch size 1 or batch_size same as prompts + if input_tensor.shape[0] == 1: + repeat_by = batch_size * num_images_per_prompt + elif input_tensor.shape[0] == batch_size: + repeat_by = num_images_per_prompt + else: + raise ValueError( + f"`{input_name}` must have have batch size 1 or {batch_size}, but got {input_tensor.shape[0]}" + ) + + # expand the tensor to match the batch_size * num_images_per_prompt + input_tensor = input_tensor.repeat_interleave(repeat_by, dim=0) + + return input_tensor + + +def calculate_dimension_from_latents(latents: torch.Tensor, vae_scale_factor_spatial: int) -> tuple[int, int]: + """Calculate image dimensions from latent tensor dimensions. + + This function converts latent spatial dimensions to image spatial dimensions by multiplying the latent height/width + by the VAE scale factor. + + Args: + latents (torch.Tensor): The latent tensor. Must have 4 dimensions. + Expected shapes: [batch, channels, height, width] + vae_scale_factor (int): The scale factor used by the VAE to compress image spatial dimension. + By default, it is 16 + Returns: + tuple[int, int]: The calculated image dimensions as (height, width) + """ + latent_height, latent_width = latents.shape[2:] + height = latent_height * vae_scale_factor_spatial // 2 + width = latent_width * vae_scale_factor_spatial // 2 + + return height, width + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class ZImageTextInputStep(ModularPipelineBlocks): + model_name = "z-image" + + @property + def description(self) -> str: + return ( + "Input processing step that:\n" + " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n" + " 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_images_per_prompt`\n\n" + "All input tensors are expected to have either batch_size=1 or match the batch_size\n" + "of prompt_embeds. The tensors will be duplicated across the batch dimension to\n" + "have a final batch_size of batch_size * num_images_per_prompt." + ) + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("transformer", ZImageTransformer2DModel), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("num_images_per_prompt", default=1), + InputParam( + "prompt_embeds", + required=True, + type_hint=list[torch.Tensor], + description="Pre-generated text embeddings. Can be generated from text_encoder step.", + ), + InputParam( + "negative_prompt_embeds", + type_hint=list[torch.Tensor], + description="Pre-generated negative text embeddings. Can be generated from text_encoder step.", + ), + ] + + @property + def intermediate_outputs(self) -> list[str]: + return [ + OutputParam( + "batch_size", + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt", + ), + OutputParam( + "dtype", + type_hint=torch.dtype, + description="Data type of model tensor inputs (determined by `transformer.dtype`)", + ), + ] + + def check_inputs(self, components, block_state): + if block_state.prompt_embeds is not None and block_state.negative_prompt_embeds is not None: + if not isinstance(block_state.prompt_embeds, list): + raise ValueError( + f"`prompt_embeds` must be a list when passed directly, but got {type(block_state.prompt_embeds)}." + ) + if not isinstance(block_state.negative_prompt_embeds, list): + raise ValueError( + f"`negative_prompt_embeds` must be a list when passed directly, but got {type(block_state.negative_prompt_embeds)}." + ) + if len(block_state.prompt_embeds) != len(block_state.negative_prompt_embeds): + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same length when passed directly, but" + f" got: `prompt_embeds` {len(block_state.prompt_embeds)} != `negative_prompt_embeds`" + f" {len(block_state.negative_prompt_embeds)}." + ) + + @torch.no_grad() + def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(components, block_state) + + block_state.batch_size = len(block_state.prompt_embeds) + block_state.dtype = block_state.prompt_embeds[0].dtype + + if block_state.num_images_per_prompt > 1: + prompt_embeds = [pe for pe in block_state.prompt_embeds for _ in range(block_state.num_images_per_prompt)] + block_state.prompt_embeds = prompt_embeds + + if block_state.negative_prompt_embeds is not None: + negative_prompt_embeds = [ + npe for npe in block_state.negative_prompt_embeds for _ in range(block_state.num_images_per_prompt) + ] + block_state.negative_prompt_embeds = negative_prompt_embeds + + self.set_block_state(state, block_state) + + return components, state + + +class ZImageAdditionalInputsStep(ModularPipelineBlocks): + model_name = "z-image" + + def __init__( + self, + image_latent_inputs: list[str] = ["image_latents"], + additional_batch_inputs: list[str] = [], + ): + """Initialize a configurable step that standardizes the inputs for the denoising step. It:\n" + + This step handles multiple common tasks to prepare inputs for the denoising step: + 1. For encoded image latents, use it update height/width if None, and expands batch size + 2. For additional_batch_inputs: Only expands batch dimensions to match final batch size + + This is a dynamic block that allows you to configure which inputs to process. + + Args: + image_latent_inputs (list[str], optional): Names of image latent tensors to process. + In additional to adjust batch size of these inputs, they will be used to determine height/width. Can be + a single string or list of strings. Defaults to ["image_latents"]. + additional_batch_inputs (list[str], optional): + Names of additional conditional input tensors to expand batch size. These tensors will only have their + batch dimensions adjusted to match the final batch size. Can be a single string or list of strings. + Defaults to []. + + Examples: + # Configure to process image_latents (default behavior) ZImageAdditionalInputsStep() + + # Configure to process multiple image latent inputs + ZImageAdditionalInputsStep(image_latent_inputs=["image_latents", "control_image_latents"]) + + # Configure to process image latents and additional batch inputs ZImageAdditionalInputsStep( + image_latent_inputs=["image_latents"], additional_batch_inputs=["image_embeds"] + ) + """ + if not isinstance(image_latent_inputs, list): + image_latent_inputs = [image_latent_inputs] + if not isinstance(additional_batch_inputs, list): + additional_batch_inputs = [additional_batch_inputs] + + self._image_latent_inputs = image_latent_inputs + self._additional_batch_inputs = additional_batch_inputs + super().__init__() + + @property + def description(self) -> str: + # Functionality section + summary_section = ( + "Input processing step that:\n" + " 1. For image latent inputs: Updates height/width if None, and expands batch size\n" + " 2. For additional batch inputs: Expands batch dimensions to match final batch size" + ) + + # Inputs info + inputs_info = "" + if self._image_latent_inputs or self._additional_batch_inputs: + inputs_info = "\n\nConfigured inputs:" + if self._image_latent_inputs: + inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}" + if self._additional_batch_inputs: + inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}" + + # Placement guidance + placement_section = "\n\nThis block should be placed after the encoder steps and the text input step." + + return summary_section + inputs_info + placement_section + + @property + def inputs(self) -> list[InputParam]: + inputs = [ + InputParam(name="num_images_per_prompt", default=1), + InputParam(name="batch_size", required=True), + InputParam(name="height"), + InputParam(name="width"), + ] + + # Add image latent inputs + for image_latent_input_name in self._image_latent_inputs: + inputs.append(InputParam(name=image_latent_input_name)) + + # Add additional batch inputs + for input_name in self._additional_batch_inputs: + inputs.append(InputParam(name=input_name)) + + return inputs + + def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + # Process image latent inputs (height/width calculation, patchify, and batch expansion) + for image_latent_input_name in self._image_latent_inputs: + image_latent_tensor = getattr(block_state, image_latent_input_name) + if image_latent_tensor is None: + continue + + # 1. Calculate num_frames, height/width from latents + height, width = calculate_dimension_from_latents(image_latent_tensor, components.vae_scale_factor_spatial) + block_state.height = block_state.height or height + block_state.width = block_state.width or width + + # Process additional batch inputs (only batch expansion) + for input_name in self._additional_batch_inputs: + input_tensor = getattr(block_state, input_name) + if input_tensor is None: + continue + + # Only expand batch size + input_tensor = repeat_tensor_to_batch_size( + input_name=input_name, + input_tensor=input_tensor, + num_images_per_prompt=block_state.num_images_per_prompt, + batch_size=block_state.batch_size, + ) + + setattr(block_state, input_name, input_tensor) + + self.set_block_state(state, block_state) + return components, state + + +class ZImagePrepareLatentsStep(ModularPipelineBlocks): + model_name = "z-image" + + @property + def description(self) -> str: + return "Prepare latents step that prepares the latents for the text-to-video generation process" + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("height", type_hint=int), + InputParam("width", type_hint=int), + InputParam("latents", type_hint=torch.Tensor | None), + InputParam("num_images_per_prompt", type_hint=int, default=1), + InputParam("generator"), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`. Can be generated in input step.", + ), + InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process" + ) + ] + + def check_inputs(self, components, block_state): + if (block_state.height is not None and block_state.height % components.vae_scale_factor_spatial != 0) or ( + block_state.width is not None and block_state.width % components.vae_scale_factor_spatial != 0 + ): + raise ValueError( + f"`height` and `width` have to be divisible by {components.vae_scale_factor_spatial} but are {block_state.height} and {block_state.width}." + ) + + @staticmethod + # Copied from diffusers.pipelines.z_image.pipeline_z_image.ZImagePipeline.prepare_latents with self->comp + def prepare_latents( + comp, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + height = 2 * (int(height) // (comp.vae_scale_factor * 2)) + width = 2 * (int(width) // (comp.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + return latents + + @torch.no_grad() + def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(components, block_state) + + device = components._execution_device + dtype = torch.float32 + + block_state.height = block_state.height or components.default_height + block_state.width = block_state.width or components.default_width + + block_state.latents = self.prepare_latents( + components, + batch_size=block_state.batch_size * block_state.num_images_per_prompt, + num_channels_latents=components.num_channels_latents, + height=block_state.height, + width=block_state.width, + dtype=dtype, + device=device, + generator=block_state.generator, + latents=block_state.latents, + ) + + self.set_block_state(state, block_state) + + return components, state + + +class ZImageSetTimestepsStep(ModularPipelineBlocks): + model_name = "z-image" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return "Step that sets the scheduler's timesteps for inference. Need to run after prepare latents step." + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("latents", required=True), + InputParam("num_inference_steps", default=9), + InputParam("sigmas"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "timesteps", type_hint=torch.Tensor, description="The timesteps to use for the denoising process" + ), + ] + + @torch.no_grad() + def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + device = components._execution_device + + latent_height, latent_width = block_state.latents.shape[2], block_state.latents.shape[3] + image_seq_len = (latent_height // 2) * (latent_width // 2) # sequence length after patchify + + mu = calculate_shift( + image_seq_len, + base_seq_len=components.scheduler.config.get("base_image_seq_len", 256), + max_seq_len=components.scheduler.config.get("max_image_seq_len", 4096), + base_shift=components.scheduler.config.get("base_shift", 0.5), + max_shift=components.scheduler.config.get("max_shift", 1.15), + ) + components.scheduler.sigma_min = 0.0 + + block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( + components.scheduler, + block_state.num_inference_steps, + device, + sigmas=block_state.sigmas, + mu=mu, + ) + + self.set_block_state(state, block_state) + return components, state + + +class ZImageSetTimestepsWithStrengthStep(ModularPipelineBlocks): + model_name = "z-image" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return "Step that sets the scheduler's timesteps for inference with strength. Need to run after set timesteps step." + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("timesteps", required=True), + InputParam("num_inference_steps", required=True), + InputParam("strength", default=0.6), + ] + + def check_inputs(self, components, block_state): + if block_state.strength < 0.0 or block_state.strength > 1.0: + raise ValueError(f"Strength must be between 0.0 and 1.0, but got {block_state.strength}") + + @torch.no_grad() + def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(components, block_state) + + init_timestep = min(block_state.num_inference_steps * block_state.strength, block_state.num_inference_steps) + + t_start = int(max(block_state.num_inference_steps - init_timestep, 0)) + timesteps = components.scheduler.timesteps[t_start * components.scheduler.order :] + if hasattr(components.scheduler, "set_begin_index"): + components.scheduler.set_begin_index(t_start * components.scheduler.order) + + block_state.timesteps = timesteps + block_state.num_inference_steps = block_state.num_inference_steps - t_start + + self.set_block_state(state, block_state) + return components, state + + +class ZImagePrepareLatentswithImageStep(ModularPipelineBlocks): + model_name = "z-image" + + @property + def description(self) -> str: + return "step that prepares the latents with image condition, need to run after set timesteps and prepare latents step." + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("latents", required=True), + InputParam("image_latents", required=True), + InputParam("timesteps", required=True), + ] + + def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + latent_timestep = block_state.timesteps[:1].repeat(block_state.latents.shape[0]) + block_state.latents = components.scheduler.scale_noise( + block_state.image_latents, latent_timestep, block_state.latents + ) + + self.set_block_state(state, block_state) + return components, state diff --git a/diffusers/modular_pipelines/z_image/decoders.py b/diffusers/modular_pipelines/z_image/decoders.py new file mode 100644 index 0000000000000000000000000000000000000000..353253102376c7a0973ae1c457a314ad0cbe8ff5 --- /dev/null +++ b/diffusers/modular_pipelines/z_image/decoders.py @@ -0,0 +1,91 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import numpy as np +import PIL +import torch + +from ...configuration_utils import FrozenDict +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKL +from ...utils import logging +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class ZImageVaeDecoderStep(ModularPipelineBlocks): + model_name = "z-image" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKL), + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8 * 2}), + default_creation_method="from_config", + ), + ] + + @property + def description(self) -> str: + return "Step that decodes the denoised latents into images" + + @property + def inputs(self) -> list[tuple[str, Any]]: + return [ + InputParam( + "latents", + required=True, + ), + InputParam( + name="output_type", + default="pil", + type_hint=str, + description="The type of the output images, can be 'pil', 'np', 'pt'", + ), + ] + + @property + def intermediate_outputs(self) -> list[str]: + return [ + OutputParam( + "images", + type_hint=list[PIL.Image.Image, list[torch.Tensor], list[np.ndarray]], + description="The generated images, can be a PIL.Image.Image, torch.Tensor or a numpy array", + ) + ] + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + vae_dtype = components.vae.dtype + + latents = block_state.latents.to(vae_dtype) + latents = latents / components.vae.config.scaling_factor + components.vae.config.shift_factor + + block_state.images = components.vae.decode(latents, return_dict=False)[0] + block_state.images = components.image_processor.postprocess( + block_state.images, output_type=block_state.output_type + ) + + self.set_block_state(state, block_state) + + return components, state diff --git a/diffusers/modular_pipelines/z_image/denoise.py b/diffusers/modular_pipelines/z_image/denoise.py new file mode 100644 index 0000000000000000000000000000000000000000..863df312389a56794663a36082dce143e60bf80c --- /dev/null +++ b/diffusers/modular_pipelines/z_image/denoise.py @@ -0,0 +1,314 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import torch + +from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance +from ...models import ZImageTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging +from ..modular_pipeline import ( + BlockState, + LoopSequentialPipelineBlocks, + ModularPipelineBlocks, + PipelineState, +) +from ..modular_pipeline_utils import ComponentSpec, InputParam +from .modular_pipeline import ZImageModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class ZImageLoopBeforeDenoiser(ModularPipelineBlocks): + model_name = "z-image" + + @property + def description(self) -> str: + return ( + "step within the denoising loop that prepares the latent input for the denoiser. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `ZImageDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", + ), + InputParam( + "dtype", + required=True, + type_hint=torch.dtype, + description="The dtype of the model inputs. Can be generated in input step.", + ), + ] + + @torch.no_grad() + def __call__(self, components: ZImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + latents = block_state.latents.unsqueeze(2).to( + block_state.dtype + ) # [batch_size, num_channels, 1, height, width] + block_state.latent_model_input = list(latents.unbind(dim=0)) # list of [num_channels, 1, height, width] + + timestep = t.expand(latents.shape[0]).to(block_state.dtype) + timestep = (1000 - timestep) / 1000 + block_state.timestep = timestep + return components, block_state + + +class ZImageLoopDenoiser(ModularPipelineBlocks): + model_name = "z-image" + + def __init__( + self, + guider_input_fields: dict[str, Any] = {"cap_feats": ("prompt_embeds", "negative_prompt_embeds")}, + ): + """Initialize a denoiser block that calls the denoiser model. This block is used in Z-Image. + + Args: + guider_input_fields: A dictionary that maps each argument expected by the denoiser model + (for example, "encoder_hidden_states") to data stored on 'block_state'. The value can be either: + + - A tuple of strings. For instance, {"encoder_hidden_states": ("prompt_embeds", + "negative_prompt_embeds")} tells the guider to read `block_state.prompt_embeds` and + `block_state.negative_prompt_embeds` and pass them as the conditional and unconditional batches of + 'encoder_hidden_states'. + - A string. For example, {"encoder_hidden_image": "image_embeds"} makes the guider forward + `block_state.image_embeds` for both conditional and unconditional batches. + """ + if not isinstance(guider_input_fields, dict): + raise ValueError(f"guider_input_fields must be a dictionary but is {type(guider_input_fields)}") + self._guider_input_fields = guider_input_fields + super().__init__() + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 5.0, "enabled": False}), + default_creation_method="from_config", + ), + ComponentSpec("transformer", ZImageTransformer2DModel), + ] + + @property + def description(self) -> str: + return ( + "Step within the denoising loop that denoise the latents with guidance. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `ZImageDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> list[tuple[str, Any]]: + inputs = [ + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", + ), + InputParam( + kwargs_type="denoiser_input_fields", + description="The conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.", + ), + ] + guider_input_names = [] + uncond_guider_input_names = [] + for value in self._guider_input_fields.values(): + if isinstance(value, tuple): + guider_input_names.append(value[0]) + uncond_guider_input_names.append(value[1]) + else: + guider_input_names.append(value) + + for name in guider_input_names: + inputs.append(InputParam(name=name, required=True)) + for name in uncond_guider_input_names: + inputs.append(InputParam(name=name)) + return inputs + + @torch.no_grad() + def __call__( + self, components: ZImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor + ) -> PipelineState: + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + + # The guider splits model inputs into separate batches for conditional/unconditional predictions. + # For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}: + # you will get a guider_state with two batches: + # guider_state = [ + # {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch + # {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch + # ] + # Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG). + guider_state = components.guider.prepare_inputs_from_block_state(block_state, self._guider_input_fields) + + # run the denoiser for each guidance batch + for guider_state_batch in guider_state: + components.guider.prepare_models(components.transformer) + cond_kwargs = guider_state_batch.as_dict() + + def _convert_dtype(v, dtype): + if isinstance(v, torch.Tensor): + return v.to(dtype) + elif isinstance(v, list): + return [_convert_dtype(t, dtype) for t in v] + return v + + cond_kwargs = { + k: _convert_dtype(v, block_state.dtype) + for k, v in cond_kwargs.items() + if k in self._guider_input_fields.keys() + } + + # Predict the noise residual + # store the noise_pred in guider_state_batch so that we can apply guidance across all batches + model_out_list = components.transformer( + x=block_state.latent_model_input, + t=block_state.timestep, + return_dict=False, + **cond_kwargs, + )[0] + noise_pred = torch.stack(model_out_list, dim=0).squeeze(2) + guider_state_batch.noise_pred = -noise_pred + components.guider.cleanup_models(components.transformer) + + # Perform guidance + block_state.noise_pred = components.guider(guider_state)[0] + + return components, block_state + + +class ZImageLoopAfterDenoiser(ModularPipelineBlocks): + model_name = "z-image" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return ( + "step within the denoising loop that update the latents. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `ZImageDenoiseLoopWrapper`)" + ) + + @torch.no_grad() + def __call__(self, components: ZImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + # Perform scheduler step using the predicted output + latents_dtype = block_state.latents.dtype + block_state.latents = components.scheduler.step( + block_state.noise_pred.float(), + t, + block_state.latents.float(), + return_dict=False, + )[0] + + if block_state.latents.dtype != latents_dtype: + block_state.latents = block_state.latents.to(latents_dtype) + + return components, block_state + + +class ZImageDenoiseLoopWrapper(LoopSequentialPipelineBlocks): + model_name = "z-image" + + @property + def description(self) -> str: + return ( + "Pipeline block that iteratively denoise the latents over `timesteps`. " + "The specific steps with each iteration can be customized with `sub_blocks` attributes" + ) + + @property + def loop_expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ] + + @property + def loop_inputs(self) -> list[InputParam]: + return [ + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", + ), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", + ), + ] + + @torch.no_grad() + def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.num_warmup_steps = max( + len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0 + ) + + with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: + for i, t in enumerate(block_state.timesteps): + components, block_state = self.loop_step(components, block_state, i=i, t=t) + if i == len(block_state.timesteps) - 1 or ( + (i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0 + ): + progress_bar.update() + + self.set_block_state(state, block_state) + + return components, state + + +class ZImageDenoiseStep(ZImageDenoiseLoopWrapper): + block_classes = [ + ZImageLoopBeforeDenoiser, + ZImageLoopDenoiser( + guider_input_fields={ + "cap_feats": ("prompt_embeds", "negative_prompt_embeds"), + } + ), + ZImageLoopAfterDenoiser, + ] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. \n" + "Its loop logic is defined in `ZImageDenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" + " - `ZImageLoopBeforeDenoiser`\n" + " - `ZImageLoopDenoiser`\n" + " - `ZImageLoopAfterDenoiser`\n" + "This block supports text-to-image and image-to-image tasks for Z-Image." + ) diff --git a/diffusers/modular_pipelines/z_image/encoders.py b/diffusers/modular_pipelines/z_image/encoders.py new file mode 100644 index 0000000000000000000000000000000000000000..06deb823689302c99cad58643350c3f184deb28f --- /dev/null +++ b/diffusers/modular_pipelines/z_image/encoders.py @@ -0,0 +1,343 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import PIL +import torch +from transformers import Qwen2Tokenizer, Qwen3Model + +from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKL +from ...utils import is_ftfy_available, logging +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import ZImageModularPipeline + + +if is_ftfy_available(): + pass + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def get_qwen_prompt_embeds( + text_encoder: Qwen3Model, + tokenizer: Qwen2Tokenizer, + prompt: str | list[str], + device: torch.device, + max_sequence_length: int = 512, +) -> list[torch.Tensor]: + prompt = [prompt] if isinstance(prompt, str) else prompt + + for i, prompt_item in enumerate(prompt): + messages = [ + {"role": "user", "content": prompt_item}, + ] + prompt_item = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=True, + ) + prompt[i] = prompt_item + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device) + prompt_masks = text_inputs.attention_mask.to(device).bool() + + prompt_embeds = text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_masks, + output_hidden_states=True, + ).hidden_states[-2] + + prompt_embeds_list = [] + + for i in range(len(prompt_embeds)): + prompt_embeds_list.append(prompt_embeds[i][prompt_masks[i]]) + + return prompt_embeds_list + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +def encode_vae_image( + image_tensor: torch.Tensor, + vae: AutoencoderKL, + generator: torch.Generator, + device: torch.device, + dtype: torch.dtype, + latent_channels: int = 16, +): + if not isinstance(image_tensor, torch.Tensor): + raise ValueError(f"Expected image_tensor to be a tensor, got {type(image_tensor)}.") + + if isinstance(generator, list) and len(generator) != image_tensor.shape[0]: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but it is not same as number of images {image_tensor.shape[0]}." + ) + + image_tensor = image_tensor.to(device=device, dtype=dtype) + + if isinstance(generator, list): + image_latents = [ + retrieve_latents(vae.encode(image_tensor[i : i + 1]), generator=generator[i]) + for i in range(image_tensor.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(vae.encode(image_tensor), generator=generator) + + image_latents = (image_latents - vae.config.shift_factor) * vae.config.scaling_factor + + return image_latents + + +class ZImageTextEncoderStep(ModularPipelineBlocks): + model_name = "z-image" + + @property + def description(self) -> str: + return "Text Encoder step that generate text_embeddings to guide the video generation" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("text_encoder", Qwen3Model), + ComponentSpec("tokenizer", Qwen2Tokenizer), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 5.0, "enabled": False}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("prompt"), + InputParam("negative_prompt"), + InputParam("max_sequence_length", default=512), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "prompt_embeds", + type_hint=list[torch.Tensor], + kwargs_type="denoiser_input_fields", + description="text embeddings used to guide the image generation", + ), + OutputParam( + "negative_prompt_embeds", + type_hint=list[torch.Tensor], + kwargs_type="denoiser_input_fields", + description="negative text embeddings used to guide the image generation", + ), + ] + + @staticmethod + def check_inputs(block_state): + if block_state.prompt is not None and ( + not isinstance(block_state.prompt, str) and not isinstance(block_state.prompt, list) + ): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}") + + @staticmethod + def encode_prompt( + components, + prompt: str, + device: torch.device | None = None, + prepare_unconditional_embeds: bool = True, + negative_prompt: str | None = None, + max_sequence_length: int = 512, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + prepare_unconditional_embeds (`bool`): + whether to use prepare unconditional embeddings or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + max_sequence_length (`int`, defaults to `512`): + The maximum number of text tokens to be used for the generation process. + """ + device = device or components._execution_device + if not isinstance(prompt, list): + prompt = [prompt] + batch_size = len(prompt) + + prompt_embeds = get_qwen_prompt_embeds( + text_encoder=components.text_encoder, + tokenizer=components.tokenizer, + prompt=prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + negative_prompt_embeds = None + if prepare_unconditional_embeds: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = get_qwen_prompt_embeds( + text_encoder=components.text_encoder, + tokenizer=components.tokenizer, + prompt=negative_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + return prompt_embeds, negative_prompt_embeds + + @torch.no_grad() + def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState: + # Get inputs and intermediates + block_state = self.get_block_state(state) + self.check_inputs(block_state) + + block_state.device = components._execution_device + + # Encode input prompt + ( + block_state.prompt_embeds, + block_state.negative_prompt_embeds, + ) = self.encode_prompt( + components=components, + prompt=block_state.prompt, + device=block_state.device, + prepare_unconditional_embeds=components.requires_unconditional_embeds, + negative_prompt=block_state.negative_prompt, + max_sequence_length=block_state.max_sequence_length, + ) + + # Add outputs + self.set_block_state(state, block_state) + return components, state + + +class ZImageVaeImageEncoderStep(ModularPipelineBlocks): + model_name = "z-image" + + @property + def description(self) -> str: + return "Vae Image Encoder step that generate condition_latents based on image to guide the image generation" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKL), + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8 * 2}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("image", type_hint=PIL.Image.Image, required=True), + InputParam("height"), + InputParam("width"), + InputParam("generator"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "image_latents", + type_hint=torch.Tensor, + description="video latent representation with the first frame image condition", + ), + ] + + @staticmethod + def check_inputs(components, block_state): + if (block_state.height is not None and block_state.height % components.vae_scale_factor_spatial != 0) or ( + block_state.width is not None and block_state.width % components.vae_scale_factor_spatial != 0 + ): + raise ValueError( + f"`height` and `width` have to be divisible by {components.vae_scale_factor_spatial} but are {block_state.height} and {block_state.width}." + ) + + def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(components, block_state) + + image = block_state.image + + device = components._execution_device + dtype = torch.float32 + vae_dtype = components.vae.dtype + + image_tensor = components.image_processor.preprocess( + image, height=block_state.height, width=block_state.width + ).to(device=device, dtype=dtype) + + block_state.image_latents = encode_vae_image( + image_tensor=image_tensor, + vae=components.vae, + generator=block_state.generator, + device=device, + dtype=vae_dtype, + latent_channels=components.num_channels_latents, + ) + + self.set_block_state(state, block_state) + return components, state diff --git a/diffusers/modular_pipelines/z_image/modular_blocks_z_image.py b/diffusers/modular_pipelines/z_image/modular_blocks_z_image.py new file mode 100644 index 0000000000000000000000000000000000000000..23e20d55fb1e8e3ed23fcb141ce525b3c6e8270b --- /dev/null +++ b/diffusers/modular_pipelines/z_image/modular_blocks_z_image.py @@ -0,0 +1,334 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...utils import logging +from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks +from ..modular_pipeline_utils import OutputParam +from .before_denoise import ( + ZImageAdditionalInputsStep, + ZImagePrepareLatentsStep, + ZImagePrepareLatentswithImageStep, + ZImageSetTimestepsStep, + ZImageSetTimestepsWithStrengthStep, + ZImageTextInputStep, +) +from .decoders import ZImageVaeDecoderStep +from .denoise import ( + ZImageDenoiseStep, +) +from .encoders import ( + ZImageTextEncoderStep, + ZImageVaeImageEncoderStep, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# ==================== +# 1. DENOISE +# ==================== + + +# text2image: inputs(text) -> set_timesteps -> prepare_latents -> denoise +# auto_docstring +class ZImageCoreDenoiseStep(SequentialPipelineBlocks): + """ + denoise block that takes encoded conditions and runs the denoising process. + + Components: + transformer (`ZImageTransformer2DModel`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider + (`ClassifierFreeGuidance`) + + Inputs: + num_images_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + prompt_embeds (`list`): + Pre-generated text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`list`, *optional*): + Pre-generated negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + TODO: Add description. + width (`int`, *optional*): + TODO: Add description. + latents (`Tensor | NoneType`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + num_inference_steps (`None`, *optional*, defaults to 9): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + **denoiser_input_fields (`None`, *optional*): + The conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + block_classes = [ + ZImageTextInputStep, + ZImagePrepareLatentsStep, + ZImageSetTimestepsStep, + ZImageDenoiseStep, + ] + block_names = ["input", "prepare_latents", "set_timesteps", "denoise"] + + @property + def description(self): + return "denoise block that takes encoded conditions and runs the denoising process." + + @property + def outputs(self): + return [OutputParam.template("latents")] + + +# image2image: inputs(text + image_latents) -> prepare_latents -> set_timesteps -> set_timesteps_with_strength -> prepare_latents_with_image -> denoise +# auto_docstring +class ZImageImage2ImageCoreDenoiseStep(SequentialPipelineBlocks): + """ + denoise block that takes encoded text and image latent conditions and runs the denoising process. + + Components: + transformer (`ZImageTransformer2DModel`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider + (`ClassifierFreeGuidance`) + + Inputs: + num_images_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + prompt_embeds (`list`): + Pre-generated text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`list`, *optional*): + Pre-generated negative text embeddings. Can be generated from text_encoder step. + height (`None`, *optional*): + TODO: Add description. + width (`None`, *optional*): + TODO: Add description. + image_latents (`None`, *optional*): + TODO: Add description. + latents (`Tensor | NoneType`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + num_inference_steps (`None`, *optional*, defaults to 9): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + strength (`None`, *optional*, defaults to 0.6): + TODO: Add description. + **denoiser_input_fields (`None`, *optional*): + The conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + block_classes = [ + ZImageTextInputStep, + ZImageAdditionalInputsStep(image_latent_inputs=["image_latents"]), + ZImagePrepareLatentsStep, + ZImageSetTimestepsStep, + ZImageSetTimestepsWithStrengthStep, + ZImagePrepareLatentswithImageStep, + ZImageDenoiseStep, + ] + block_names = [ + "input", + "additional_inputs", + "prepare_latents", + "set_timesteps", + "set_timesteps_with_strength", + "prepare_latents_with_image", + "denoise", + ] + + @property + def description(self): + return "denoise block that takes encoded text and image latent conditions and runs the denoising process." + + @property + def outputs(self): + return [OutputParam.template("latents")] + + +# auto_docstring +class ZImageAutoDenoiseStep(AutoPipelineBlocks): + """ + Denoise step that iteratively denoise the latents. This is a auto pipeline block that works for text2image and + image2image tasks. - `ZImageCoreDenoiseStep` (text2image) for text2image tasks. - + `ZImageImage2ImageCoreDenoiseStep` (image2image) for image2image tasks. - if `image_latents` is provided, + `ZImageImage2ImageCoreDenoiseStep` will be used. + - if `image_latents` is not provided, `ZImageCoreDenoiseStep` will be used. + + Components: + transformer (`ZImageTransformer2DModel`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider + (`ClassifierFreeGuidance`) + + Inputs: + num_images_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + prompt_embeds (`list`): + Pre-generated text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`list`, *optional*): + Pre-generated negative text embeddings. Can be generated from text_encoder step. + height (`None`, *optional*): + TODO: Add description. + width (`None`, *optional*): + TODO: Add description. + image_latents (`None`, *optional*): + TODO: Add description. + latents (`Tensor | NoneType`): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + num_inference_steps (`None`): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + strength (`None`, *optional*, defaults to 0.6): + TODO: Add description. + **denoiser_input_fields (`None`, *optional*): + The conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + block_classes = [ + ZImageImage2ImageCoreDenoiseStep, + ZImageCoreDenoiseStep, + ] + block_names = ["image2image", "text2image"] + block_trigger_inputs = ["image_latents", None] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. " + "This is a auto pipeline block that works for text2image and image2image tasks." + " - `ZImageCoreDenoiseStep` (text2image) for text2image tasks." + " - `ZImageImage2ImageCoreDenoiseStep` (image2image) for image2image tasks." + + " - if `image_latents` is provided, `ZImageImage2ImageCoreDenoiseStep` will be used.\n" + + " - if `image_latents` is not provided, `ZImageCoreDenoiseStep` will be used.\n" + ) + + +# auto_docstring +class ZImageAutoVaeImageEncoderStep(AutoPipelineBlocks): + """ + Vae Image Encoder step that encode the image to generate the image latents + + Components: + vae (`AutoencoderKL`) image_processor (`VaeImageProcessor`) + + Inputs: + image (`Image`, *optional*): + TODO: Add description. + height (`None`, *optional*): + TODO: Add description. + width (`None`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + + Outputs: + image_latents (`Tensor`): + video latent representation with the first frame image condition + """ + + block_classes = [ZImageVaeImageEncoderStep] + block_names = ["vae_encoder"] + block_trigger_inputs = ["image"] + + @property + def description(self) -> str: + return "Vae Image Encoder step that encode the image to generate the image latents" + +"This is an auto pipeline block that works for image2image tasks." + +" - `ZImageVaeImageEncoderStep` is used when `image` is provided." + +" - if `image` is not provided, step will be skipped." + + +# auto_docstring +class ZImageAutoBlocks(SequentialPipelineBlocks): + """ + Auto Modular pipeline for text-to-image and image-to-image using ZImage. + + Supported workflows: + - `text2image`: requires `prompt` + - `image2image`: requires `image`, `prompt` + + Components: + text_encoder (`Qwen3Model`) tokenizer (`Qwen2Tokenizer`) guider (`ClassifierFreeGuidance`) vae + (`AutoencoderKL`) image_processor (`VaeImageProcessor`) transformer (`ZImageTransformer2DModel`) scheduler + (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + prompt (`None`, *optional*): + TODO: Add description. + negative_prompt (`None`, *optional*): + TODO: Add description. + max_sequence_length (`None`, *optional*, defaults to 512): + TODO: Add description. + image (`Image`, *optional*): + TODO: Add description. + height (`None`, *optional*): + TODO: Add description. + width (`None`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + num_images_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + image_latents (`None`, *optional*): + TODO: Add description. + latents (`Tensor | NoneType`): + TODO: Add description. + num_inference_steps (`None`): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + strength (`None`, *optional*, defaults to 0.6): + TODO: Add description. + **denoiser_input_fields (`None`, *optional*): + The conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + output_type (`str`, *optional*, defaults to pil): + The type of the output images, can be 'pil', 'np', 'pt' + + Outputs: + images (`list`): + Generated images. + """ + + block_classes = [ + ZImageTextEncoderStep, + ZImageAutoVaeImageEncoderStep, + ZImageAutoDenoiseStep, + ZImageVaeDecoderStep, + ] + block_names = ["text_encoder", "vae_encoder", "denoise", "decode"] + _workflow_map = { + "text2image": {"prompt": True}, + "image2image": {"image": True, "prompt": True}, + } + + @property + def description(self) -> str: + return "Auto Modular pipeline for text-to-image and image-to-image using ZImage." + + @property + def outputs(self): + return [OutputParam.template("images")] diff --git a/diffusers/modular_pipelines/z_image/modular_pipeline.py b/diffusers/modular_pipelines/z_image/modular_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..f1d8e53a3639d481134c418c47a4fc0a9a95400f --- /dev/null +++ b/diffusers/modular_pipelines/z_image/modular_pipeline.py @@ -0,0 +1,72 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ...loaders import ZImageLoraLoaderMixin +from ...utils import logging +from ..modular_pipeline import ModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class ZImageModularPipeline( + ModularPipeline, + ZImageLoraLoaderMixin, +): + """ + A ModularPipeline for Z-Image. + + > [!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "ZImageAutoBlocks" + + @property + def default_height(self): + return 1024 + + @property + def default_width(self): + return 1024 + + @property + def vae_scale_factor_spatial(self): + vae_scale_factor_spatial = 16 + if hasattr(self, "image_processor") and self.image_processor is not None: + vae_scale_factor_spatial = self.image_processor.config.vae_scale_factor + return vae_scale_factor_spatial + + @property + def vae_scale_factor(self): + vae_scale_factor = 8 + if hasattr(self, "vae") and self.vae is not None: + vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + return vae_scale_factor + + @property + def num_channels_latents(self): + num_channels_latents = 16 + if hasattr(self, "transformer") and self.transformer is not None: + num_channels_latents = self.transformer.config.in_channels + return num_channels_latents + + @property + def requires_unconditional_embeds(self): + requires_unconditional_embeds = False + + if hasattr(self, "guider") and self.guider is not None: + requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1 + + return requires_unconditional_embeds diff --git a/diffusers/pipelines/allegro/__init__.py b/diffusers/pipelines/allegro/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2162b825e0a2d59aa79430fc0b66f664b6feccb5 --- /dev/null +++ b/diffusers/pipelines/allegro/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_allegro"] = ["AllegroPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_allegro import AllegroPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/pipelines/allegro/pipeline_allegro.py b/diffusers/pipelines/allegro/pipeline_allegro.py new file mode 100644 index 0000000000000000000000000000000000000000..e54e9ed207396b09b4a9255938f40ace37a12218 --- /dev/null +++ b/diffusers/pipelines/allegro/pipeline_allegro.py @@ -0,0 +1,986 @@ +# Copyright 2025 The RhymesAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import inspect +import math +import re +import urllib.parse as ul +from typing import Callable + +import torch +from transformers import T5EncoderModel, T5Tokenizer + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...models import AllegroTransformer3DModel, AutoencoderKLAllegro +from ...models.embeddings import get_3d_rotary_pos_embed_allegro +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + BACKENDS_MAPPING, + deprecate, + is_bs4_available, + is_ftfy_available, + is_torch_xla_available, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from .pipeline_output import AllegroPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import AutoencoderKLAllegro, AllegroPipeline + >>> from diffusers.utils import export_to_video + + >>> vae = AutoencoderKLAllegro.from_pretrained("rhymes-ai/Allegro", subfolder="vae", torch_dtype=torch.float32) + >>> pipe = AllegroPipeline.from_pretrained("rhymes-ai/Allegro", vae=vae, torch_dtype=torch.bfloat16).to("cuda") + >>> pipe.enable_vae_tiling() + + >>> prompt = ( + ... "A seaside harbor with bright sunlight and sparkling seawater, with many boats in the water. From an aerial view, " + ... "the boats vary in size and color, some moving and some stationary. Fishing boats in the water suggest that this " + ... "location might be a popular spot for docking fishing boats." + ... ) + >>> video = pipe(prompt, guidance_scale=7.5, max_sequence_length=512).frames[0] + >>> export_to_video(video, "output.mp4", fps=15) + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class AllegroPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using Allegro. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AllegroAutoEncoderKL3D`]): + Variational Auto-Encoder (VAE) Model to encode and decode video to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. PixArt-Alpha uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the + [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. + tokenizer (`T5Tokenizer`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`AllegroTransformer3DModel`]): + A text conditioned `AllegroTransformer3DModel` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + """ + + bad_punct_regex = re.compile( + r"[" + + "#®•©™&@·º½¾¿¡§~" + + r"\)" + + r"\(" + + r"\]" + + r"\[" + + r"\}" + + r"\{" + + r"\|" + + "\\" + + r"\/" + + r"\*" + + r"]{1,}" + ) # noqa + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKLAllegro, + transformer: AllegroTransformer3DModel, + scheduler: KarrasDiffusionSchedulers, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + self.vae_scale_factor_spatial = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + ) + self.vae_scale_factor_temporal = ( + self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 4 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + # Copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.PixArtAlphaPipeline.encode_prompt with 120->512, num_images_per_prompt->num_videos_per_prompt + def encode_prompt( + self, + prompt: str | list[str], + do_classifier_free_guidance: bool = True, + negative_prompt: str = "", + num_videos_per_prompt: int = 1, + device: torch.device | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + clean_caption: bool = False, + max_sequence_length: int = 512, + **kwargs, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For + PixArt-Alpha, this should be "". + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_videos_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the "" + string. + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 512): Maximum sequence length to use for the prompt. + """ + + if "mask_feature" in kwargs: + deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version." + deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False) + + if device is None: + device = self._execution_device + + # See Section 3.1. of the paper. + max_length = max_sequence_length + + if prompt_embeds is None: + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because T5 can only handle sequences up to" + f" {max_length} tokens: {removed_text}" + ) + + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.to(device) + + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask) + prompt_embeds = prompt_embeds[0] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.transformer is not None: + dtype = self.transformer.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.repeat(1, num_videos_per_prompt) + prompt_attention_mask = prompt_attention_mask.view(bs_embed * num_videos_per_prompt, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens = [negative_prompt] * bs_embed if isinstance(negative_prompt, str) else negative_prompt + uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + negative_prompt_attention_mask = uncond_input.attention_mask + negative_prompt_attention_mask = negative_prompt_attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1) + + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(1, num_videos_per_prompt) + negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed * num_videos_per_prompt, -1) + else: + negative_prompt_embeds = None + negative_prompt_attention_mask = None + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + num_frames, + height, + width, + callback_on_step_end_tensor_inputs, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if num_frames <= 0: + raise ValueError(f"`num_frames` have to be positive but is {num_frames}.") + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip addresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if num_frames % 2 == 0: + num_frames = math.ceil(num_frames / self.vae_scale_factor_temporal) + else: + num_frames = math.ceil((num_frames - 1) / self.vae_scale_factor_temporal) + 1 + + shape = ( + batch_size, + num_channels_latents, + num_frames, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + latents = 1 / self.vae.config.scaling_factor * latents + frames = self.vae.decode(latents).sample + frames = frames.permute(0, 2, 1, 3, 4) # [batch_size, channels, num_frames, height, width] + return frames + + def _prepare_rotary_positional_embeddings( + self, + batch_size: int, + height: int, + width: int, + num_frames: int, + device: torch.device, + ): + grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + + start, stop = (0, 0), (grid_height, grid_width) + freqs_t, freqs_h, freqs_w, grid_t, grid_h, grid_w = get_3d_rotary_pos_embed_allegro( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=(start, stop), + grid_size=(grid_height, grid_width), + temporal_size=num_frames, + interpolation_scale=( + self.transformer.config.interpolation_scale_t, + self.transformer.config.interpolation_scale_h, + self.transformer.config.interpolation_scale_w, + ), + device=device, + ) + + grid_t = grid_t.to(dtype=torch.long) + grid_h = grid_h.to(dtype=torch.long) + grid_w = grid_w.to(dtype=torch.long) + + pos = torch.cartesian_prod(grid_t, grid_h, grid_w) + pos = pos.reshape(-1, 3).transpose(0, 1).reshape(3, 1, -1).contiguous() + grid_t, grid_h, grid_w = pos + + return (freqs_t, freqs_h, freqs_w), (grid_t, grid_h, grid_w) + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." + deprecate( + "enable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." + deprecate( + "disable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." + deprecate( + "enable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." + deprecate( + "disable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.disable_tiling() + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str = "", + num_inference_steps: int = 100, + timesteps: list[int] = None, + guidance_scale: float = 7.5, + num_frames: int | None = None, + height: int | None = None, + width: int | None = None, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + clean_caption: bool = True, + max_sequence_length: int = 512, + ) -> AllegroPipelineOutput | tuple: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the video generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality video at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps are used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate videos that are closely linked to + the text `prompt`, usually at the expense of lower video quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + num_frames: (`int`, *optional*, defaults to 88): + The number controls the generated video frames. + height (`int`, *optional*, defaults to self.unet.config.sample_size): + The height in pixels of the generated video. + width (`int`, *optional*, defaults to self.unet.config.sample_size): + The width in pixels of the generated video. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate video. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + max_sequence_length (`int` defaults to `512`): + Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.allegro.pipeline_output.AllegroPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.allegro.pipeline_output.AllegroPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated videos. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + num_frames = num_frames or self.transformer.config.sample_frames * self.vae_scale_factor_temporal + height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial + width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial + + self.check_inputs( + prompt, + num_frames, + height, + width, + callback_on_step_end_tensor_inputs, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) + self._guidance_scale = guidance_scale + self._current_timestep = None + self._interrupt = False + + # 2. Default height and width to transformer + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt, + do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + if prompt_embeds.ndim == 3: + prompt_embeds = prompt_embeds.unsqueeze(1) # b l d -> b 1 l d + + # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps + ) + self.scheduler.set_timesteps(num_inference_steps, device=device) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare rotary embeddings + image_rotary_emb = self._prepare_rotary_positional_embeddings( + batch_size, height, width, latents.size(2), device + ) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # predict noise model_output + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute previous image: x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) + video = self.decode_latents(latents) + video = video[:, :, :num_frames, :height, :width] + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return AllegroPipelineOutput(frames=video) diff --git a/diffusers/pipelines/allegro/pipeline_output.py b/diffusers/pipelines/allegro/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..bf85a4954ce986c3cea771dcc1258596e857951f --- /dev/null +++ b/diffusers/pipelines/allegro/pipeline_output.py @@ -0,0 +1,22 @@ +from dataclasses import dataclass + +import numpy as np +import PIL +import torch + +from diffusers.utils import BaseOutput + + +@dataclass +class AllegroPipelineOutput(BaseOutput): + r""" + Output class for Allegro pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or list[list[PIL.Image.Image]]): + list of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + frames: torch.Tensor | np.ndarray | list[list[PIL.Image.Image]] diff --git a/diffusers/pipelines/amused/__init__.py b/diffusers/pipelines/amused/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3c4d07a426b54fabfcdf35bfb8e4486cd828b3b3 --- /dev/null +++ b/diffusers/pipelines/amused/__init__.py @@ -0,0 +1,62 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import ( + AmusedImg2ImgPipeline, + AmusedInpaintPipeline, + AmusedPipeline, + ) + + _dummy_objects.update( + { + "AmusedPipeline": AmusedPipeline, + "AmusedImg2ImgPipeline": AmusedImg2ImgPipeline, + "AmusedInpaintPipeline": AmusedInpaintPipeline, + } + ) +else: + _import_structure["pipeline_amused"] = ["AmusedPipeline"] + _import_structure["pipeline_amused_img2img"] = ["AmusedImg2ImgPipeline"] + _import_structure["pipeline_amused_inpaint"] = ["AmusedInpaintPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import ( + AmusedPipeline, + ) + else: + from .pipeline_amused import AmusedPipeline + from .pipeline_amused_img2img import AmusedImg2ImgPipeline + from .pipeline_amused_inpaint import AmusedInpaintPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/pipelines/amused/pipeline_amused.py b/diffusers/pipelines/amused/pipeline_amused.py new file mode 100644 index 0000000000000000000000000000000000000000..b23adf0d2152b8eafc2cb0d772fb498d5b161661 --- /dev/null +++ b/diffusers/pipelines/amused/pipeline_amused.py @@ -0,0 +1,342 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Callable + +import torch +from transformers import CLIPTextModelWithProjection, CLIPTokenizer + +from ...image_processor import VaeImageProcessor +from ...models import UVit2DModel, VQModel +from ...schedulers import AmusedScheduler +from ...utils import is_torch_xla_available, replace_example_docstring +from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, ImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import AmusedPipeline + + >>> pipe = AmusedPipeline.from_pretrained("amused/amused-512", variant="fp16", torch_dtype=torch.float16) + >>> pipe = pipe.to("cuda") + + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt).images[0] + ``` +""" + + +class AmusedPipeline(DeprecatedPipelineMixin, DiffusionPipeline): + _last_supported_version = "0.33.1" + image_processor: VaeImageProcessor + vqvae: VQModel + tokenizer: CLIPTokenizer + text_encoder: CLIPTextModelWithProjection + transformer: UVit2DModel + scheduler: AmusedScheduler + + model_cpu_offload_seq = "text_encoder->transformer->vqvae" + + def __init__( + self, + vqvae: VQModel, + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModelWithProjection, + transformer: UVit2DModel, + scheduler: AmusedScheduler, + ): + super().__init__() + + self.register_modules( + vqvae=vqvae, + tokenizer=tokenizer, + text_encoder=text_encoder, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vqvae.config.block_out_channels) - 1) if getattr(self, "vqvae", None) else 8 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: list[str] | str | None = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 12, + guidance_scale: float = 10.0, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | None = None, + latents: torch.IntTensor | None = None, + prompt_embeds: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_encoder_hidden_states: torch.Tensor | None = None, + output_type="pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + cross_attention_kwargs: dict[str, Any] | None = None, + micro_conditioning_aesthetic_score: int = 6, + micro_conditioning_crop_coord: tuple[int, int] = (0, 0), + temperature: int | tuple[int, int] | list[int] = (2, 0), + ): + """ + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.transformer.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 16): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 10.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.IntTensor`, *optional*): + Pre-generated tokens representing latent vectors in `self.vqvae`, to be used as inputs for image + generation. If not provided, the starting latents will be completely masked. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. A single vector from the + pooled and projected final hidden states. + encoder_hidden_states (`torch.Tensor`, *optional*): + Pre-generated penultimate hidden states from the text encoder providing additional text conditioning. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + negative_encoder_hidden_states (`torch.Tensor`, *optional*): + Analogous to `encoder_hidden_states` for the positive prompt. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + micro_conditioning_aesthetic_score (`int`, *optional*, defaults to 6): + The targeted aesthetic score according to the laion aesthetic classifier. See + https://laion.ai/blog/laion-aesthetics/ and the micro-conditioning section of + https://huggingface.co/papers/2307.01952. + micro_conditioning_crop_coord (`tuple[int]`, *optional*, defaults to (0, 0)): + The targeted height, width crop coordinates. See the micro-conditioning section of + https://huggingface.co/papers/2307.01952. + temperature (`int | tuple[int, int, list[int]]`, *optional*, defaults to (2, 0)): + Configures the temperature scheduler on `self.scheduler` see `AmusedScheduler#set_timesteps`. + + Examples: + + Returns: + [`~pipelines.pipeline_utils.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.pipeline_utils.ImagePipelineOutput`] is returned, otherwise a + `tuple` is returned where the first element is a list with the generated images. + """ + if (prompt_embeds is not None and encoder_hidden_states is None) or ( + prompt_embeds is None and encoder_hidden_states is not None + ): + raise ValueError("pass either both `prompt_embeds` and `encoder_hidden_states` or neither") + + if (negative_prompt_embeds is not None and negative_encoder_hidden_states is None) or ( + negative_prompt_embeds is None and negative_encoder_hidden_states is not None + ): + raise ValueError( + "pass either both `negatve_prompt_embeds` and `negative_encoder_hidden_states` or neither" + ) + + if (prompt is None and prompt_embeds is None) or (prompt is not None and prompt_embeds is not None): + raise ValueError("pass only one of `prompt` or `prompt_embeds`") + + if isinstance(prompt, str): + prompt = [prompt] + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + batch_size = batch_size * num_images_per_prompt + + if height is None: + height = self.transformer.config.sample_size * self.vae_scale_factor + + if width is None: + width = self.transformer.config.sample_size * self.vae_scale_factor + + if prompt_embeds is None: + input_ids = self.tokenizer( + prompt, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=self.tokenizer.model_max_length, + ).input_ids.to(self._execution_device) + + outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True) + prompt_embeds = outputs.text_embeds + encoder_hidden_states = outputs.hidden_states[-2] + + prompt_embeds = prompt_embeds.repeat(num_images_per_prompt, 1) + encoder_hidden_states = encoder_hidden_states.repeat(num_images_per_prompt, 1, 1) + + if guidance_scale > 1.0: + if negative_prompt_embeds is None: + if negative_prompt is None: + negative_prompt = [""] * len(prompt) + + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + + input_ids = self.tokenizer( + negative_prompt, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=self.tokenizer.model_max_length, + ).input_ids.to(self._execution_device) + + outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True) + negative_prompt_embeds = outputs.text_embeds + negative_encoder_hidden_states = outputs.hidden_states[-2] + + negative_prompt_embeds = negative_prompt_embeds.repeat(num_images_per_prompt, 1) + negative_encoder_hidden_states = negative_encoder_hidden_states.repeat(num_images_per_prompt, 1, 1) + + prompt_embeds = torch.concat([negative_prompt_embeds, prompt_embeds]) + encoder_hidden_states = torch.concat([negative_encoder_hidden_states, encoder_hidden_states]) + + # Note that the micro conditionings _do_ flip the order of width, height for the original size + # and the crop coordinates. This is how it was done in the original code base + micro_conds = torch.tensor( + [ + width, + height, + micro_conditioning_crop_coord[0], + micro_conditioning_crop_coord[1], + micro_conditioning_aesthetic_score, + ], + device=self._execution_device, + dtype=encoder_hidden_states.dtype, + ) + micro_conds = micro_conds.unsqueeze(0) + micro_conds = micro_conds.expand(2 * batch_size if guidance_scale > 1.0 else batch_size, -1) + + shape = (batch_size, height // self.vae_scale_factor, width // self.vae_scale_factor) + + if latents is None: + latents = torch.full( + shape, self.scheduler.config.mask_token_id, dtype=torch.long, device=self._execution_device + ) + + self.scheduler.set_timesteps(num_inference_steps, temperature, self._execution_device) + + num_warmup_steps = len(self.scheduler.timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, timestep in enumerate(self.scheduler.timesteps): + if guidance_scale > 1.0: + model_input = torch.cat([latents] * 2) + else: + model_input = latents + + model_output = self.transformer( + model_input, + micro_conds=micro_conds, + pooled_text_emb=prompt_embeds, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ) + + if guidance_scale > 1.0: + uncond_logits, cond_logits = model_output.chunk(2) + model_output = uncond_logits + guidance_scale * (cond_logits - uncond_logits) + + latents = self.scheduler.step( + model_output=model_output, + timestep=timestep, + sample=latents, + generator=generator, + ).prev_sample + + if i == len(self.scheduler.timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, timestep, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + output = latents + else: + needs_upcasting = self.vqvae.dtype == torch.float16 and self.vqvae.config.force_upcast + + if needs_upcasting: + self.vqvae.float() + + output = self.vqvae.decode( + latents, + force_not_quantize=True, + shape=( + batch_size, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + self.vqvae.config.latent_channels, + ), + ).sample.clip(0, 1) + output = self.image_processor.postprocess(output, output_type) + + if needs_upcasting: + self.vqvae.half() + + self.maybe_free_model_hooks() + + if not return_dict: + return (output,) + + return ImagePipelineOutput(output) diff --git a/diffusers/pipelines/amused/pipeline_amused_img2img.py b/diffusers/pipelines/amused/pipeline_amused_img2img.py new file mode 100644 index 0000000000000000000000000000000000000000..79ebd96dedebefc295781cb9eb1c7bd09ab5459b --- /dev/null +++ b/diffusers/pipelines/amused/pipeline_amused_img2img.py @@ -0,0 +1,363 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Callable + +import torch +from transformers import CLIPTextModelWithProjection, CLIPTokenizer + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...models import UVit2DModel, VQModel +from ...schedulers import AmusedScheduler +from ...utils import is_torch_xla_available, replace_example_docstring +from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, ImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import AmusedImg2ImgPipeline + >>> from diffusers.utils import load_image + + >>> pipe = AmusedImg2ImgPipeline.from_pretrained( + ... "amused/amused-512", variant="fp16", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = "winter mountains" + >>> input_image = ( + ... load_image( + ... "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains.jpg" + ... ) + ... .resize((512, 512)) + ... .convert("RGB") + ... ) + >>> image = pipe(prompt, input_image).images[0] + ``` +""" + + +class AmusedImg2ImgPipeline(DeprecatedPipelineMixin, DiffusionPipeline): + _last_supported_version = "0.33.1" + image_processor: VaeImageProcessor + vqvae: VQModel + tokenizer: CLIPTokenizer + text_encoder: CLIPTextModelWithProjection + transformer: UVit2DModel + scheduler: AmusedScheduler + + model_cpu_offload_seq = "text_encoder->transformer->vqvae" + + # TODO - when calling self.vqvae.quantize, it uses self.vqvae.quantize.embedding.weight before + # the forward method of self.vqvae.quantize, so the hook doesn't get called to move the parameter + # off the meta device. There should be a way to fix this instead of just not offloading it + _exclude_from_cpu_offload = ["vqvae"] + + def __init__( + self, + vqvae: VQModel, + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModelWithProjection, + transformer: UVit2DModel, + scheduler: AmusedScheduler, + ): + super().__init__() + + self.register_modules( + vqvae=vqvae, + tokenizer=tokenizer, + text_encoder=text_encoder, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vqvae.config.block_out_channels) - 1) if getattr(self, "vqvae", None) else 8 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: list[str] | str | None = None, + image: PipelineImageInput = None, + strength: float = 0.5, + num_inference_steps: int = 12, + guidance_scale: float = 10.0, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | None = None, + prompt_embeds: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_encoder_hidden_states: torch.Tensor | None = None, + output_type="pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + cross_attention_kwargs: dict[str, Any] | None = None, + micro_conditioning_aesthetic_score: int = 6, + micro_conditioning_crop_coord: tuple[int, int] = (0, 0), + temperature: int | tuple[int, int] | list[int] = (2, 0), + ): + """ + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + strength (`float`, *optional*, defaults to 0.5): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 12): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 10.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. A single vector from the + pooled and projected final hidden states. + encoder_hidden_states (`torch.Tensor`, *optional*): + Pre-generated penultimate hidden states from the text encoder providing additional text conditioning. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + negative_encoder_hidden_states (`torch.Tensor`, *optional*): + Analogous to `encoder_hidden_states` for the positive prompt. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + micro_conditioning_aesthetic_score (`int`, *optional*, defaults to 6): + The targeted aesthetic score according to the laion aesthetic classifier. See + https://laion.ai/blog/laion-aesthetics/ and the micro-conditioning section of + https://huggingface.co/papers/2307.01952. + micro_conditioning_crop_coord (`tuple[int]`, *optional*, defaults to (0, 0)): + The targeted height, width crop coordinates. See the micro-conditioning section of + https://huggingface.co/papers/2307.01952. + temperature (`int | tuple[int, int, list[int]]`, *optional*, defaults to (2, 0)): + Configures the temperature scheduler on `self.scheduler` see `AmusedScheduler#set_timesteps`. + + Examples: + + Returns: + [`~pipelines.pipeline_utils.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.pipeline_utils.ImagePipelineOutput`] is returned, otherwise a + `tuple` is returned where the first element is a list with the generated images. + """ + + if (prompt_embeds is not None and encoder_hidden_states is None) or ( + prompt_embeds is None and encoder_hidden_states is not None + ): + raise ValueError("pass either both `prompt_embeds` and `encoder_hidden_states` or neither") + + if (negative_prompt_embeds is not None and negative_encoder_hidden_states is None) or ( + negative_prompt_embeds is None and negative_encoder_hidden_states is not None + ): + raise ValueError( + "pass either both `negative_prompt_embeds` and `negative_encoder_hidden_states` or neither" + ) + + if (prompt is None and prompt_embeds is None) or (prompt is not None and prompt_embeds is not None): + raise ValueError("pass only one of `prompt` or `prompt_embeds`") + + if isinstance(prompt, str): + prompt = [prompt] + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + batch_size = batch_size * num_images_per_prompt + + if prompt_embeds is None: + input_ids = self.tokenizer( + prompt, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=self.tokenizer.model_max_length, + ).input_ids.to(self._execution_device) + + outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True) + prompt_embeds = outputs.text_embeds + encoder_hidden_states = outputs.hidden_states[-2] + + prompt_embeds = prompt_embeds.repeat(num_images_per_prompt, 1) + encoder_hidden_states = encoder_hidden_states.repeat(num_images_per_prompt, 1, 1) + + if guidance_scale > 1.0: + if negative_prompt_embeds is None: + if negative_prompt is None: + negative_prompt = [""] * len(prompt) + + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + + input_ids = self.tokenizer( + negative_prompt, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=self.tokenizer.model_max_length, + ).input_ids.to(self._execution_device) + + outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True) + negative_prompt_embeds = outputs.text_embeds + negative_encoder_hidden_states = outputs.hidden_states[-2] + + negative_prompt_embeds = negative_prompt_embeds.repeat(num_images_per_prompt, 1) + negative_encoder_hidden_states = negative_encoder_hidden_states.repeat(num_images_per_prompt, 1, 1) + + prompt_embeds = torch.concat([negative_prompt_embeds, prompt_embeds]) + encoder_hidden_states = torch.concat([negative_encoder_hidden_states, encoder_hidden_states]) + + image = self.image_processor.preprocess(image) + + height, width = image.shape[-2:] + + # Note that the micro conditionings _do_ flip the order of width, height for the original size + # and the crop coordinates. This is how it was done in the original code base + micro_conds = torch.tensor( + [ + width, + height, + micro_conditioning_crop_coord[0], + micro_conditioning_crop_coord[1], + micro_conditioning_aesthetic_score, + ], + device=self._execution_device, + dtype=encoder_hidden_states.dtype, + ) + + micro_conds = micro_conds.unsqueeze(0) + micro_conds = micro_conds.expand(2 * batch_size if guidance_scale > 1.0 else batch_size, -1) + + self.scheduler.set_timesteps(num_inference_steps, temperature, self._execution_device) + num_inference_steps = int(len(self.scheduler.timesteps) * strength) + start_timestep_idx = len(self.scheduler.timesteps) - num_inference_steps + + needs_upcasting = self.vqvae.dtype == torch.float16 and self.vqvae.config.force_upcast + + if needs_upcasting: + self.vqvae.float() + + latents = self.vqvae.encode(image.to(dtype=self.vqvae.dtype, device=self._execution_device)).latents + latents_bsz, channels, latents_height, latents_width = latents.shape + latents = self.vqvae.quantize(latents)[2][2].reshape(latents_bsz, latents_height, latents_width) + latents = self.scheduler.add_noise( + latents, self.scheduler.timesteps[start_timestep_idx - 1], generator=generator + ) + latents = latents.repeat(num_images_per_prompt, 1, 1) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i in range(start_timestep_idx, len(self.scheduler.timesteps)): + timestep = self.scheduler.timesteps[i] + + if guidance_scale > 1.0: + model_input = torch.cat([latents] * 2) + else: + model_input = latents + + model_output = self.transformer( + model_input, + micro_conds=micro_conds, + pooled_text_emb=prompt_embeds, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ) + + if guidance_scale > 1.0: + uncond_logits, cond_logits = model_output.chunk(2) + model_output = uncond_logits + guidance_scale * (cond_logits - uncond_logits) + + latents = self.scheduler.step( + model_output=model_output, + timestep=timestep, + sample=latents, + generator=generator, + ).prev_sample + + if i == len(self.scheduler.timesteps) - 1 or ((i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, timestep, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + output = latents + else: + output = self.vqvae.decode( + latents, + force_not_quantize=True, + shape=( + batch_size, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + self.vqvae.config.latent_channels, + ), + ).sample.clip(0, 1) + output = self.image_processor.postprocess(output, output_type) + + if needs_upcasting: + self.vqvae.half() + + self.maybe_free_model_hooks() + + if not return_dict: + return (output,) + + return ImagePipelineOutput(output) diff --git a/diffusers/pipelines/amused/pipeline_amused_inpaint.py b/diffusers/pipelines/amused/pipeline_amused_inpaint.py new file mode 100644 index 0000000000000000000000000000000000000000..55302401832c5683b0601c5690954472a44074b9 --- /dev/null +++ b/diffusers/pipelines/amused/pipeline_amused_inpaint.py @@ -0,0 +1,394 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Any, Callable + +import torch +from transformers import CLIPTextModelWithProjection, CLIPTokenizer + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...models import UVit2DModel, VQModel +from ...schedulers import AmusedScheduler +from ...utils import is_torch_xla_available, replace_example_docstring +from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, ImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import AmusedInpaintPipeline + >>> from diffusers.utils import load_image + + >>> pipe = AmusedInpaintPipeline.from_pretrained( + ... "amused/amused-512", variant="fp16", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = "fall mountains" + >>> input_image = ( + ... load_image( + ... "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1.jpg" + ... ) + ... .resize((512, 512)) + ... .convert("RGB") + ... ) + >>> mask = ( + ... load_image( + ... "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1_mask.png" + ... ) + ... .resize((512, 512)) + ... .convert("L") + ... ) + >>> pipe(prompt, input_image, mask).images[0].save("out.png") + ``` +""" + + +class AmusedInpaintPipeline(DeprecatedPipelineMixin, DiffusionPipeline): + _last_supported_version = "0.33.1" + image_processor: VaeImageProcessor + vqvae: VQModel + tokenizer: CLIPTokenizer + text_encoder: CLIPTextModelWithProjection + transformer: UVit2DModel + scheduler: AmusedScheduler + + model_cpu_offload_seq = "text_encoder->transformer->vqvae" + + # TODO - when calling self.vqvae.quantize, it uses self.vqvae.quantize.embedding.weight before + # the forward method of self.vqvae.quantize, so the hook doesn't get called to move the parameter + # off the meta device. There should be a way to fix this instead of just not offloading it + _exclude_from_cpu_offload = ["vqvae"] + + def __init__( + self, + vqvae: VQModel, + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModelWithProjection, + transformer: UVit2DModel, + scheduler: AmusedScheduler, + ): + super().__init__() + + self.register_modules( + vqvae=vqvae, + tokenizer=tokenizer, + text_encoder=text_encoder, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vqvae.config.block_out_channels) - 1) if getattr(self, "vqvae", None) else 8 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, + do_normalize=False, + do_binarize=True, + do_convert_grayscale=True, + do_resize=True, + ) + self.scheduler.register_to_config(masking_schedule="linear") + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: list[str] | str | None = None, + image: PipelineImageInput = None, + mask_image: PipelineImageInput = None, + strength: float = 1.0, + num_inference_steps: int = 12, + guidance_scale: float = 10.0, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | None = None, + prompt_embeds: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_encoder_hidden_states: torch.Tensor | None = None, + output_type="pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + cross_attention_kwargs: dict[str, Any] | None = None, + micro_conditioning_aesthetic_score: int = 6, + micro_conditioning_crop_coord: tuple[int, int] = (0, 0), + temperature: int | tuple[int, int] | list[int] = (2, 0), + ): + """ + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask + are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a + single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one + color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B, + H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W, + 1)`, or `(H, W)`. + strength (`float`, *optional*, defaults to 1.0): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 16): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 10.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. A single vector from the + pooled and projected final hidden states. + encoder_hidden_states (`torch.Tensor`, *optional*): + Pre-generated penultimate hidden states from the text encoder providing additional text conditioning. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + negative_encoder_hidden_states (`torch.Tensor`, *optional*): + Analogous to `encoder_hidden_states` for the positive prompt. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + micro_conditioning_aesthetic_score (`int`, *optional*, defaults to 6): + The targeted aesthetic score according to the laion aesthetic classifier. See + https://laion.ai/blog/laion-aesthetics/ and the micro-conditioning section of + https://huggingface.co/papers/2307.01952. + micro_conditioning_crop_coord (`tuple[int]`, *optional*, defaults to (0, 0)): + The targeted height, width crop coordinates. See the micro-conditioning section of + https://huggingface.co/papers/2307.01952. + temperature (`int | tuple[int, int, list[int]]`, *optional*, defaults to (2, 0)): + Configures the temperature scheduler on `self.scheduler` see `AmusedScheduler#set_timesteps`. + + Examples: + + Returns: + [`~pipelines.pipeline_utils.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.pipeline_utils.ImagePipelineOutput`] is returned, otherwise a + `tuple` is returned where the first element is a list with the generated images. + """ + + if (prompt_embeds is not None and encoder_hidden_states is None) or ( + prompt_embeds is None and encoder_hidden_states is not None + ): + raise ValueError("pass either both `prompt_embeds` and `encoder_hidden_states` or neither") + + if (negative_prompt_embeds is not None and negative_encoder_hidden_states is None) or ( + negative_prompt_embeds is None and negative_encoder_hidden_states is not None + ): + raise ValueError( + "pass either both `negatve_prompt_embeds` and `negative_encoder_hidden_states` or neither" + ) + + if (prompt is None and prompt_embeds is None) or (prompt is not None and prompt_embeds is not None): + raise ValueError("pass only one of `prompt` or `prompt_embeds`") + + if isinstance(prompt, str): + prompt = [prompt] + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + batch_size = batch_size * num_images_per_prompt + + if prompt_embeds is None: + input_ids = self.tokenizer( + prompt, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=self.tokenizer.model_max_length, + ).input_ids.to(self._execution_device) + + outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True) + prompt_embeds = outputs.text_embeds + encoder_hidden_states = outputs.hidden_states[-2] + + prompt_embeds = prompt_embeds.repeat(num_images_per_prompt, 1) + encoder_hidden_states = encoder_hidden_states.repeat(num_images_per_prompt, 1, 1) + + if guidance_scale > 1.0: + if negative_prompt_embeds is None: + if negative_prompt is None: + negative_prompt = [""] * len(prompt) + + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + + input_ids = self.tokenizer( + negative_prompt, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=self.tokenizer.model_max_length, + ).input_ids.to(self._execution_device) + + outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True) + negative_prompt_embeds = outputs.text_embeds + negative_encoder_hidden_states = outputs.hidden_states[-2] + + negative_prompt_embeds = negative_prompt_embeds.repeat(num_images_per_prompt, 1) + negative_encoder_hidden_states = negative_encoder_hidden_states.repeat(num_images_per_prompt, 1, 1) + + prompt_embeds = torch.concat([negative_prompt_embeds, prompt_embeds]) + encoder_hidden_states = torch.concat([negative_encoder_hidden_states, encoder_hidden_states]) + + image = self.image_processor.preprocess(image) + + height, width = image.shape[-2:] + + # Note that the micro conditionings _do_ flip the order of width, height for the original size + # and the crop coordinates. This is how it was done in the original code base + micro_conds = torch.tensor( + [ + width, + height, + micro_conditioning_crop_coord[0], + micro_conditioning_crop_coord[1], + micro_conditioning_aesthetic_score, + ], + device=self._execution_device, + dtype=encoder_hidden_states.dtype, + ) + + micro_conds = micro_conds.unsqueeze(0) + micro_conds = micro_conds.expand(2 * batch_size if guidance_scale > 1.0 else batch_size, -1) + + self.scheduler.set_timesteps(num_inference_steps, temperature, self._execution_device) + num_inference_steps = int(len(self.scheduler.timesteps) * strength) + start_timestep_idx = len(self.scheduler.timesteps) - num_inference_steps + + needs_upcasting = self.vqvae.dtype == torch.float16 and self.vqvae.config.force_upcast + + if needs_upcasting: + self.vqvae.float() + + latents = self.vqvae.encode(image.to(dtype=self.vqvae.dtype, device=self._execution_device)).latents + latents_bsz, channels, latents_height, latents_width = latents.shape + latents = self.vqvae.quantize(latents)[2][2].reshape(latents_bsz, latents_height, latents_width) + + mask = self.mask_processor.preprocess( + mask_image, height // self.vae_scale_factor, width // self.vae_scale_factor + ) + mask = mask.reshape(mask.shape[0], latents_height, latents_width).bool().to(latents.device) + latents[mask] = self.scheduler.config.mask_token_id + + starting_mask_ratio = mask.sum() / latents.numel() + + latents = latents.repeat(num_images_per_prompt, 1, 1) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i in range(start_timestep_idx, len(self.scheduler.timesteps)): + timestep = self.scheduler.timesteps[i] + + if guidance_scale > 1.0: + model_input = torch.cat([latents] * 2) + else: + model_input = latents + + model_output = self.transformer( + model_input, + micro_conds=micro_conds, + pooled_text_emb=prompt_embeds, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ) + + if guidance_scale > 1.0: + uncond_logits, cond_logits = model_output.chunk(2) + model_output = uncond_logits + guidance_scale * (cond_logits - uncond_logits) + + latents = self.scheduler.step( + model_output=model_output, + timestep=timestep, + sample=latents, + generator=generator, + starting_mask_ratio=starting_mask_ratio, + ).prev_sample + + if i == len(self.scheduler.timesteps) - 1 or ((i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, timestep, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + output = latents + else: + output = self.vqvae.decode( + latents, + force_not_quantize=True, + shape=( + batch_size, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + self.vqvae.config.latent_channels, + ), + ).sample.clip(0, 1) + output = self.image_processor.postprocess(output, output_type) + + if needs_upcasting: + self.vqvae.half() + + self.maybe_free_model_hooks() + + if not return_dict: + return (output,) + + return ImagePipelineOutput(output) diff --git a/diffusers/pipelines/audioldm/__init__.py b/diffusers/pipelines/audioldm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a002b4aa72e0a180c7042c406667d37122d6e4cc --- /dev/null +++ b/diffusers/pipelines/audioldm/__init__.py @@ -0,0 +1,51 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + is_torch_available, + is_transformers_available, + is_transformers_version, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import ( + AudioLDMPipeline, + ) + + _dummy_objects.update({"AudioLDMPipeline": AudioLDMPipeline}) +else: + _import_structure["pipeline_audioldm"] = ["AudioLDMPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import ( + AudioLDMPipeline, + ) + + else: + from .pipeline_audioldm import AudioLDMPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/pipelines/audioldm/pipeline_audioldm.py b/diffusers/pipelines/audioldm/pipeline_audioldm.py new file mode 100644 index 0000000000000000000000000000000000000000..357c3582b21cf838e3b8812853b3c4ba1814c0c3 --- /dev/null +++ b/diffusers/pipelines/audioldm/pipeline_audioldm.py @@ -0,0 +1,558 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import torch +import torch.nn.functional as F +from transformers import ClapTextModelWithProjection, RobertaTokenizer, RobertaTokenizerFast, SpeechT5HifiGan + +from ...models import AutoencoderKL, UNet2DConditionModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import AudioPipelineOutput, DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import AudioLDMPipeline + >>> import torch + >>> import scipy + + >>> repo_id = "cvssp/audioldm-s-full-v2" + >>> pipe = AudioLDMPipeline.from_pretrained(repo_id, torch_dtype=torch.float16) + >>> pipe = pipe.to("cuda") + + >>> prompt = "Techno music with a strong, upbeat tempo and high melodic riffs" + >>> audio = pipe(prompt, num_inference_steps=10, audio_length_in_s=5.0).audios[0] + + >>> # save the audio sample as a .wav file + >>> scipy.io.wavfile.write("techno.wav", rate=16000, data=audio) + ``` +""" + + +class AudioLDMPipeline(DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin): + r""" + Pipeline for text-to-audio generation using AudioLDM. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.ClapTextModelWithProjection`]): + Frozen text-encoder (`ClapTextModelWithProjection`, specifically the + [laion/clap-htsat-unfused](https://huggingface.co/laion/clap-htsat-unfused) variant. + tokenizer ([`PreTrainedTokenizer`]): + A [`~transformers.RobertaTokenizer`] to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded audio latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded audio latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + vocoder ([`~transformers.SpeechT5HifiGan`]): + Vocoder of class `SpeechT5HifiGan`. + """ + + _last_supported_version = "0.33.1" + model_cpu_offload_seq = "text_encoder->unet->vae" + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: ClapTextModelWithProjection, + tokenizer: RobertaTokenizer | RobertaTokenizerFast, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + vocoder: SpeechT5HifiGan, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + vocoder=vocoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + + def _encode_prompt( + self, + prompt, + device, + num_waveforms_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device (`torch.device`): + torch device + num_waveforms_per_prompt (`int`): + number of waveforms that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the audio generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + attention_mask = text_inputs.attention_mask + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLAP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask.to(device), + ) + prompt_embeds = prompt_embeds.text_embeds + # additional L_2 normalization over each hidden-state + prompt_embeds = F.normalize(prompt_embeds, dim=-1) + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + ( + bs_embed, + seq_len, + ) = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_waveforms_per_prompt) + prompt_embeds = prompt_embeds.view(bs_embed * num_waveforms_per_prompt, seq_len) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + uncond_input_ids = uncond_input.input_ids.to(device) + attention_mask = uncond_input.attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + uncond_input_ids, + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds.text_embeds + # additional L_2 normalization over each hidden-state + negative_prompt_embeds = F.normalize(negative_prompt_embeds, dim=-1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_waveforms_per_prompt) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_waveforms_per_prompt, seq_len) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + def decode_latents(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + mel_spectrogram = self.vae.decode(latents).sample + return mel_spectrogram + + def mel_spectrogram_to_waveform(self, mel_spectrogram): + if mel_spectrogram.dim() == 4: + mel_spectrogram = mel_spectrogram.squeeze(1) + + waveform = self.vocoder(mel_spectrogram) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + waveform = waveform.cpu().float() + return waveform + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + audio_length_in_s, + vocoder_upsample_factor, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + min_audio_length_in_s = vocoder_upsample_factor * self.vae_scale_factor + if audio_length_in_s < min_audio_length_in_s: + raise ValueError( + f"`audio_length_in_s` has to be a positive value greater than or equal to {min_audio_length_in_s}, but " + f"is {audio_length_in_s}." + ) + + if self.vocoder.config.model_in_dim % self.vae_scale_factor != 0: + raise ValueError( + f"The number of frequency bins in the vocoder's log-mel spectrogram has to be divisible by the " + f"VAE scale factor, but got {self.vocoder.config.model_in_dim} bins and a scale factor of " + f"{self.vae_scale_factor}." + ) + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents with width->self.vocoder.config.model_in_dim + def prepare_latents(self, batch_size, num_channels_latents, height, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(self.vocoder.config.model_in_dim) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + audio_length_in_s: float | None = None, + num_inference_steps: int = 10, + guidance_scale: float = 2.5, + negative_prompt: str | list[str] | None = None, + num_waveforms_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int | None = 1, + cross_attention_kwargs: dict[str, Any] | None = None, + output_type: str | None = "np", + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide audio generation. If not defined, you need to pass `prompt_embeds`. + audio_length_in_s (`int`, *optional*, defaults to 5.12): + The length of the generated audio sample in seconds. + num_inference_steps (`int`, *optional*, defaults to 10): + The number of denoising steps. More denoising steps usually lead to a higher quality audio at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 2.5): + A higher guidance scale value encourages the model to generate audio that is closely linked to the text + `prompt` at the expense of lower sound quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in audio generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_waveforms_per_prompt (`int`, *optional*, defaults to 1): + The number of waveforms to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.AudioPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated image. Choose between `"np"` to return a NumPy `np.ndarray` or + `"pt"` to return a PyTorch `torch.Tensor` object. + + Examples: + + Returns: + [`~pipelines.AudioPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.AudioPipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated audio. + """ + # 0. Convert audio input length from seconds to spectrogram height + vocoder_upsample_factor = np.prod(self.vocoder.config.upsample_rates) / self.vocoder.config.sampling_rate + + if audio_length_in_s is None: + audio_length_in_s = self.unet.config.sample_size * self.vae_scale_factor * vocoder_upsample_factor + + height = int(audio_length_in_s / vocoder_upsample_factor) + + original_waveform_length = int(audio_length_in_s * self.vocoder.config.sampling_rate) + if height % self.vae_scale_factor != 0: + height = int(np.ceil(height / self.vae_scale_factor)) * self.vae_scale_factor + logger.info( + f"Audio length in seconds {audio_length_in_s} is increased to {height * vocoder_upsample_factor} " + f"so that it can be handled by the model. It will be cut to {audio_length_in_s} after the " + f"denoising process." + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + audio_length_in_s, + vocoder_upsample_factor, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_waveforms_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_waveforms_per_prompt, + num_channels_latents, + height, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=None, + class_labels=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + # 8. Post-processing + mel_spectrogram = self.decode_latents(latents) + + audio = self.mel_spectrogram_to_waveform(mel_spectrogram) + + audio = audio[:, :original_waveform_length] + + if output_type == "np": + audio = audio.numpy() + + if not return_dict: + return (audio,) + + return AudioPipelineOutput(audios=audio) diff --git a/diffusers/pipelines/blip_diffusion/__init__.py b/diffusers/pipelines/blip_diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c245313e2f8ac578542b2db88750cea64de67b94 --- /dev/null +++ b/diffusers/pipelines/blip_diffusion/__init__.py @@ -0,0 +1,19 @@ +from dataclasses import dataclass + +import numpy as np +import PIL +from PIL import Image + +from ...utils import OptionalDependencyNotAvailable, is_torch_available, is_transformers_available + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import ShapEPipeline +else: + from .blip_image_processing import BlipImageProcessor + from .modeling_blip2 import Blip2QFormerModel + from .modeling_ctx_clip import ContextCLIPTextModel + from .pipeline_blip_diffusion import BlipDiffusionPipeline diff --git a/diffusers/pipelines/blip_diffusion/blip_image_processing.py b/diffusers/pipelines/blip_diffusion/blip_image_processing.py new file mode 100644 index 0000000000000000000000000000000000000000..5d2a0186f041d63e9274bdbcc2f8b53fb072175a --- /dev/null +++ b/diffusers/pipelines/blip_diffusion/blip_image_processing.py @@ -0,0 +1,316 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for BLIP.""" + +import numpy as np +import torch +from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from transformers.image_transforms import convert_to_rgb, resize, to_channel_dimension_format +from transformers.image_utils import ( + OPENAI_CLIP_MEAN, + OPENAI_CLIP_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + infer_channel_dimension_format, + is_scaled_image, + make_list_of_images, + to_numpy_array, + valid_images, +) +from transformers.utils import TensorType, is_vision_available, logging + +from diffusers.utils import numpy_to_pil + + +if is_vision_available(): + import PIL.Image + + +logger = logging.get_logger(__name__) + + +# We needed some extra functions on top of the ones in transformers.image_processing_utils.BaseImageProcessor, namely center crop +# Copy-pasted from transformers.models.blip.image_processing_blip.BlipImageProcessor +class BlipImageProcessor(BaseImageProcessor): + r""" + Constructs a BLIP image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the + `do_resize` parameter in the `preprocess` method. + size (`dict`, *optional*, defaults to `{"height": 384, "width": 384}`): + Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess` + method. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be + overridden by the `resample` parameter in the `preprocess` method. + do_rescale (`bool`, *optional*, defaults to `True`): + Wwhether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the + `do_rescale` parameter in the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be + overridden by the `rescale_factor` parameter in the `preprocess` method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` + method. Can be overridden by the `do_normalize` parameter in the `preprocess` method. + image_mean (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be + overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + Can be overridden by the `image_std` parameter in the `preprocess` method. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: dict[str, int] = None, + resample: PILImageResampling = PILImageResampling.BICUBIC, + do_rescale: bool = True, + rescale_factor: int | float = 1 / 255, + do_normalize: bool = True, + image_mean: float | list[float] | None = None, + image_std: float | list[float] | None = None, + do_convert_rgb: bool = True, + do_center_crop: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"height": 224, "width": 224} + size = get_size_dict(size, default_to_square=True) + + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN + self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD + self.do_convert_rgb = do_convert_rgb + self.do_center_crop = do_center_crop + + # Copy-pasted from transformers.models.vit.image_processing_vit.ViTImageProcessor.resize with PILImageResampling.BILINEAR->PILImageResampling.BICUBIC + def resize( + self, + image: np.ndarray, + size: dict[str, int], + resample: PILImageResampling = PILImageResampling.BICUBIC, + data_format: str | ChannelDimension | None = None, + input_data_format: str | ChannelDimension | None = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image to `(size["height"], size["width"])`. + + Args: + image (`np.ndarray`): + Image to resize. + size (`dict[str, int]`): + Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`. + data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the output image. If unset, the channel dimension format of the input + image is used. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + + Returns: + `np.ndarray`: The resized image. + """ + size = get_size_dict(size) + if "height" not in size or "width" not in size: + raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}") + output_size = (size["height"], size["width"]) + return resize( + image, + size=output_size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + def preprocess( + self, + images: ImageInput, + do_resize: bool | None = None, + size: dict[str, int] | None = None, + resample: PILImageResampling = None, + do_rescale: bool | None = None, + do_center_crop: bool | None = None, + rescale_factor: float | None = None, + do_normalize: bool | None = None, + image_mean: float | list[float] | None = None, + image_std: float | list[float] | None = None, + return_tensors: str | TensorType | None = None, + do_convert_rgb: bool = None, + data_format: ChannelDimension = ChannelDimension.FIRST, + input_data_format: str | ChannelDimension | None = None, + **kwargs, + ) -> PIL.Image.Image: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`dict[str, int]`, *optional*, defaults to `self.size`): + Controls the size of the image after `resize`. The shortest edge of the image is resized to + `size["shortest_edge"]` whilst preserving the aspect ratio. If the longest edge of this resized image + is > `int(size["shortest_edge"] * (1333 / 800))`, then the image is resized again to make the longest + edge equal to `int(size["shortest_edge"] * (1333 / 800))`. + resample (`PILImageResampling`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image values between [0 - 1]. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`): + Image mean to normalize the image by if `do_normalize` is set to `True`. + image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to normalize the image by if `do_normalize` is set to `True`. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + resample = resample if resample is not None else self.resample + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop + + size = size if size is not None else self.size + size = get_size_dict(size, default_to_square=False) + images = make_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + if do_resize and size is None or resample is None: + raise ValueError("Size and resample must be specified if do_resize is True.") + + if do_rescale and rescale_factor is None: + raise ValueError("Rescale factor must be specified if do_rescale is True.") + + if do_normalize and (image_mean is None or image_std is None): + raise ValueError("Image mean and std must be specified if do_normalize is True.") + + # PIL RGBA images are converted to RGB + if do_convert_rgb: + images = [convert_to_rgb(image) for image in images] + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if is_scaled_image(images[0]) and do_rescale: + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + if do_resize: + images = [ + self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) + for image in images + ] + + if do_rescale: + images = [ + self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + for image in images + ] + if do_normalize: + images = [ + self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) + for image in images + ] + if do_center_crop: + images = [self.center_crop(image, size, input_data_format=input_data_format) for image in images] + + images = [ + to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images + ] + + encoded_outputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors) + return encoded_outputs + + # Follows diffusers.VaeImageProcessor.postprocess + def postprocess(self, sample: torch.Tensor, output_type: str = "pil"): + if output_type not in ["pt", "np", "pil"]: + raise ValueError( + f"output_type={output_type} is not supported. Make sure to choose one of ['pt', 'np', or 'pil']" + ) + + # Equivalent to diffusers.VaeImageProcessor.denormalize + sample = (sample / 2 + 0.5).clamp(0, 1) + if output_type == "pt": + return sample + + # Equivalent to diffusers.VaeImageProcessor.pt_to_numpy + sample = sample.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "np": + return sample + # Output_type must be 'pil' + sample = numpy_to_pil(sample) + return sample diff --git a/diffusers/pipelines/blip_diffusion/modeling_blip2.py b/diffusers/pipelines/blip_diffusion/modeling_blip2.py new file mode 100644 index 0000000000000000000000000000000000000000..c434ccdaccca70a3a3cb2f02b015b417071a0950 --- /dev/null +++ b/diffusers/pipelines/blip_diffusion/modeling_blip2.py @@ -0,0 +1,636 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from torch import nn +from transformers import BertTokenizer +from transformers.activations import QuickGELUActivation as QuickGELU +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPooling, + BaseModelOutputWithPoolingAndCrossAttentions, +) +from transformers.models.blip_2.configuration_blip_2 import Blip2Config, Blip2VisionConfig +from transformers.models.blip_2.modeling_blip_2 import ( + Blip2Encoder, + Blip2PreTrainedModel, + Blip2QFormerAttention, + Blip2QFormerIntermediate, + Blip2QFormerOutput, +) +from transformers.pytorch_utils import apply_chunking_to_forward +from transformers.utils import ( + logging, + replace_return_docstrings, +) + + +logger = logging.get_logger(__name__) + + +# There is an implementation of Blip2 in `transformers` : https://github.com/huggingface/transformers/blob/main/src/transformers/models/blip_2/modeling_blip_2.py. +# But it doesn't support getting multimodal embeddings. So, this module can be +# replaced with a future `transformers` version supports that. +class Blip2TextEmbeddings(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + + self.config = config + + def forward( + self, + input_ids=None, + position_ids=None, + query_embeds=None, + past_key_values_length=0, + ): + if input_ids is not None: + seq_length = input_ids.size()[1] + else: + seq_length = 0 + + if position_ids is None: + position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length].clone() + + if input_ids is not None: + embeddings = self.word_embeddings(input_ids) + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings + + if query_embeds is not None: + batch_size = embeddings.shape[0] + # repeat the query embeddings for batch size + query_embeds = query_embeds.repeat(batch_size, 1, 1) + embeddings = torch.cat((query_embeds, embeddings), dim=1) + else: + embeddings = query_embeds + embeddings = embeddings.to(query_embeds.dtype) + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +# Copy-pasted from transformers.models.blip.modeling_blip.BlipVisionEmbeddings with Blip->Blip2 +class Blip2VisionEmbeddings(nn.Module): + def __init__(self, config: Blip2VisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim)) + + self.patch_embedding = nn.Conv2d( + in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size, bias=False + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + + self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim)) + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + batch_size = pixel_values.shape[0] + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + + class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + embeddings = embeddings + self.position_embedding[:, : embeddings.size(1), :].to(target_dtype) + return embeddings + + +# The Qformer encoder, which takes the visual embeddings, and the text input, to get multimodal embeddings +class Blip2QFormerEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [Blip2QFormerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + query_length=0, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions else None + + next_decoder_cache = () if use_cache else None + + for i in range(self.config.num_hidden_layers): + layer_module = self.layer[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if getattr(self.config, "gradient_checkpointing", False) and torch.is_grad_enabled(): + if use_cache: + logger.warning( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + layer_outputs = self._gradient_checkpointing_func( + layer_module, + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + query_length, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + query_length, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if layer_module.has_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +# The layers making up the Qformer encoder +class Blip2QFormerLayer(nn.Module): + def __init__(self, config, layer_idx): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = Blip2QFormerAttention(config) + + self.layer_idx = layer_idx + + if layer_idx % config.cross_attention_frequency == 0: + self.crossattention = Blip2QFormerAttention(config, is_cross_attention=True) + self.has_cross_attention = True + else: + self.has_cross_attention = False + + self.intermediate = Blip2QFormerIntermediate(config) + self.intermediate_query = Blip2QFormerIntermediate(config) + self.output_query = Blip2QFormerOutput(config) + self.output = Blip2QFormerOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + query_length=0, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:-1] + + present_key_value = self_attention_outputs[-1] + + if query_length > 0: + query_attention_output = attention_output[:, :query_length, :] + + if self.has_cross_attention: + if encoder_hidden_states is None: + raise ValueError("encoder_hidden_states must be given for cross-attention layers") + cross_attention_outputs = self.crossattention( + query_attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) + query_attention_output = cross_attention_outputs[0] + # add cross attentions if we output attention weights + outputs = outputs + cross_attention_outputs[1:-1] + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk_query, + self.chunk_size_feed_forward, + self.seq_len_dim, + query_attention_output, + ) + + if attention_output.shape[1] > query_length: + layer_output_text = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output[:, query_length:, :], + ) + layer_output = torch.cat([layer_output, layer_output_text], dim=1) + else: + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output, + ) + outputs = (layer_output,) + outputs + + outputs = outputs + (present_key_value,) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + def feed_forward_chunk_query(self, attention_output): + intermediate_output = self.intermediate_query(attention_output) + layer_output = self.output_query(intermediate_output, attention_output) + return layer_output + + +# ProjLayer used to project the multimodal Blip2 embeddings to be used in the text encoder +class ProjLayer(nn.Module): + def __init__(self, in_dim, out_dim, hidden_dim, drop_p=0.1, eps=1e-12): + super().__init__() + + # Dense1 -> Act -> Dense2 -> Drop -> Res -> Norm + self.dense1 = nn.Linear(in_dim, hidden_dim) + self.act_fn = QuickGELU() + self.dense2 = nn.Linear(hidden_dim, out_dim) + self.dropout = nn.Dropout(drop_p) + + self.LayerNorm = nn.LayerNorm(out_dim, eps=eps) + + def forward(self, x): + x_in = x + + x = self.LayerNorm(x) + x = self.dropout(self.dense2(self.act_fn(self.dense1(x)))) + x_in + + return x + + +# Copy-pasted from transformers.models.blip.modeling_blip.BlipVisionModel with Blip->Blip2, BLIP->BLIP_2 +class Blip2VisionModel(Blip2PreTrainedModel): + main_input_name = "pixel_values" + config_class = Blip2VisionConfig + + def __init__(self, config: Blip2VisionConfig): + super().__init__(config) + self.config = config + embed_dim = config.hidden_size + self.embeddings = Blip2VisionEmbeddings(config) + self.pre_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.encoder = Blip2Encoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + self.post_init() + + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=Blip2VisionConfig) + def forward( + self, + pixel_values: torch.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + ) -> tuple | BaseModelOutputWithPooling: + r""" + Returns: + + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.embeddings(pixel_values) + hidden_states = self.pre_layernorm(hidden_states) + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.post_layernorm(last_hidden_state) + + pooled_output = last_hidden_state[:, 0, :] + pooled_output = self.post_layernorm(pooled_output) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + def get_input_embeddings(self): + return self.embeddings + + +# Qformer model, used to get multimodal embeddings from the text and image inputs +class Blip2QFormerModel(Blip2PreTrainedModel): + """ + Querying Transformer (Q-Former), used in BLIP-2. + """ + + def __init__(self, config: Blip2Config): + super().__init__(config) + self.config = config + self.embeddings = Blip2TextEmbeddings(config.qformer_config) + self.visual_encoder = Blip2VisionModel(config.vision_config) + self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size)) + if not hasattr(config, "tokenizer") or config.tokenizer is None: + self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", truncation_side="right") + else: + self.tokenizer = BertTokenizer.from_pretrained(config.tokenizer, truncation_side="right") + self.tokenizer.add_special_tokens({"bos_token": "[DEC]"}) + self.proj_layer = ProjLayer( + in_dim=config.qformer_config.hidden_size, + out_dim=config.qformer_config.hidden_size, + hidden_dim=config.qformer_config.hidden_size * 4, + drop_p=0.1, + eps=1e-12, + ) + + self.encoder = Blip2QFormerEncoder(config.qformer_config) + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def get_extended_attention_mask( + self, + attention_mask: torch.Tensor, + input_shape: tuple[int], + device: torch.device, + has_query: bool = False, + ) -> torch.Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (`tuple[int]`): + The shape of the input to the model. + device (`torch.device`): + The device of the input to the model. + + Returns: + `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - the model is an encoder, so make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( + input_shape, attention_mask.shape + ) + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def forward( + self, + text_input=None, + image_input=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + encoder_hidden_states (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (`tuple(tuple(torch.Tensor))` of length `config.n_layers` with each tuple having 4 tensors of: + shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and + value hidden states of the attention blocks. Can be used to speed up decoding. If `past_key_values` are + used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key + value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape + `(batch_size, sequence_length)`. + use_cache (`bool`, `optional`): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + + text = self.tokenizer(text_input, return_tensors="pt", padding=True) + text = text.to(self.device) + input_ids = text.input_ids + batch_size = input_ids.shape[0] + query_atts = torch.ones((batch_size, self.query_tokens.size()[1]), dtype=torch.long).to(self.device) + attention_mask = torch.cat([query_atts, text.attention_mask], dim=1) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # past_key_values_length + past_key_values_length = ( + past_key_values[0][0].shape[2] - self.config.query_length if past_key_values is not None else 0 + ) + + query_length = self.query_tokens.shape[1] + + embedding_output = self.embeddings( + input_ids=input_ids, + query_embeds=self.query_tokens, + past_key_values_length=past_key_values_length, + ) + + # embedding_output = self.layernorm(query_embeds) + # embedding_output = self.dropout(embedding_output) + + input_shape = embedding_output.size()[:-1] + batch_size, seq_length = input_shape + device = embedding_output.device + + image_embeds_frozen = self.visual_encoder(image_input).last_hidden_state + # image_embeds_frozen = torch.ones_like(image_embeds_frozen) + encoder_hidden_states = image_embeds_frozen + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + if isinstance(encoder_hidden_states, list): + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size() + else: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + + if isinstance(encoder_attention_mask, list): + encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask] + elif encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.qformer_config.num_hidden_layers) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + query_length=query_length, + ) + sequence_output = encoder_outputs[0] + pooled_output = sequence_output[:, 0, :] + + if not return_dict: + return self.proj_layer(sequence_output[:, :query_length, :]) + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) diff --git a/diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py b/diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..c5364f8985aabd82679caf5bfc20d04c1d93686e --- /dev/null +++ b/diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py @@ -0,0 +1,221 @@ +# Copyright 2025 Salesforce.com, inc. +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from torch import nn +from transformers import CLIPPreTrainedModel +from transformers.modeling_outputs import BaseModelOutputWithPooling +from transformers.models.clip.configuration_clip import CLIPTextConfig +from transformers.models.clip.modeling_clip import CLIPEncoder + + +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: int | None = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + +# This is a modified version of the CLIPTextModel from transformers.models.clip.modeling_clip +# Which allows for an extra input of "context embeddings", which are the query embeddings used in Qformer +# They pass through the clip model, along with the text embeddings, and interact with them using self attention +class ContextCLIPTextModel(CLIPPreTrainedModel): + config_class = CLIPTextConfig + + _no_split_modules = ["CLIPEncoderLayer"] + + def __init__(self, config: CLIPTextConfig): + super().__init__(config) + self.text_model = ContextCLIPTextTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + ctx_embeddings: torch.Tensor = None, + ctx_begin_pos: list = None, + input_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + ) -> tuple | BaseModelOutputWithPooling: + return self.text_model( + ctx_embeddings=ctx_embeddings, + ctx_begin_pos=ctx_begin_pos, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +class ContextCLIPTextTransformer(nn.Module): + def __init__(self, config: CLIPTextConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + self.embeddings = ContextCLIPTextEmbeddings(config) + self.encoder = CLIPEncoder(config) + self.final_layer_norm = nn.LayerNorm(embed_dim) + + def forward( + self, + ctx_embeddings: torch.Tensor, + ctx_begin_pos: list, + input_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + ) -> tuple | BaseModelOutputWithPooling: + r""" + Returns: + + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is None: + raise ValueError("You have to specify either input_ids") + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + + hidden_states = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + ctx_embeddings=ctx_embeddings, + ctx_begin_pos=ctx_begin_pos, + ) + + bsz, seq_len = input_shape + if ctx_embeddings is not None: + seq_len += ctx_embeddings.size(1) + # CLIP's text model uses causal mask, prepare it here. + # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 + causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to( + hidden_states.device + ) + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _expand_mask(attention_mask, hidden_states.dtype) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.final_layer_norm(last_hidden_state) + + # text_embeds.shape = [batch_size, sequence_length, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 + pooled_output = last_hidden_state[ + torch.arange(last_hidden_state.shape[0], device=input_ids.device), + input_ids.to(torch.int).argmax(dim=-1), + ] + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + def _build_causal_attention_mask(self, bsz, seq_len, dtype): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype) + mask.fill_(torch.tensor(torch.finfo(dtype).min)) + mask.triu_(1) # zero out the lower diagonal + mask = mask.unsqueeze(1) # expand mask + return mask + + +class ContextCLIPTextEmbeddings(nn.Module): + def __init__(self, config: CLIPTextConfig): + super().__init__() + embed_dim = config.hidden_size + + self.token_embedding = nn.Embedding(config.vocab_size, embed_dim) + self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) + + def forward( + self, + ctx_embeddings: torch.Tensor, + ctx_begin_pos: list, + input_ids: torch.LongTensor | None = None, + position_ids: torch.LongTensor | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor: + if ctx_embeddings is None: + ctx_len = 0 + else: + ctx_len = ctx_embeddings.shape[1] + + seq_length = (input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]) + ctx_len + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if inputs_embeds is None: + inputs_embeds = self.token_embedding(input_ids) + + # for each input embeddings, add the ctx embeddings at the correct position + input_embeds_ctx = [] + bsz = inputs_embeds.shape[0] + + if ctx_embeddings is not None: + for i in range(bsz): + cbp = ctx_begin_pos[i] + + prefix = inputs_embeds[i, :cbp] + # remove the special token embedding + suffix = inputs_embeds[i, cbp:] + + input_embeds_ctx.append(torch.cat([prefix, ctx_embeddings[i], suffix], dim=0)) + + inputs_embeds = torch.stack(input_embeds_ctx, dim=0) + + position_embeddings = self.position_embedding(position_ids) + embeddings = inputs_embeds + position_embeddings + + return embeddings diff --git a/diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py b/diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..aa3dbdae966b35061ea70f3184d25a250c8e9b63 --- /dev/null +++ b/diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py @@ -0,0 +1,355 @@ +# Copyright 2025 Salesforce.com, inc. +# Copyright 2025 The HuggingFace Team. All rights reserved.# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import PIL.Image +import torch +from transformers import CLIPTokenizer + +from ...models import AutoencoderKL, UNet2DConditionModel +from ...schedulers import PNDMScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, ImagePipelineOutput +from .blip_image_processing import BlipImageProcessor +from .modeling_blip2 import Blip2QFormerModel +from .modeling_ctx_clip import ContextCLIPTextModel + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers.pipelines import BlipDiffusionPipeline + >>> from diffusers.utils import load_image + >>> import torch + + >>> blip_diffusion_pipe = BlipDiffusionPipeline.from_pretrained( + ... "Salesforce/blipdiffusion", torch_dtype=torch.float16 + ... ).to("cuda") + + + >>> cond_subject = "dog" + >>> tgt_subject = "dog" + >>> text_prompt_input = "swimming underwater" + + >>> cond_image = load_image( + ... "https://huggingface.co/datasets/ayushtues/blipdiffusion_images/resolve/main/dog.jpg" + ... ) + >>> guidance_scale = 7.5 + >>> num_inference_steps = 25 + >>> negative_prompt = "over-exposure, under-exposure, saturated, duplicate, out of frame, lowres, cropped, worst quality, low quality, jpeg artifacts, morbid, mutilated, out of frame, ugly, bad anatomy, bad proportions, deformed, blurry, duplicate" + + + >>> output = blip_diffusion_pipe( + ... text_prompt_input, + ... cond_image, + ... cond_subject, + ... tgt_subject, + ... guidance_scale=guidance_scale, + ... num_inference_steps=num_inference_steps, + ... neg_prompt=negative_prompt, + ... height=512, + ... width=512, + ... ).images + >>> output[0].save("image.png") + ``` +""" + + +class BlipDiffusionPipeline(DeprecatedPipelineMixin, DiffusionPipeline): + """ + Pipeline for Zero-Shot Subject Driven Generation using Blip Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + tokenizer ([`CLIPTokenizer`]): + Tokenizer for the text encoder + text_encoder ([`ContextCLIPTextModel`]): + Text encoder to encode the text prompt + vae ([`AutoencoderKL`]): + VAE model to map the latents to the image + unet ([`UNet2DConditionModel`]): + Conditional U-Net architecture to denoise the image embedding. + scheduler ([`PNDMScheduler`]): + A scheduler to be used in combination with `unet` to generate image latents. + qformer ([`Blip2QFormerModel`]): + QFormer model to get multi-modal embeddings from the text and image. + image_processor ([`BlipImageProcessor`]): + Image Processor to preprocess and postprocess the image. + ctx_begin_pos (int, `optional`, defaults to 2): + Position of the context token in the text encoder. + """ + + _last_supported_version = "0.33.1" + model_cpu_offload_seq = "qformer->text_encoder->unet->vae" + + def __init__( + self, + tokenizer: CLIPTokenizer, + text_encoder: ContextCLIPTextModel, + vae: AutoencoderKL, + unet: UNet2DConditionModel, + scheduler: PNDMScheduler, + qformer: Blip2QFormerModel, + image_processor: BlipImageProcessor, + ctx_begin_pos: int = 2, + mean: list[float] = None, + std: list[float] = None, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + vae=vae, + unet=unet, + scheduler=scheduler, + qformer=qformer, + image_processor=image_processor, + ) + self.register_to_config(ctx_begin_pos=ctx_begin_pos, mean=mean, std=std) + + def get_query_embeddings(self, input_image, src_subject): + return self.qformer(image_input=input_image, text_input=src_subject, return_dict=False) + + # from the original Blip Diffusion code, specifies the target subject and augments the prompt by repeating it + def _build_prompt(self, prompts, tgt_subjects, prompt_strength=1.0, prompt_reps=20): + rv = [] + for prompt, tgt_subject in zip(prompts, tgt_subjects): + prompt = f"a {tgt_subject} {prompt.strip()}" + # a trick to amplify the prompt + rv.append(", ".join([prompt] * int(prompt_strength * prompt_reps))) + + return rv + + # Copied from diffusers.pipelines.consistency_models.pipeline_consistency_models.ConsistencyModelPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels, height, width) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def encode_prompt(self, query_embeds, prompt, device=None): + device = device or self._execution_device + + # embeddings for prompt, with query_embeds as context + max_len = self.text_encoder.text_model.config.max_position_embeddings + max_len -= self.qformer.config.num_query_tokens + + tokenized_prompt = self.tokenizer( + prompt, + padding="max_length", + truncation=True, + max_length=max_len, + return_tensors="pt", + ).to(device) + + batch_size = query_embeds.shape[0] + ctx_begin_pos = [self.config.ctx_begin_pos] * batch_size + + text_embeddings = self.text_encoder( + input_ids=tokenized_prompt.input_ids, + ctx_embeddings=query_embeds, + ctx_begin_pos=ctx_begin_pos, + )[0] + + return text_embeddings + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: list[str], + reference_image: PIL.Image.Image, + source_subject_category: list[str], + target_subject_category: list[str], + latents: torch.Tensor | None = None, + guidance_scale: float = 7.5, + height: int = 512, + width: int = 512, + num_inference_steps: int = 50, + generator: torch.Generator | list[torch.Generator] | None = None, + neg_prompt: str | None = "", + prompt_strength: float = 1.0, + prompt_reps: int = 20, + output_type: str | None = "pil", + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`list[str]`): + The prompt or prompts to guide the image generation. + reference_image (`PIL.Image.Image`): + The reference image to condition the generation on. + source_subject_category (`list[str]`): + The source subject category. + target_subject_category (`list[str]`): + The target subject category. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by random sampling. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + height (`int`, *optional*, defaults to 512): + The height of the generated image. + width (`int`, *optional*, defaults to 512): + The width of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + neg_prompt (`str`, *optional*, defaults to ""): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + prompt_strength (`float`, *optional*, defaults to 1.0): + The strength of the prompt. Specifies the number of times the prompt is repeated along with prompt_reps + to amplify the prompt. + prompt_reps (`int`, *optional*, defaults to 20): + The number of times the prompt is repeated along with prompt_strength to amplify the prompt. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` + """ + device = self._execution_device + + reference_image = self.image_processor.preprocess( + reference_image, image_mean=self.config.mean, image_std=self.config.std, return_tensors="pt" + )["pixel_values"] + reference_image = reference_image.to(device) + + if isinstance(prompt, str): + prompt = [prompt] + if isinstance(source_subject_category, str): + source_subject_category = [source_subject_category] + if isinstance(target_subject_category, str): + target_subject_category = [target_subject_category] + + batch_size = len(prompt) + + prompt = self._build_prompt( + prompts=prompt, + tgt_subjects=target_subject_category, + prompt_strength=prompt_strength, + prompt_reps=prompt_reps, + ) + query_embeds = self.get_query_embeddings(reference_image, source_subject_category) + text_embeddings = self.encode_prompt(query_embeds, prompt, device) + do_classifier_free_guidance = guidance_scale > 1.0 + if do_classifier_free_guidance: + max_length = self.text_encoder.text_model.config.max_position_embeddings + + uncond_input = self.tokenizer( + [neg_prompt] * batch_size, + padding="max_length", + max_length=max_length, + return_tensors="pt", + ) + uncond_embeddings = self.text_encoder( + input_ids=uncond_input.input_ids.to(device), + ctx_embeddings=None, + )[0] + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + scale_down_factor = 2 ** (len(self.unet.config.block_out_channels) - 1) + latents = self.prepare_latents( + batch_size=batch_size, + num_channels=self.unet.config.in_channels, + height=height // scale_down_factor, + width=width // scale_down_factor, + generator=generator, + latents=latents, + dtype=self.unet.dtype, + device=device, + ) + # set timesteps + extra_set_kwargs = {} + self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) + + for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): + # expand the latents if we are doing classifier free guidance + do_classifier_free_guidance = guidance_scale > 1.0 + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + noise_pred = self.unet( + latent_model_input, + timestep=t, + encoder_hidden_states=text_embeddings, + down_block_additional_residuals=None, + mid_block_additional_residual=None, + )["sample"] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + latents = self.scheduler.step( + noise_pred, + t, + latents, + )["prev_sample"] + + if XLA_AVAILABLE: + xm.mark_step() + + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/pipelines/bria_fibo/__init__.py b/diffusers/pipelines/bria_fibo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8dd77270902cfa436b4d5ccf130349c0c963b924 --- /dev/null +++ b/diffusers/pipelines/bria_fibo/__init__.py @@ -0,0 +1,51 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_bria_fibo"] = ["BriaFiboPipeline"] + _import_structure["pipeline_bria_fibo_edit"] = ["BriaFiboEditPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_bria_fibo import BriaFiboPipeline + from .pipeline_bria_fibo_edit import BriaFiboEditPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py b/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py new file mode 100644 index 0000000000000000000000000000000000000000..1f178066b17d9836cb5755f64ef35e2749ecc692 --- /dev/null +++ b/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py @@ -0,0 +1,838 @@ +# Copyright (c) Bria.ai. All rights reserved. +# +# This file is licensed under the Creative Commons Attribution-NonCommercial 4.0 International Public License (CC-BY-NC-4.0). +# You may obtain a copy of the license at https://creativecommons.org/licenses/by-nc/4.0/ +# +# You are free to share and adapt this material for non-commercial purposes provided you give appropriate credit, +# indicate if changes were made, and do not use the material for commercial purposes. +# +# See the license for further details. + +from typing import Any, Callable + +import numpy as np +import torch +from transformers import AutoTokenizer +from transformers.models.smollm3.modeling_smollm3 import SmolLM3ForCausalLM + +from ...image_processor import VaeImageProcessor +from ...loaders import FluxLoraLoaderMixin +from ...models.autoencoders.autoencoder_kl_wan import AutoencoderKLWan +from ...models.transformers.transformer_bria_fibo import BriaFiboTransformer2DModel +from ...pipelines.bria_fibo.pipeline_output import BriaFiboPipelineOutput +from ...pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Example: + ```python + import torch + from diffusers import BriaFiboPipeline + from diffusers.modular_pipelines import ModularPipeline + + torch.set_grad_enabled(False) + vlm_pipe = ModularPipeline.from_pretrained("briaai/FIBO-VLM-prompt-to-JSON", trust_remote_code=True) + + pipe = BriaFiboPipeline.from_pretrained( + "briaai/FIBO", + trust_remote_code=True, + torch_dtype=torch.bfloat16, + ) + pipe.enable_model_cpu_offload() + + with torch.inference_mode(): + # 1. Create a prompt to generate an initial image + output = vlm_pipe(prompt="a beautiful dog") + json_prompt_generate = output.values["json_prompt"] + + # Generate the image from the structured json prompt + results_generate = pipe(prompt=json_prompt_generate, num_inference_steps=50, guidance_scale=5) + results_generate.images[0].save("image_generate.png") + ``` +""" + + +class BriaFiboPipeline(DiffusionPipeline, FluxLoraLoaderMixin): + r""" + Args: + transformer (`BriaFiboTransformer2DModel`): + The transformer model for 2D diffusion modeling. + scheduler (`FlowMatchEulerDiscreteScheduler` or `KarrasDiffusionSchedulers`): + Scheduler to be used with `transformer` to denoise the encoded latents. + vae (`AutoencoderKLWan`): + Variational Auto-Encoder for encoding and decoding images to and from latent representations. + text_encoder (`SmolLM3ForCausalLM`): + Text encoder for processing input prompts. + tokenizer (`AutoTokenizer`): + Tokenizer used for processing the input text prompts for the text_encoder. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + transformer: BriaFiboTransformer2DModel, + scheduler: FlowMatchEulerDiscreteScheduler | KarrasDiffusionSchedulers, + vae: AutoencoderKLWan, + text_encoder: SmolLM3ForCausalLM, + tokenizer: AutoTokenizer, + ): + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + + self.vae_scale_factor = 16 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.default_sample_size = 64 + + def get_prompt_embeds( + self, + prompt: str | list[str], + num_images_per_prompt: int = 1, + max_sequence_length: int = 2048, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + if not prompt: + raise ValueError("`prompt` must be a non-empty string or list of strings.") + + batch_size = len(prompt) + bot_token_id = 128000 + + text_encoder_device = device if device is not None else torch.device("cpu") + if not isinstance(text_encoder_device, torch.device): + text_encoder_device = torch.device(text_encoder_device) + + if all(p == "" for p in prompt): + input_ids = torch.full((batch_size, 1), bot_token_id, dtype=torch.long, device=text_encoder_device) + attention_mask = torch.ones_like(input_ids) + else: + tokenized = self.tokenizer( + prompt, + padding="longest", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + input_ids = tokenized.input_ids.to(text_encoder_device) + attention_mask = tokenized.attention_mask.to(text_encoder_device) + + if any(p == "" for p in prompt): + empty_rows = torch.tensor([p == "" for p in prompt], dtype=torch.bool, device=text_encoder_device) + input_ids[empty_rows] = bot_token_id + attention_mask[empty_rows] = 1 + + encoder_outputs = self.text_encoder( + input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + ) + hidden_states = encoder_outputs.hidden_states + + prompt_embeds = torch.cat([hidden_states[-1], hidden_states[-2]], dim=-1) + prompt_embeds = prompt_embeds.to(device=device, dtype=dtype) + + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + hidden_states = tuple( + layer.repeat_interleave(num_images_per_prompt, dim=0).to(device=device) for layer in hidden_states + ) + attention_mask = attention_mask.repeat_interleave(num_images_per_prompt, dim=0).to(device=device) + + return prompt_embeds, hidden_states, attention_mask + + @staticmethod + def pad_embedding(prompt_embeds, max_tokens, attention_mask=None): + # Pad embeddings to `max_tokens` while preserving the mask of real tokens. + batch_size, seq_len, dim = prompt_embeds.shape + + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_len), dtype=prompt_embeds.dtype, device=prompt_embeds.device) + else: + attention_mask = attention_mask.to(device=prompt_embeds.device, dtype=prompt_embeds.dtype) + + if max_tokens < seq_len: + raise ValueError("`max_tokens` must be greater or equal to the current sequence length.") + + if max_tokens > seq_len: + pad_length = max_tokens - seq_len + padding = torch.zeros( + (batch_size, pad_length, dim), dtype=prompt_embeds.dtype, device=prompt_embeds.device + ) + prompt_embeds = torch.cat([prompt_embeds, padding], dim=1) + + mask_padding = torch.zeros( + (batch_size, pad_length), dtype=prompt_embeds.dtype, device=prompt_embeds.device + ) + attention_mask = torch.cat([attention_mask, mask_padding], dim=1) + + return prompt_embeds, attention_mask + + def encode_prompt( + self, + prompt: str | list[str], + device: torch.device | None = None, + num_images_per_prompt: int = 1, + guidance_scale: float = 5, + negative_prompt: str | list[str] | None = None, + prompt_embeds: torch.FloatTensor | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + max_sequence_length: int = 3000, + lora_scale: float | None = None, + ): + r""" + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + guidance_scale (`float`): + Guidance scale for classifier free guidance. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + prompt_attention_mask = None + negative_prompt_attention_mask = None + if prompt_embeds is None: + prompt_embeds, prompt_layers, prompt_attention_mask = self.get_prompt_embeds( + prompt=prompt, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + prompt_embeds = prompt_embeds.to(dtype=self.transformer.dtype) + prompt_layers = [tensor.to(dtype=self.transformer.dtype) for tensor in prompt_layers] + + if guidance_scale > 1: + if isinstance(negative_prompt, list) and negative_prompt[0] is None: + negative_prompt = "" + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds, negative_prompt_layers, negative_prompt_attention_mask = self.get_prompt_embeds( + prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.transformer.dtype) + negative_prompt_layers = [tensor.to(dtype=self.transformer.dtype) for tensor in negative_prompt_layers] + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + # Pad to longest + if prompt_attention_mask is not None: + prompt_attention_mask = prompt_attention_mask.to(device=prompt_embeds.device, dtype=prompt_embeds.dtype) + + if negative_prompt_embeds is not None: + if negative_prompt_attention_mask is not None: + negative_prompt_attention_mask = negative_prompt_attention_mask.to( + device=negative_prompt_embeds.device, dtype=negative_prompt_embeds.dtype + ) + max_tokens = max(negative_prompt_embeds.shape[1], prompt_embeds.shape[1]) + + prompt_embeds, prompt_attention_mask = self.pad_embedding( + prompt_embeds, max_tokens, attention_mask=prompt_attention_mask + ) + prompt_layers = [self.pad_embedding(layer, max_tokens)[0] for layer in prompt_layers] + + negative_prompt_embeds, negative_prompt_attention_mask = self.pad_embedding( + negative_prompt_embeds, max_tokens, attention_mask=negative_prompt_attention_mask + ) + negative_prompt_layers = [self.pad_embedding(layer, max_tokens)[0] for layer in negative_prompt_layers] + else: + max_tokens = prompt_embeds.shape[1] + prompt_embeds, prompt_attention_mask = self.pad_embedding( + prompt_embeds, max_tokens, attention_mask=prompt_attention_mask + ) + negative_prompt_layers = None + + dtype = self.text_encoder.dtype + text_ids = torch.zeros(prompt_embeds.shape[0], max_tokens, 3).to(device=device, dtype=dtype) + + return ( + prompt_embeds, + negative_prompt_embeds, + text_ids, + prompt_attention_mask, + negative_prompt_attention_mask, + prompt_layers, + negative_prompt_layers, + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @staticmethod + # Based on diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + height = height // vae_scale_factor + width = width // vae_scale_factor + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + def _unpack_latents_no_patch(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + height = height // vae_scale_factor + width = width // vae_scale_factor + + latents = latents.view(batch_size, height, width, channels) + latents = latents.permute(0, 3, 1, 2) + + return latents + + @staticmethod + def _pack_latents_no_patch(latents, batch_size, num_channels_latents, height, width): + latents = latents.permute(0, 2, 3, 1) + latents = latents.reshape(batch_size, height * width, num_channels_latents) + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + do_patching=False, + ): + height = int(height) // self.vae_scale_factor + width = int(width) // self.vae_scale_factor + + shape = (batch_size, num_channels_latents, height, width) + + if latents is not None: + latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + return latents.to(device=device, dtype=dtype), latent_image_ids + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + if do_patching: + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + else: + latents = self._pack_latents_no_patch(latents, batch_size, num_channels_latents, height, width) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + + return latents, latent_image_ids + + @staticmethod + def _prepare_attention_mask(attention_mask): + attention_matrix = torch.einsum("bi,bj->bij", attention_mask, attention_mask) + + # convert to 0 - keep, -inf ignore + attention_matrix = torch.where( + attention_matrix == 1, 0.0, -torch.inf + ) # Apply -inf to ignored tokens for nulling softmax score + return attention_matrix + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 30, + timesteps: list[int] = None, + guidance_scale: float = 5, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.FloatTensor | None = None, + prompt_embeds: torch.FloatTensor | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + joint_attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 3000, + do_patching=False, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 3000): Maximum sequence length to use with the `prompt`. + do_patching (`bool`, *optional*, defaults to `False`): Whether to use patching. + Examples: + Returns: + [`~pipelines.flux.BriaFiboPipelineOutput`] or `tuple`: [`~pipelines.flux.BriaFiboPipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + height=height, + width=width, + prompt_embeds=prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + + ( + prompt_embeds, + negative_prompt_embeds, + text_ids, + prompt_attention_mask, + negative_prompt_attention_mask, + prompt_layers, + negative_prompt_layers, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + guidance_scale=guidance_scale, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + device=device, + max_sequence_length=max_sequence_length, + num_images_per_prompt=num_images_per_prompt, + lora_scale=lora_scale, + ) + prompt_batch_size = prompt_embeds.shape[0] + + if guidance_scale > 1: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_layers = [ + torch.cat([negative_prompt_layers[i], prompt_layers[i]], dim=0) for i in range(len(prompt_layers)) + ] + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + total_num_layers_transformer = len(self.transformer.transformer_blocks) + len( + self.transformer.single_transformer_blocks + ) + if len(prompt_layers) >= total_num_layers_transformer: + # remove first layers + prompt_layers = prompt_layers[len(prompt_layers) - total_num_layers_transformer :] + else: + # duplicate last layer + prompt_layers = prompt_layers + [prompt_layers[-1]] * (total_num_layers_transformer - len(prompt_layers)) + + # 5. Prepare latent variables + + num_channels_latents = self.transformer.config.in_channels + if do_patching: + num_channels_latents = int(num_channels_latents / 4) + + latents, latent_image_ids = self.prepare_latents( + prompt_batch_size, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + do_patching, + ) + + latent_attention_mask = torch.ones( + [latents.shape[0], latents.shape[1]], dtype=latents.dtype, device=latents.device + ) + if guidance_scale > 1: + latent_attention_mask = latent_attention_mask.repeat(2, 1) + + attention_mask = torch.cat([prompt_attention_mask, latent_attention_mask], dim=1) + attention_mask = self._prepare_attention_mask(attention_mask) # batch, seq => batch, seq, seq + attention_mask = attention_mask.unsqueeze(dim=1).to(dtype=self.transformer.dtype) # for head broadcasting + + if self._joint_attention_kwargs is None: + self._joint_attention_kwargs = {} + self._joint_attention_kwargs["attention_mask"] = attention_mask + + # Adapt scheduler to dynamic shifting (resolution dependent) + + if do_patching: + seq_len = (height // (self.vae_scale_factor * 2)) * (width // (self.vae_scale_factor * 2)) + else: + seq_len = (height // self.vae_scale_factor) * (width // self.vae_scale_factor) + + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + + mu = calculate_shift( + seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + + # Init sigmas and timesteps according to shift size + # This changes the scheduler in-place according to the dynamic scheduling + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps=num_inference_steps, + device=device, + timesteps=None, + sigmas=sigmas, + mu=mu, + ) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # Support old different diffusers versions + if len(latent_image_ids.shape) == 3: + latent_image_ids = latent_image_ids[0] + + if len(text_ids.shape) == 3: + text_ids = text_ids[0] + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1 else latents + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]).to( + device=latent_model_input.device, dtype=latent_model_input.dtype + ) + + # This is predicts "v" from flow-matching or eps from diffusion + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + text_encoder_layers=prompt_layers, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + txt_ids=text_ids, + img_ids=latent_image_ids, + )[0] + + # perform guidance + if guidance_scale > 1: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + if do_patching: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + else: + latents = self._unpack_latents_no_patch(latents, height, width, self.vae_scale_factor) + + latents = latents.unsqueeze(dim=2) + latents_device = latents[0].device + latents_dtype = latents[0].dtype + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents_device, latents_dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents_device, latents_dtype + ) + latents_scaled = [latent / latents_std + latents_mean for latent in latents] + latents_scaled = torch.cat(latents_scaled, dim=0) + image = [] + for scaled_latent in latents_scaled: + curr_image = self.vae.decode(scaled_latent.unsqueeze(0), return_dict=False)[0] + curr_image = self.image_processor.postprocess(curr_image.squeeze(dim=2), output_type=output_type) + image.append(curr_image) + if len(image) == 1: + image = image[0] + else: + image = np.stack(image, axis=0) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return BriaFiboPipelineOutput(images=image) + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if max_sequence_length is not None and max_sequence_length > 3000: + raise ValueError(f"`max_sequence_length` cannot be greater than 3000 but is {max_sequence_length}") diff --git a/diffusers/pipelines/bria_fibo/pipeline_bria_fibo_edit.py b/diffusers/pipelines/bria_fibo/pipeline_bria_fibo_edit.py new file mode 100644 index 0000000000000000000000000000000000000000..c2327bbce1c7820312652f3476dcda2f4acd9e0c --- /dev/null +++ b/diffusers/pipelines/bria_fibo/pipeline_bria_fibo_edit.py @@ -0,0 +1,1133 @@ +# Copyright (c) Bria.ai. All rights reserved. +# +# This file is licensed under the Creative Commons Attribution-NonCommercial 4.0 International Public License (CC-BY-NC-4.0). +# You may obtain a copy of the license at https://creativecommons.org/licenses/by-nc/4.0/ +# +# You are free to share and adapt this material for non-commercial purposes provided you give appropriate credit, +# indicate if changes were made, and do not use the material for commercial purposes. +# +# See the license for further details. + +import json +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from PIL import Image +from transformers import AutoTokenizer +from transformers.models.smollm3.modeling_smollm3 import SmolLM3ForCausalLM + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FluxLoraLoaderMixin +from ...models.autoencoders.autoencoder_kl_wan import AutoencoderKLWan +from ...models.transformers.transformer_bria_fibo import BriaFiboTransformer2DModel +from ...pipelines.bria_fibo.pipeline_output import BriaFiboPipelineOutput +from ...pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +PipelineMaskInput = Union[ + torch.FloatTensor, Image.Image, List[Image.Image], List[torch.FloatTensor], np.ndarray, List[np.ndarray] +] + +# TODO: Update example docstring +EXAMPLE_DOC_STRING = """ + Example: + ```python + import torch + from diffusers import BriaFiboEditPipeline + from diffusers.modular_pipelines import ModularPipeline + + torch.set_grad_enabled(False) + vlm_pipe = ModularPipelineBlocks.from_pretrained("briaai/FIBO-VLM-prompt-to-JSON", trust_remote_code=True) + vlm_pipe = vlm_pipe.init_pipeline() + + pipe = BriaFiboEditPipeline.from_pretrained( + "briaai/fibo-edit", + torch_dtype=torch.bfloat16, + ) + pipe.to("cuda") + + output = vlm_pipe( + prompt="A hyper-detailed, ultra-fluffy owl sitting in the trees at night, looking directly at the camera with wide, adorable, expressive eyes. Its feathers are soft and voluminous, catching the cool moonlight with subtle silver highlights. The owl's gaze is curious and full of charm, giving it a whimsical, storybook-like personality." + ) + json_prompt_generate = json.loads(output.values["json_prompt"]) + + image = Image.open("image_generate.png") + + edit_prompt = "Make the owl to be a cat" + + json_prompt_generate["edit_instruction"] = edit_prompt + + results_generate = pipe( + prompt=json_prompt_generate, num_inference_steps=50, guidance_scale=3.5, image=image, output_type="np" + ) + ``` +""" + +PREFERRED_RESOLUTION = { + 256 * 256: [(208, 304), (224, 288), (256, 256), (288, 224), (304, 208), (320, 192), (336, 192)], + 512 * 512: [ + (416, 624), + (432, 592), + (464, 560), + (512, 512), + (544, 480), + (576, 448), + (592, 432), + (608, 416), + (624, 416), + (640, 400), + (672, 384), + (704, 368), + ], + 1024 * 1024: [ + (832, 1248), + (880, 1184), + (912, 1136), + (1024, 1024), + (1136, 912), + (1184, 880), + (1216, 848), + (1248, 832), + (1248, 832), + (1264, 816), + (1296, 800), + (1360, 768), + ], +} + + +def is_valid_edit_json(json_input: str | dict): + """ + Check if the input is a valid JSON string or dict with an "edit_instruction" key. + + Args: + json_input (`str` or `dict`): + The JSON string or dict to check. + + Returns: + `bool`: True if the input is a valid JSON string or dict with an "edit_instruction" key, False otherwise. + """ + try: + if isinstance(json_input, str) and "edit_instruction" in json_input: + json.loads(json_input) + return True + elif isinstance(json_input, dict) and "edit_instruction" in json_input: + return True + else: + return False + except json.JSONDecodeError: + return False + + +def is_valid_mask(mask: PipelineMaskInput): + """ + Check if the mask is a valid mask. + """ + if isinstance(mask, torch.Tensor): + return True + elif isinstance(mask, Image.Image): + return True + elif isinstance(mask, list): + return all(isinstance(m, (torch.Tensor, Image.Image, np.ndarray)) for m in mask) + elif isinstance(mask, np.ndarray): + return mask.ndim in [2, 3] and mask.min() >= 0 and mask.max() <= 1 + else: + return False + + +def get_mask_size(mask: PipelineMaskInput): + """ + Get the size of the mask. + """ + if isinstance(mask, torch.Tensor): + return mask.shape[-2:] + elif isinstance(mask, Image.Image): + return mask.size[::-1] # (height, width) + elif isinstance(mask, list): + return [get_mask_size(m) for m in mask] + elif isinstance(mask, np.ndarray): + return mask.shape[-2:] + else: + return None + + +def get_image_size(image: PipelineImageInput): + """ + Get the size of the image. + """ + if isinstance(image, torch.Tensor): + return image.shape[-2:] + elif isinstance(image, Image.Image): + return image.size[::-1] # (height, width) + elif isinstance(image, list): + return [get_image_size(i) for i in image] + else: + return None + + +def paste_mask_on_image(mask: PipelineMaskInput, image: PipelineImageInput): + """convert mask and image to PIL Images and paste the mask on the image""" + if isinstance(mask, torch.Tensor): + if mask.ndim == 3 and mask.shape[0] == 1: + mask = mask.squeeze(0) + mask = Image.fromarray((mask.cpu().numpy() * 255).astype(np.uint8)) + elif isinstance(mask, Image.Image): + pass + elif isinstance(mask, list): + mask = mask[0] + if isinstance(mask, torch.Tensor): + if mask.ndim == 3 and mask.shape[0] == 1: + mask = mask.squeeze(0) + mask = Image.fromarray((mask.cpu().numpy() * 255).astype(np.uint8)) + elif isinstance(mask, np.ndarray): + mask = Image.fromarray((mask * 255).astype(np.uint8)) + elif isinstance(mask, np.ndarray): + mask = Image.fromarray((mask * 255).astype(np.uint8)) + + if isinstance(image, torch.Tensor): + if image.ndim == 3: + image = image.permute(1, 2, 0) + image = Image.fromarray((image.cpu().numpy() * 255).astype(np.uint8)) + elif isinstance(image, Image.Image): + pass + elif isinstance(image, list): + image = image[0] + if isinstance(image, torch.Tensor): + if image.ndim == 3: + image = image.permute(1, 2, 0) + image = Image.fromarray((image.cpu().numpy() * 255).astype(np.uint8)) + elif isinstance(image, np.ndarray): + image = Image.fromarray((image * 255).astype(np.uint8)) + elif isinstance(image, np.ndarray): + image = Image.fromarray((image * 255).astype(np.uint8)) + + mask = mask.convert("L") + image = image.convert("RGB") + gray_color = (128, 128, 128) + gray_img = Image.new("RGB", image.size, gray_color) + image = Image.composite(gray_img, image, mask) + return image + + +class BriaFiboEditPipeline(DiffusionPipeline, FluxLoraLoaderMixin): + r""" + Args: + transformer (`BriaFiboTransformer2DModel`): + The transformer model for 2D diffusion modeling. + scheduler (`FlowMatchEulerDiscreteScheduler` or `KarrasDiffusionSchedulers`): + Scheduler to be used with `transformer` to denoise the encoded latents. + vae (`AutoencoderKLWan`): + Variational Auto-Encoder for encoding and decoding images to and from latent representations. + text_encoder (`SmolLM3ForCausalLM`): + Text encoder for processing input prompts. + tokenizer (`AutoTokenizer`): + Tokenizer used for processing the input text prompts for the text_encoder. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + transformer: BriaFiboTransformer2DModel, + scheduler: Union[FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers], + vae: AutoencoderKLWan, + text_encoder: SmolLM3ForCausalLM, + tokenizer: AutoTokenizer, + ): + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + + self.vae_scale_factor = 16 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) # * 2) + self.default_sample_size = 32 # 64 + + def get_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + max_sequence_length: int = 2048, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + if not prompt: + raise ValueError("`prompt` must be a non-empty string or list of strings.") + + batch_size = len(prompt) + bot_token_id = 128000 + + text_encoder_device = device if device is not None else torch.device("cpu") + if not isinstance(text_encoder_device, torch.device): + text_encoder_device = torch.device(text_encoder_device) + + if all(p == "" for p in prompt): + input_ids = torch.full((batch_size, 1), bot_token_id, dtype=torch.long, device=text_encoder_device) + attention_mask = torch.ones_like(input_ids) + else: + tokenized = self.tokenizer( + prompt, + padding="longest", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + input_ids = tokenized.input_ids.to(text_encoder_device) + attention_mask = tokenized.attention_mask.to(text_encoder_device) + + if any(p == "" for p in prompt): + empty_rows = torch.tensor([p == "" for p in prompt], dtype=torch.bool, device=text_encoder_device) + input_ids[empty_rows] = bot_token_id + attention_mask[empty_rows] = 1 + + encoder_outputs = self.text_encoder( + input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + ) + hidden_states = encoder_outputs.hidden_states + + prompt_embeds = torch.cat([hidden_states[-1], hidden_states[-2]], dim=-1) + prompt_embeds = prompt_embeds.to(device=device, dtype=dtype) + + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + hidden_states = tuple( + layer.repeat_interleave(num_images_per_prompt, dim=0).to(device=device) for layer in hidden_states + ) + attention_mask = attention_mask.repeat_interleave(num_images_per_prompt, dim=0).to(device=device) + + return prompt_embeds, hidden_states, attention_mask + + @staticmethod + def pad_embedding(prompt_embeds, max_tokens, attention_mask=None): + # Pad embeddings to `max_tokens` while preserving the mask of real tokens. + batch_size, seq_len, dim = prompt_embeds.shape + + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_len), dtype=prompt_embeds.dtype, device=prompt_embeds.device) + else: + attention_mask = attention_mask.to(device=prompt_embeds.device, dtype=prompt_embeds.dtype) + + if max_tokens < seq_len: + raise ValueError("`max_tokens` must be greater or equal to the current sequence length.") + + if max_tokens > seq_len: + pad_length = max_tokens - seq_len + padding = torch.zeros( + (batch_size, pad_length, dim), dtype=prompt_embeds.dtype, device=prompt_embeds.device + ) + prompt_embeds = torch.cat([prompt_embeds, padding], dim=1) + + mask_padding = torch.zeros( + (batch_size, pad_length), dtype=prompt_embeds.dtype, device=prompt_embeds.device + ) + attention_mask = torch.cat([attention_mask, mask_padding], dim=1) + + return prompt_embeds, attention_mask + + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: torch.device | None = None, + num_images_per_prompt: int = 1, + guidance_scale: float = 5, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 3000, + lora_scale: bool | None = None, + ): + r""" + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + guidance_scale (`float`): + Guidance scale for classifier free guidance. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + prompt_attention_mask = None + negative_prompt_attention_mask = None + if prompt_embeds is None: + prompt_embeds, prompt_layers, prompt_attention_mask = self.get_prompt_embeds( + prompt=prompt, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + prompt_embeds = prompt_embeds.to(dtype=self.transformer.dtype) + prompt_layers = [tensor.to(dtype=self.transformer.dtype) for tensor in prompt_layers] + + if guidance_scale > 1: + if isinstance(negative_prompt, list) and negative_prompt[0] is None: + negative_prompt = "" + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds, negative_prompt_layers, negative_prompt_attention_mask = self.get_prompt_embeds( + prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.transformer.dtype) + negative_prompt_layers = [tensor.to(dtype=self.transformer.dtype) for tensor in negative_prompt_layers] + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + # Pad to longest + if prompt_attention_mask is not None: + prompt_attention_mask = prompt_attention_mask.to(device=prompt_embeds.device, dtype=prompt_embeds.dtype) + + if negative_prompt_embeds is not None: + if negative_prompt_attention_mask is not None: + negative_prompt_attention_mask = negative_prompt_attention_mask.to( + device=negative_prompt_embeds.device, dtype=negative_prompt_embeds.dtype + ) + max_tokens = max(negative_prompt_embeds.shape[1], prompt_embeds.shape[1]) + + prompt_embeds, prompt_attention_mask = self.pad_embedding( + prompt_embeds, max_tokens, attention_mask=prompt_attention_mask + ) + prompt_layers = [self.pad_embedding(layer, max_tokens)[0] for layer in prompt_layers] + + negative_prompt_embeds, negative_prompt_attention_mask = self.pad_embedding( + negative_prompt_embeds, max_tokens, attention_mask=negative_prompt_attention_mask + ) + negative_prompt_layers = [self.pad_embedding(layer, max_tokens)[0] for layer in negative_prompt_layers] + else: + max_tokens = prompt_embeds.shape[1] + prompt_embeds, prompt_attention_mask = self.pad_embedding( + prompt_embeds, max_tokens, attention_mask=prompt_attention_mask + ) + negative_prompt_layers = None + + dtype = self.text_encoder.dtype + text_ids = torch.zeros(prompt_embeds.shape[0], max_tokens, 3).to(device=device, dtype=dtype) + + return ( + prompt_embeds, + negative_prompt_embeds, + text_ids, + prompt_attention_mask, + negative_prompt_attention_mask, + prompt_layers, + negative_prompt_layers, + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @staticmethod + # Based on diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + height = height // vae_scale_factor + width = width // vae_scale_factor + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + def _unpack_latents_no_patch(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + height = height // vae_scale_factor + width = width // vae_scale_factor + + latents = latents.view(batch_size, height, width, channels) + latents = latents.permute(0, 3, 1, 2) + + return latents + + @staticmethod + def _pack_latents_no_patch(latents, batch_size, num_channels_latents, height, width): + latents = latents.permute(0, 2, 3, 1) + latents = latents.reshape(batch_size, height * width, num_channels_latents) + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + do_patching=False, + ): + height = int(height) // self.vae_scale_factor + width = int(width) // self.vae_scale_factor + + shape = (batch_size, num_channels_latents, height, width) + + if latents is not None: + latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + return latents.to(device=device, dtype=dtype), latent_image_ids + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + if do_patching: + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + else: + latents = self._pack_latents_no_patch(latents, batch_size, num_channels_latents, height, width) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + + return latents, latent_image_ids + + @staticmethod + def _prepare_attention_mask(attention_mask): + attention_matrix = torch.einsum("bi,bj->bij", attention_mask, attention_mask) + + # convert to 0 - keep, -inf ignore + attention_matrix = torch.where( + attention_matrix == 1, 0.0, -torch.inf + ) # Apply -inf to ignored tokens for nulling softmax score + return attention_matrix + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Optional[PipelineImageInput] = None, + mask: Optional[PipelineMaskInput] = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 30, + timesteps: List[int] = None, + seed: int | None = None, + guidance_scale: float = 5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 3000, + do_patching=False, + _auto_resize: bool = True, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`PIL.Image.Image` or `torch.FloatTensor`, *optional*): + The image to guide the image generation. If not defined, the pipeline will generate an image from + scratch. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + seed (`int`, *optional*): + A seed used to make generation deterministic. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 3000): Maximum sequence length to use with the `prompt`. + do_patching (`bool`, *optional*, defaults to `False`): Whether to use patching. + Examples: + Returns: + [`~pipelines.flux.BriaFiboPipelineOutput`] or `tuple`: [`~pipelines.flux.BriaFiboPipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + + if height is None or width is None: + if image is not None: + image_height, image_width = self.image_processor.get_default_height_width(image) + if _auto_resize: + image_width, image_height = min( + PREFERRED_RESOLUTION[1024 * 1024], + key=lambda size: abs(size[0] / size[1] - image_width / image_height), + ) + width, height = image_width, image_height + else: + raise ValueError("You must provide either an image or both height and width.") + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + seed=seed, + image=image, + mask=mask, + prompt=prompt, + height=height, + width=width, + prompt_embeds=prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + if mask is not None and image is not None: + image = paste_mask_on_image(mask, image) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + + if prompt is not None and is_valid_edit_json(prompt): + prompt = json.dumps(prompt) + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + if generator is None and seed is not None: + generator = torch.Generator(device=device).manual_seed(seed) + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + + ( + prompt_embeds, + negative_prompt_embeds, + text_ids, + prompt_attention_mask, + negative_prompt_attention_mask, + prompt_layers, + negative_prompt_layers, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + guidance_scale=guidance_scale, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + device=device, + max_sequence_length=max_sequence_length, + num_images_per_prompt=num_images_per_prompt, + lora_scale=lora_scale, + ) + prompt_batch_size = prompt_embeds.shape[0] + + if guidance_scale > 1: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_layers = [ + torch.cat([negative_prompt_layers[i], prompt_layers[i]], dim=0) for i in range(len(prompt_layers)) + ] + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + total_num_layers_transformer = len(self.transformer.transformer_blocks) + len( + self.transformer.single_transformer_blocks + ) + if len(prompt_layers) >= total_num_layers_transformer: + # remove first layers + prompt_layers = prompt_layers[len(prompt_layers) - total_num_layers_transformer :] + else: + # duplicate last layer + prompt_layers = prompt_layers + [prompt_layers[-1]] * (total_num_layers_transformer - len(prompt_layers)) + + # Preprocess image + if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): + image = self.image_processor.resize(image, height, width) + image = self.image_processor.preprocess(image, height, width) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + if do_patching: + num_channels_latents = int(num_channels_latents / 4) + + latents, latent_image_ids = self.prepare_latents( + prompt_batch_size, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + do_patching, + ) + + if image is not None: + image_latents, image_ids = self.prepare_image_latents( + image=image, + batch_size=batch_size * num_images_per_prompt, + num_channels_latents=num_channels_latents, + height=height, + width=width, + dtype=prompt_embeds.dtype, + device=device, + generator=generator, + ) + latent_image_ids = torch.cat([latent_image_ids, image_ids], dim=0) # dim 0 is sequence dimension + else: + image_latents = None + + latent_attention_mask = torch.ones( + [latents.shape[0], latents.shape[1]], dtype=latents.dtype, device=latents.device + ) + if guidance_scale > 1: + latent_attention_mask = latent_attention_mask.repeat(2, 1) + + if image_latents is None: + attention_mask = torch.cat([prompt_attention_mask, latent_attention_mask], dim=1) + else: + image_latent_attention_mask = torch.ones( + [image_latents.shape[0], image_latents.shape[1]], + dtype=image_latents.dtype, + device=image_latents.device, + ) + if guidance_scale > 1: + image_latent_attention_mask = image_latent_attention_mask.repeat(2, 1) + attention_mask = torch.cat( + [prompt_attention_mask, latent_attention_mask, image_latent_attention_mask], dim=1 + ) + + attention_mask = self.create_attention_matrix(attention_mask) # batch, seq => batch, seq, seq + attention_mask = attention_mask.unsqueeze(dim=1).to(dtype=self.transformer.dtype) # for head broadcasting + + if self._joint_attention_kwargs is None: + self._joint_attention_kwargs = {} + self._joint_attention_kwargs["attention_mask"] = attention_mask + + # Adapt scheduler to dynamic shifting (resolution dependent) + + if do_patching: + seq_len = (height // (self.vae_scale_factor * 2)) * (width // (self.vae_scale_factor * 2)) + else: + seq_len = (height // self.vae_scale_factor) * (width // self.vae_scale_factor) + + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + + mu = calculate_shift( + seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + + # Init sigmas and timesteps according to shift size + # This changes the scheduler in-place according to the dynamic scheduling + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps=num_inference_steps, + device=device, + timesteps=None, + sigmas=sigmas, + mu=mu, + ) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # Support old different diffusers versions + if len(latent_image_ids.shape) == 3: + latent_image_ids = latent_image_ids[0] + + if len(text_ids.shape) == 3: + text_ids = text_ids[0] + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = latents + + if image_latents is not None: + latent_model_input = torch.cat([latent_model_input, image_latents], dim=1) + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latent_model_input] * 2) if guidance_scale > 1 else latent_model_input + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]).to( + device=latent_model_input.device, dtype=latent_model_input.dtype + ) + + # This is predicts "v" from flow-matching or eps from diffusion + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + text_encoder_layers=prompt_layers, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + txt_ids=text_ids, + img_ids=latent_image_ids, + )[0] + + # perform guidance + if guidance_scale > 1: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred[:, : latents.shape[1], ...], t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + if do_patching: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + else: + latents = self._unpack_latents_no_patch(latents, height, width, self.vae_scale_factor) + + latents = latents.unsqueeze(dim=2) + latents_device = latents[0].device + latents_dtype = latents[0].dtype + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents_device, latents_dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents_device, latents_dtype + ) + latents_scaled = [latent / latents_std + latents_mean for latent in latents] + latents_scaled = torch.cat(latents_scaled, dim=0) + image = [] + for scaled_latent in latents_scaled: + curr_image = self.vae.decode(scaled_latent.unsqueeze(0), return_dict=False)[0] + curr_image = self.image_processor.postprocess(curr_image.squeeze(dim=2), output_type=output_type) + image.append(curr_image) + if len(image) == 1: + image = image[0] + else: + image = np.stack(image, axis=0) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return BriaFiboPipelineOutput(images=image) + + def prepare_image_latents( + self, + image: torch.Tensor, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + generator: torch.Generator | list[torch.Generator] | None = None, + ): + image = image.to(device=device, dtype=dtype) + + height = int(height) // self.vae_scale_factor + width = int(width) // self.vae_scale_factor + + # scaling + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1).to(device, dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + device, dtype + ) + + image_latents_cthw = self.vae.encode(image.unsqueeze(2)).latent_dist.mean + latents_scaled = [(latent - latents_mean) * latents_std for latent in image_latents_cthw] + image_latents_cthw = torch.concat(latents_scaled, dim=0) + image_latents_bchw = image_latents_cthw[:, :, 0, :, :] + + image_latent_height, image_latent_width = image_latents_bchw.shape[2:] + image_latents_bsd = self._pack_latents_no_patch( + latents=image_latents_bchw, + batch_size=batch_size, + num_channels_latents=num_channels_latents, + height=image_latent_height, + width=image_latent_width, + ) + # breakpoint() + image_ids = self._prepare_latent_image_ids( + batch_size=batch_size, height=image_latent_height, width=image_latent_width, device=device, dtype=dtype + ) + # image ids are the same as latent ids with the first dimension set to 1 instead of 0 + image_ids[..., 0] = 1 + return image_latents_bsd, image_ids + + def check_inputs( + self, + prompt, + seed, + image, + mask, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if seed is not None and not isinstance(seed, int): + raise ValueError("Seed must be an integer") + if image is not None and not isinstance(image, (torch.Tensor, Image.Image, list)): + raise ValueError("Image must be a valid image") + if image is None and mask is not None: + raise ValueError("If mask is provided, image must also be provided") + + if mask is not None and not is_valid_mask(mask): + raise ValueError("Mask must be a valid mask") + + if mask is not None and image is not None and not (get_mask_size(mask) == get_image_size(image)): + raise ValueError("Mask and image must have the same size") + + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and not is_valid_edit_json(prompt): + raise ValueError(f"`prompt` has to be a valid JSON string or dict but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if max_sequence_length is not None and max_sequence_length > 3000: + raise ValueError(f"`max_sequence_length` cannot be greater than 3000 but is {max_sequence_length}") + + def create_attention_matrix(self, attention_mask): + attention_matrix = torch.einsum("bi,bj->bij", attention_mask, attention_mask) + + # convert to 0 - keep, -inf ignore + attention_matrix = torch.where( + attention_matrix == 1, 0.0, -torch.inf + ) # Apply -inf to ignored tokens for nulling softmax score + return attention_matrix diff --git a/diffusers/pipelines/bria_fibo/pipeline_output.py b/diffusers/pipelines/bria_fibo/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..0c131db29d9f08d6767a4c11ac78055f020916ba --- /dev/null +++ b/diffusers/pipelines/bria_fibo/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import numpy as np +import PIL.Image + +from ...utils import BaseOutput + + +@dataclass +class BriaFiboPipelineOutput(BaseOutput): + """ + Output class for BriaFibo pipelines. + + Args: + images (`list[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: list[PIL.Image.Image, np.ndarray] diff --git a/diffusers/pipelines/chronoedit/__init__.py b/diffusers/pipelines/chronoedit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cffe4660977f28c13fddb0f8948b419d43efa79a --- /dev/null +++ b/diffusers/pipelines/chronoedit/__init__.py @@ -0,0 +1,47 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_chronoedit"] = ["ChronoEditPipeline"] +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_chronoedit import ChronoEditPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/pipelines/chronoedit/pipeline_chronoedit.py b/diffusers/pipelines/chronoedit/pipeline_chronoedit.py new file mode 100644 index 0000000000000000000000000000000000000000..1e0cc0ea5c2aa60167b9d39c72e78d95ae1fe1a8 --- /dev/null +++ b/diffusers/pipelines/chronoedit/pipeline_chronoedit.py @@ -0,0 +1,750 @@ +# Copyright 2025 The ChronoEdit Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +from typing import Any, Callable + +import PIL +import regex as re +import torch +from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput +from ...loaders import WanLoraLoaderMixin +from ...models import AutoencoderKLWan, ChronoEditTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import ChronoEditPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_ftfy_available(): + import ftfy + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> import numpy as np + >>> from diffusers import AutoencoderKLWan, ChronoEditTransformer3DModel, ChronoEditPipeline + >>> from diffusers.utils import export_to_video, load_image + >>> from transformers import CLIPVisionModel + + >>> # Available models: nvidia/ChronoEdit-14B-Diffusers + >>> model_id = "nvidia/ChronoEdit-14B-Diffusers" + >>> image_encoder = CLIPVisionModel.from_pretrained( + ... model_id, subfolder="image_encoder", torch_dtype=torch.float32 + ... ) + >>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) + >>> transformer = ChronoEditTransformer3DModel.from_pretrained( + ... model_id, subfolder="transformer", torch_dtype=torch.bfloat16 + ... ) + >>> pipe = ChronoEditPipeline.from_pretrained( + ... model_id, vae=vae, image_encoder=image_encoder, transformer=transformer, torch_dtype=torch.bfloat16 + ... ) + >>> pipe.to("cuda") + + >>> image = load_image("https://huggingface.co/spaces/nvidia/ChronoEdit/resolve/main/examples/3.png") + >>> max_area = 720 * 1280 + >>> aspect_ratio = image.height / image.width + >>> mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1] + >>> height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value + >>> width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value + >>> image = image.resize((width, height)) + >>> prompt = ( + ... "The user wants to transform the image by adding a small, cute mouse sitting inside the floral teacup, enjoying a spa bath. The mouse should appear relaxed and cheerful, with a tiny white bath towel draped over its head like a turban. It should be positioned comfortably in the cup’s liquid, with gentle steam rising around it to blend with the cozy atmosphere. " + ... "The mouse’s pose should be natural—perhaps sitting upright with paws resting lightly on the rim or submerged in the tea. The teacup’s floral design, gold trim, and warm lighting must remain unchanged to preserve the original aesthetic. The steam should softly swirl around the mouse, enhancing the spa-like, whimsical mood." + ... ) + + >>> output = pipe( + ... image=image, + ... prompt=prompt, + ... height=height, + ... width=width, + ... num_frames=5, + ... guidance_scale=5.0, + ... enable_temporal_reasoning=False, + ... num_temporal_reasoning_steps=0, + ... ).frames[0] + >>> export_to_video(output, "output.mp4", fps=16) + ``` +""" + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class ChronoEditPipeline(DiffusionPipeline, WanLoraLoaderMixin): + r""" + Pipeline for image-to-video generation using Wan. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + tokenizer ([`T5Tokenizer`]): + Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer), + specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + text_encoder ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + image_encoder ([`CLIPVisionModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModel), specifically + the + [clip-vit-huge-patch14](https://github.com/mlfoundations/open_clip/blob/main/docs/PRETRAINED.md#vit-h14-xlm-roberta-large) + variant. + transformer ([`WanTransformer3DModel`]): + Conditional Transformer to denoise the input latents. + scheduler ([`UniPCMultistepScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLWan`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + image_encoder: CLIPVisionModel, + image_processor: CLIPImageProcessor, + transformer: ChronoEditTransformer3DModel, + vae: AutoencoderKLWan, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + image_encoder=image_encoder, + transformer=transformer, + scheduler=scheduler, + image_processor=image_processor, + ) + + self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + self.image_processor = image_processor + + # Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline.encode_image + def encode_image( + self, + image: PipelineImageInput, + device: torch.device | None = None, + ): + device = device or self._execution_device + image = self.image_processor(images=image, return_tensors="pt").to(device) + image_embeds = self.image_encoder(**image, output_hidden_states=True) + return image_embeds.hidden_states[-2] + + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 226, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + # modified from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline.check_inputs + def check_inputs( + self, + prompt, + negative_prompt, + image, + height, + width, + prompt_embeds=None, + negative_prompt_embeds=None, + image_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if image is not None and image_embeds is not None: + raise ValueError( + f"Cannot forward both `image`: {image} and `image_embeds`: {image_embeds}. Please make sure to" + " only forward one of the two." + ) + if image is None and image_embeds is None: + raise ValueError( + "Provide either `image` or `prompt_embeds`. Cannot leave both `image` and `image_embeds` undefined." + ) + if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image): + raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}") + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif negative_prompt is not None and ( + not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + # modified from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline.prepare_latents + def prepare_latents( + self, + image: PipelineImageInput, + batch_size: int, + num_channels_latents: int = 16, + height: int = 480, + width: int = 832, + num_frames: int = 81, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + latent_height = height // self.vae_scale_factor_spatial + latent_width = width // self.vae_scale_factor_spatial + + shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + image = image.unsqueeze(2) # [batch_size, channels, 1, height, width] + video_condition = torch.cat( + [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2 + ) + video_condition = video_condition.to(device=device, dtype=self.vae.dtype) + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + + if isinstance(generator, list): + latent_condition = [ + retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") for _ in generator + ] + latent_condition = torch.cat(latent_condition) + else: + latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") + latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1) + + latent_condition = latent_condition.to(dtype) + latent_condition = (latent_condition - latents_mean) * latents_std + + mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width) + mask_lat_size[:, :, list(range(1, num_frames))] = 0 + first_frame_mask = mask_lat_size[:, :, 0:1] + first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal) + mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) + mask_lat_size = mask_lat_size.view(batch_size, -1, self.vae_scale_factor_temporal, latent_height, latent_width) + mask_lat_size = mask_lat_size.transpose(1, 2) + mask_lat_size = mask_lat_size.to(latent_condition.device) + + return latents, torch.concat([mask_lat_size, latent_condition], dim=1) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput, + prompt: str | list[str] = None, + negative_prompt: str | list[str] = None, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + num_videos_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + image_embeds: torch.Tensor | None = None, + output_type: str | None = "np", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int, None], PipelineCallback | MultiPipelineCallbacks] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + enable_temporal_reasoning: bool = False, + num_temporal_reasoning_steps: int = 0, + ): + r""" + The call function to the pipeline for generation. + + Args: + image (`PipelineImageInput`): + The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`. + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, defaults to `480`): + The height of the generated video. + width (`int`, defaults to `832`): + The width of the generated video. + num_frames (`int`, defaults to `81`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to `5.0`): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `negative_prompt` input argument. + image_embeds (`torch.Tensor`, *optional*): + Pre-generated image embeddings. Can be used to easily tweak image inputs (weighting). If not provided, + image embeddings are generated from the `image` input argument. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`ChronoEditPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `512`): + The maximum sequence length of the text encoder. If the prompt is longer than this, it will be + truncated. If the prompt is shorter, it will be padded to this length. + enable_temporal_reasoning (`bool`, *optional*, defaults to `False`): + Whether to enable temporal reasoning. + num_temporal_reasoning_steps (`int`, *optional*, defaults to `0`): + The number of steps to enable temporal reasoning. + + Examples: + + Returns: + [`~ChronoEditPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`ChronoEditPipelineOutput`] is returned, otherwise a `tuple` is returned + where the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + negative_prompt, + image, + height, + width, + prompt_embeds, + negative_prompt_embeds, + image_embeds, + callback_on_step_end_tensor_inputs, + ) + + num_frames = 5 if not enable_temporal_reasoning else num_frames + + if num_frames % self.vae_scale_factor_temporal != 1: + logger.warning( + f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." + ) + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + num_frames = max(num_frames, 1) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + + # Encode image embedding + transformer_dtype = self.transformer.dtype + prompt_embeds = prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + + if image_embeds is None: + image_embeds = self.encode_image(image, device) + image_embeds = image_embeds.repeat(batch_size, 1, 1) + image_embeds = image_embeds.to(transformer_dtype) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.vae.config.z_dim + image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32) + latents, condition = self.prepare_latents( + image, + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + ) + + # 6. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + if enable_temporal_reasoning and i == num_temporal_reasoning_steps: + latents = latents[:, :, [0, -1]] + condition = condition[:, :, [0, -1]] + + for j in range(len(self.scheduler.model_outputs)): + if self.scheduler.model_outputs[j] is not None: + if latents.shape[-3] != self.scheduler.model_outputs[j].shape[-3]: + self.scheduler.model_outputs[j] = self.scheduler.model_outputs[j][:, :, [0, -1]] + if self.scheduler.last_sample is not None: + self.scheduler.last_sample = self.scheduler.last_sample[:, :, [0, -1]] + + self._current_timestep = t + latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype) + timestep = t.expand(latents.shape[0]) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + encoder_hidden_states_image=image_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states_image=image_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + if enable_temporal_reasoning and latents.shape[2] > 2: + video_edit = self.vae.decode(latents[:, :, [0, -1]], return_dict=False)[0] + video_reason = self.vae.decode(latents[:, :, :-1], return_dict=False)[0] + video = torch.cat([video_reason, video_edit[:, :, 1:]], dim=2) + else: + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return ChronoEditPipelineOutput(frames=video) diff --git a/diffusers/pipelines/chronoedit/pipeline_output.py b/diffusers/pipelines/chronoedit/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..6247ce9f3a0c57e5ee3e185d4c31211261e4379c --- /dev/null +++ b/diffusers/pipelines/chronoedit/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import torch + +from diffusers.utils import BaseOutput + + +@dataclass +class ChronoEditPipelineOutput(BaseOutput): + r""" + Output class for ChronoEdit pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or list[list[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + frames: torch.Tensor diff --git a/diffusers/pipelines/cogview3/__init__.py b/diffusers/pipelines/cogview3/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..50895251ba0b67dc2848b99654cbdd08312833f5 --- /dev/null +++ b/diffusers/pipelines/cogview3/__init__.py @@ -0,0 +1,47 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_additional_imports = {} +_import_structure = {"pipeline_output": ["CogView3PlusPipelineOutput"]} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_cogview3plus"] = ["CogView3PlusPipeline"] +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .pipeline_cogview3plus import CogView3PlusPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) + for name, value in _additional_imports.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/pipelines/cogview3/pipeline_cogview3plus.py b/diffusers/pipelines/cogview3/pipeline_cogview3plus.py new file mode 100644 index 0000000000000000000000000000000000000000..8880e3a0d1e2da02956d64853dbaca89429ced01 --- /dev/null +++ b/diffusers/pipelines/cogview3/pipeline_cogview3plus.py @@ -0,0 +1,686 @@ +# Copyright 2025 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable + +import torch +from transformers import T5EncoderModel, T5Tokenizer + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKL, CogView3PlusTransformer2DModel +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from .pipeline_output import CogView3PipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import CogView3PlusPipeline + + >>> pipe = CogView3PlusPipeline.from_pretrained("THUDM/CogView3-Plus-3B", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> prompt = "A photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt).images[0] + >>> image.save("output.png") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class CogView3PlusPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using CogView3Plus. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. CogView3Plus uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the + [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. + tokenizer (`T5Tokenizer`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`CogView3PlusTransformer2DModel`]): + A text conditioned `CogView3PlusTransformer2DModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + """ + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKL, + transformer: CogView3PlusTransformer2DModel, + scheduler: CogVideoXDDIMScheduler | CogVideoXDPMScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds with num_videos_per_prompt->num_images_per_prompt + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 226, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_images_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 224, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_images_per_prompt (`int`, *optional*, defaults to 1): + Number of images that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + max_sequence_length (`int`, defaults to `224`): + Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt is None: + negative_prompt_embeds = prompt_embeds.new_zeros(prompt_embeds.shape) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] | None = None, + negative_prompt: str | list[str] | None = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + timesteps: list[int] | None = None, + guidance_scale: float = 5.0, + num_images_per_prompt: int = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.FloatTensor | None = None, + prompt_embeds: torch.FloatTensor | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + original_size: tuple[int, int] | None = None, + crops_coords_top_left: tuple[int, int] = (0, 0), + output_type: str = "pil", + return_dict: bool = True, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 224, + ) -> CogView3PipelineOutput | tuple: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. If not provided, it is set to 1024. + width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. If not provided it is set to 1024. + num_inference_steps (`int`, *optional*, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to `5.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to `1`): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `224`): + Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results. + + Examples: + + Returns: + [`~pipelines.cogview3.pipeline_cogview3plus.CogView3PipelineOutput`] or `tuple`: + [`~pipelines.cogview3.pipeline_cogview3plus.CogView3PipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + height = height or self.transformer.config.sample_size * self.vae_scale_factor + width = width or self.transformer.config.sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + self.do_classifier_free_guidance, + num_images_per_prompt=num_images_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps + ) + self._num_timesteps = len(timesteps) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare additional timestep conditions + original_size = torch.tensor([original_size], dtype=prompt_embeds.dtype) + target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype) + crops_coords_top_left = torch.tensor([crops_coords_top_left], dtype=prompt_embeds.dtype) + + if self.do_classifier_free_guidance: + original_size = torch.cat([original_size, original_size]) + target_size = torch.cat([target_size, target_size]) + crops_coords_top_left = torch.cat([crops_coords_top_left, crops_coords_top_left]) + + original_size = original_size.to(device).repeat(batch_size * num_images_per_prompt, 1) + target_size = target_size.to(device).repeat(batch_size * num_images_per_prompt, 1) + crops_coords_top_left = crops_coords_top_left.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + # for DPM-solver++ + old_pred_original_sample = None + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # predict noise model_output + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + original_size=original_size, + target_size=target_size, + crop_coords=crops_coords_top_left, + return_dict=False, + )[0] + noise_pred = noise_pred.float() + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + if not isinstance(self.scheduler, CogVideoXDPMScheduler): + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + else: + latents, old_pred_original_sample = self.scheduler.step( + noise_pred, + old_pred_original_sample, + t, + timesteps[i - 1] if i > 0 else None, + latents, + **extra_step_kwargs, + return_dict=False, + ) + latents = latents.to(prompt_embeds.dtype) + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] + else: + image = latents + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return CogView3PipelineOutput(images=image) diff --git a/diffusers/pipelines/cogview3/pipeline_output.py b/diffusers/pipelines/cogview3/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..6c89e117b74ccc1edb6e098bbdc84098d978230a --- /dev/null +++ b/diffusers/pipelines/cogview3/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import numpy as np +import PIL.Image + +from ...utils import BaseOutput + + +@dataclass +class CogView3PipelineOutput(BaseOutput): + """ + Output class for CogView3 pipelines. + + Args: + images (`list[PIL.Image.Image]` or `np.ndarray`) + list of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: list[PIL.Image.Image] | np.ndarray diff --git a/diffusers/pipelines/cogview4/__init__.py b/diffusers/pipelines/cogview4/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6a365e17fee7185cbfb67e857ec2d265c35674ab --- /dev/null +++ b/diffusers/pipelines/cogview4/__init__.py @@ -0,0 +1,49 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_additional_imports = {} +_import_structure = {"pipeline_output": ["CogView4PlusPipelineOutput"]} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_cogview4"] = ["CogView4Pipeline"] + _import_structure["pipeline_cogview4_control"] = ["CogView4ControlPipeline"] +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .pipeline_cogview4 import CogView4Pipeline + from .pipeline_cogview4_control import CogView4ControlPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) + for name, value in _additional_imports.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/pipelines/cogview4/pipeline_cogview4.py b/diffusers/pipelines/cogview4/pipeline_cogview4.py new file mode 100644 index 0000000000000000000000000000000000000000..329b76d11e0d404935d39820d671ff413de489a8 --- /dev/null +++ b/diffusers/pipelines/cogview4/pipeline_cogview4.py @@ -0,0 +1,687 @@ +# Copyright 2025 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import torch +from transformers import AutoTokenizer, GlmModel + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import VaeImageProcessor +from ...loaders import CogView4LoraLoaderMixin +from ...models import AutoencoderKL, CogView4Transformer2DModel +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from .pipeline_output import CogView4PipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import CogView4Pipeline + + >>> pipe = CogView4Pipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> prompt = "A photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt).images[0] + >>> image.save("output.png") + ``` +""" + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + base_shift: float = 0.25, + max_shift: float = 0.75, +) -> float: + m = (image_seq_len / base_seq_len) ** 0.5 + mu = m * max_shift + base_shift + return mu + + +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + accepts_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + + if timesteps is not None and sigmas is not None: + if not accepts_timesteps and not accepts_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep or sigma schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif timesteps is not None and sigmas is None: + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif timesteps is None and sigmas is not None: + if not accepts_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin): + r""" + Pipeline for text-to-image generation using CogView4. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`GLMModel`]): + Frozen text-encoder. CogView4 uses [glm-4-9b-hf](https://huggingface.co/THUDM/glm-4-9b-hf). + tokenizer (`PreTrainedTokenizer`): + Tokenizer of class + [PreTrainedTokenizer](https://huggingface.co/docs/transformers/main/en/main_classes/tokenizer#transformers.PreTrainedTokenizer). + transformer ([`CogView4Transformer2DModel`]): + A text conditioned `CogView4Transformer2DModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + """ + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: GlmModel, + vae: AutoencoderKL, + transformer: CogView4Transformer2DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + def _get_glm_embeds( + self, + prompt: str | list[str] = None, + max_sequence_length: int = 1024, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + + text_inputs = self.tokenizer( + prompt, + padding="longest", # not use max length + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + current_length = text_input_ids.shape[1] + pad_length = (16 - (current_length % 16)) % 16 + if pad_length > 0: + pad_ids = torch.full( + (text_input_ids.shape[0], pad_length), + fill_value=self.tokenizer.pad_token_id, + dtype=text_input_ids.dtype, + device=text_input_ids.device, + ) + text_input_ids = torch.cat([pad_ids, text_input_ids], dim=1) + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=True).hidden_states[-2] + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + return prompt_embeds + + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_images_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + max_sequence_length: int = 1024, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_images_per_prompt (`int`, *optional*, defaults to 1): + Number of images that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + max_sequence_length (`int`, defaults to `1024`): + Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results. + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_glm_embeds(prompt, max_sequence_length, device, dtype) + + seq_len = prompt_embeds.size(1) + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_glm_embeds(negative_prompt, max_sequence_length, device, dtype) + + seq_len = negative_prompt_embeds.size(1) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + if latents is not None: + return latents.to(device) + + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape[0] != negative_prompt_embeds.shape[0]: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same batch size when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} and `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_embeds.shape[-1] != negative_prompt_embeds.shape[-1]: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same dimension when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} and `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] | None = None, + negative_prompt: str | list[str] | None = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + guidance_scale: float = 5.0, + num_images_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.FloatTensor | None = None, + prompt_embeds: torch.FloatTensor | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + original_size: tuple[int, int] | None = None, + crops_coords_top_left: tuple[int, int] = (0, 0), + output_type: str = "pil", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 1024, + ) -> CogView4PipelineOutput | tuple: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. If not provided, it is set to 1024. + width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. If not provided it is set to 1024. + num_inference_steps (`int`, *optional*, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to `5.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to `1`): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `224`): + Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results. + + Examples: + + Returns: + [`~pipelines.cogview4.pipeline_CogView4.CogView4PipelineOutput`] or `tuple`: + [`~pipelines.cogview4.pipeline_CogView4.CogView4PipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + height = height or self.transformer.config.sample_size * self.vae_scale_factor + width = width or self.transformer.config.sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = (height, width) + + # Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + self.do_classifier_free_guidance, + num_images_per_prompt=num_images_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + + # Prepare latents + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + height, + width, + torch.float32, + device, + generator, + latents, + ) + + # Prepare additional timestep conditions + original_size = torch.tensor([original_size], dtype=prompt_embeds.dtype, device=device) + target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype, device=device) + crops_coords_top_left = torch.tensor([crops_coords_top_left], dtype=prompt_embeds.dtype, device=device) + + original_size = original_size.repeat(batch_size * num_images_per_prompt, 1) + target_size = target_size.repeat(batch_size * num_images_per_prompt, 1) + crops_coords_top_left = crops_coords_top_left.repeat(batch_size * num_images_per_prompt, 1) + + # Prepare timesteps + image_seq_len = ((height // self.vae_scale_factor) * (width // self.vae_scale_factor)) // ( + self.transformer.config.patch_size**2 + ) + timesteps = ( + np.linspace(self.scheduler.config.num_train_timesteps, 1.0, num_inference_steps) + if timesteps is None + else np.array(timesteps) + ) + timesteps = timesteps.astype(np.int64).astype(np.float32) + sigmas = timesteps / self.scheduler.config.num_train_timesteps if sigmas is None else sigmas + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("base_shift", 0.25), + self.scheduler.config.get("max_shift", 0.75), + ) + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas, mu=mu + ) + self._num_timesteps = len(timesteps) + + # Denoising loop + transformer_dtype = self.transformer.dtype + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + latent_model_input = latents.to(transformer_dtype) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]) + + with self.transformer.cache_context("cond"): + noise_pred_cond = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + original_size=original_size, + target_size=target_size, + crop_coords=crops_coords_top_left, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + with self.transformer.cache_context("uncond"): + noise_pred_uncond = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=negative_prompt_embeds, + timestep=timestep, + original_size=original_size, + target_size=target_size, + crop_coords=crops_coords_top_left, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) + else: + noise_pred = noise_pred_cond + + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, self.scheduler.sigmas[i], callback_kwargs) + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor + image = self.vae.decode(latents, return_dict=False, generator=generator)[0] + else: + image = latents + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return CogView4PipelineOutput(images=image) diff --git a/diffusers/pipelines/cogview4/pipeline_cogview4_control.py b/diffusers/pipelines/cogview4/pipeline_cogview4_control.py new file mode 100644 index 0000000000000000000000000000000000000000..6282bf4cd7a405fd075daf5c1913f342f2c6753f --- /dev/null +++ b/diffusers/pipelines/cogview4/pipeline_cogview4_control.py @@ -0,0 +1,734 @@ +# Copyright 2025 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import torch +from transformers import AutoTokenizer, GlmModel + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...models import AutoencoderKL, CogView4Transformer2DModel +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from .pipeline_output import CogView4PipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import CogView4ControlPipeline + + >>> pipe = CogView4ControlPipeline.from_pretrained("THUDM/CogView4-6B-Control", torch_dtype=torch.bfloat16) + >>> control_image = load_image( + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png" + ... ) + >>> prompt = "A bird in space" + >>> image = pipe(prompt, control_image=control_image, height=1024, width=1024, guidance_scale=3.5).images[0] + >>> image.save("cogview4-control.png") + ``` +""" + + +# Copied from diffusers.pipelines.cogview4.pipeline_cogview4.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + base_shift: float = 0.25, + max_shift: float = 0.75, +) -> float: + m = (image_seq_len / base_seq_len) ** 0.5 + mu = m * max_shift + base_shift + return mu + + +# Copied from diffusers.pipelines.cogview4.pipeline_cogview4.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + accepts_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + + if timesteps is not None and sigmas is not None: + if not accepts_timesteps and not accepts_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep or sigma schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif timesteps is not None and sigmas is None: + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif timesteps is None and sigmas is not None: + if not accepts_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class CogView4ControlPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using CogView4. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`GLMModel`]): + Frozen text-encoder. CogView4 uses [glm-4-9b-hf](https://huggingface.co/THUDM/glm-4-9b-hf). + tokenizer (`PreTrainedTokenizer`): + Tokenizer of class + [PreTrainedTokenizer](https://huggingface.co/docs/transformers/main/en/main_classes/tokenizer#transformers.PreTrainedTokenizer). + transformer ([`CogView4Transformer2DModel`]): + A text conditioned `CogView4Transformer2DModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + """ + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: GlmModel, + vae: AutoencoderKL, + transformer: CogView4Transformer2DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + # Copied from diffusers.pipelines.cogview4.pipeline_cogview4.CogView4Pipeline._get_glm_embeds + def _get_glm_embeds( + self, + prompt: str | list[str] = None, + max_sequence_length: int = 1024, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + + text_inputs = self.tokenizer( + prompt, + padding="longest", # not use max length + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + current_length = text_input_ids.shape[1] + pad_length = (16 - (current_length % 16)) % 16 + if pad_length > 0: + pad_ids = torch.full( + (text_input_ids.shape[0], pad_length), + fill_value=self.tokenizer.pad_token_id, + dtype=text_input_ids.dtype, + device=text_input_ids.device, + ) + text_input_ids = torch.cat([pad_ids, text_input_ids], dim=1) + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=True).hidden_states[-2] + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + return prompt_embeds + + # Copied from diffusers.pipelines.cogview4.pipeline_cogview4.CogView4Pipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_images_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + max_sequence_length: int = 1024, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_images_per_prompt (`int`, *optional*, defaults to 1): + Number of images that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + max_sequence_length (`int`, defaults to `1024`): + Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results. + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_glm_embeds(prompt, max_sequence_length, device, dtype) + + seq_len = prompt_embeds.size(1) + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_glm_embeds(negative_prompt, max_sequence_length, device, dtype) + + seq_len = negative_prompt_embeds.size(1) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + if latents is not None: + return latents.to(device) + + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + if isinstance(image, torch.Tensor): + pass + else: + image = self.image_processor.preprocess(image, height=height, width=width) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0, output_size=image.shape[0] * repeat_by) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] | None = None, + negative_prompt: str | list[str] | None = None, + control_image: PipelineImageInput = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + guidance_scale: float = 5.0, + num_images_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.FloatTensor | None = None, + prompt_embeds: torch.FloatTensor | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + original_size: tuple[int, int] | None = None, + crops_coords_top_left: tuple[int, int] = (0, 0), + output_type: str = "pil", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 1024, + ) -> CogView4PipelineOutput | tuple: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. If not provided, it is set to 1024. + width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. If not provided it is set to 1024. + num_inference_steps (`int`, *optional*, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to `5.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to `1`): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.pipeline_CogView4.CogView4PipelineOutput`] instead of a plain + tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `224`): + Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results. + Examples: + + Returns: + [`~pipelines.cogview4.pipeline_CogView4.CogView4PipelineOutput`] or `tuple`: + [`~pipelines.cogview4.pipeline_CogView4.CogView4PipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + height = height or self.transformer.config.sample_size * self.vae_scale_factor + width = width or self.transformer.config.sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = (height, width) + + # Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + self.do_classifier_free_guidance, + num_images_per_prompt=num_images_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + + # Prepare latents + latent_channels = self.transformer.config.in_channels // 2 + + control_image = self.prepare_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.vae.dtype, + ) + height, width = control_image.shape[-2:] + + vae_shift_factor = 0 + + control_image = self.vae.encode(control_image).latent_dist.sample() + control_image = (control_image - vae_shift_factor) * self.vae.config.scaling_factor + + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + height, + width, + torch.float32, + device, + generator, + latents, + ) + + # Prepare additional timestep conditions + original_size = torch.tensor([original_size], dtype=prompt_embeds.dtype, device=device) + target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype, device=device) + crops_coords_top_left = torch.tensor([crops_coords_top_left], dtype=prompt_embeds.dtype, device=device) + + original_size = original_size.repeat(batch_size * num_images_per_prompt, 1) + target_size = target_size.repeat(batch_size * num_images_per_prompt, 1) + crops_coords_top_left = crops_coords_top_left.repeat(batch_size * num_images_per_prompt, 1) + + # Prepare timesteps + image_seq_len = ((height // self.vae_scale_factor) * (width // self.vae_scale_factor)) // ( + self.transformer.config.patch_size**2 + ) + + timesteps = ( + np.linspace(self.scheduler.config.num_train_timesteps, 1.0, num_inference_steps) + if timesteps is None + else np.array(timesteps) + ) + timesteps = timesteps.astype(np.int64).astype(np.float32) + sigmas = timesteps / self.scheduler.config.num_train_timesteps if sigmas is None else sigmas + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("base_shift", 0.25), + self.scheduler.config.get("max_shift", 0.75), + ) + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas, mu=mu + ) + self._num_timesteps = len(timesteps) + # Denoising loop + transformer_dtype = self.transformer.dtype + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + latent_model_input = torch.cat([latents, control_image], dim=1).to(transformer_dtype) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]) + + noise_pred_cond = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + original_size=original_size, + target_size=target_size, + crop_coords=crops_coords_top_left, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=negative_prompt_embeds, + timestep=timestep, + original_size=original_size, + target_size=target_size, + crop_coords=crops_coords_top_left, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) + else: + noise_pred = noise_pred_cond + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, self.scheduler.sigmas[i], callback_kwargs) + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor + image = self.vae.decode(latents, return_dict=False, generator=generator)[0] + else: + image = latents + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return CogView4PipelineOutput(images=image) diff --git a/diffusers/pipelines/cogview4/pipeline_output.py b/diffusers/pipelines/cogview4/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..997444c6c009a27bf944f84dbe6025b4345f87e1 --- /dev/null +++ b/diffusers/pipelines/cogview4/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import numpy as np +import PIL.Image + +from ...utils import BaseOutput + + +@dataclass +class CogView4PipelineOutput(BaseOutput): + """ + Output class for CogView3 pipelines. + + Args: + images (`list[PIL.Image.Image]` or `np.ndarray`) + list of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: list[PIL.Image.Image] | np.ndarray diff --git a/diffusers/pipelines/consisid/__init__.py b/diffusers/pipelines/consisid/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7b9ba330fbd129d104180a2c0e5d05672b48fa47 --- /dev/null +++ b/diffusers/pipelines/consisid/__init__.py @@ -0,0 +1,49 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_opencv_available, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available() and is_opencv_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_and_opencv_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_opencv_objects)) +else: + _import_structure["pipeline_consisid"] = ["ConsisIDPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_consisid import ConsisIDPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/pipelines/consisid/consisid_utils.py b/diffusers/pipelines/consisid/consisid_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c1646e15efbc48c36803ac71b4b6ce198075b35b --- /dev/null +++ b/diffusers/pipelines/consisid/consisid_utils.py @@ -0,0 +1,357 @@ +import importlib.util +import os + +import cv2 +import numpy as np +import torch +from PIL import Image, ImageOps +from torchvision.transforms import InterpolationMode +from torchvision.transforms.functional import normalize, resize + +from ...utils import get_logger, load_image + + +logger = get_logger(__name__) + +_insightface_available = importlib.util.find_spec("insightface") is not None +_consisid_eva_clip_available = importlib.util.find_spec("consisid_eva_clip") is not None +_facexlib_available = importlib.util.find_spec("facexlib") is not None + +if _insightface_available: + import insightface + from insightface.app import FaceAnalysis +else: + raise ImportError("insightface is not available. Please install it using 'pip install insightface'.") + +if _consisid_eva_clip_available: + from consisid_eva_clip import create_model_and_transforms + from consisid_eva_clip.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD +else: + raise ImportError("consisid_eva_clip is not available. Please install it using 'pip install consisid_eva_clip'.") + +if _facexlib_available: + from facexlib.parsing import init_parsing_model + from facexlib.utils.face_restoration_helper import FaceRestoreHelper +else: + raise ImportError("facexlib is not available. Please install it using 'pip install facexlib'.") + + +def resize_numpy_image_long(image, resize_long_edge=768): + """ + Resize the input image to a specified long edge while maintaining aspect ratio. + + Args: + image (numpy.ndarray): Input image (H x W x C or H x W). + resize_long_edge (int): The target size for the long edge of the image. Default is 768. + + Returns: + numpy.ndarray: Resized image with the long edge matching `resize_long_edge`, while maintaining the aspect + ratio. + """ + + h, w = image.shape[:2] + if max(h, w) <= resize_long_edge: + return image + k = resize_long_edge / max(h, w) + h = int(h * k) + w = int(w * k) + image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LANCZOS4) + return image + + +def img2tensor(imgs, bgr2rgb=True, float32=True): + """Numpy array to tensor. + + Args: + imgs (list[ndarray] | ndarray): Input images. + bgr2rgb (bool): Whether to change bgr to rgb. + float32 (bool): Whether to change to float32. + + Returns: + list[tensor] | tensor: Tensor images. If returned results only have + one element, just return tensor. + """ + + def _totensor(img, bgr2rgb, float32): + if img.shape[2] == 3 and bgr2rgb: + if img.dtype == "float64": + img = img.astype("float32") + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = torch.from_numpy(img.transpose(2, 0, 1)) + if float32: + img = img.float() + return img + + if isinstance(imgs, list): + return [_totensor(img, bgr2rgb, float32) for img in imgs] + return _totensor(imgs, bgr2rgb, float32) + + +def to_gray(img): + """ + Converts an RGB image to grayscale by applying the standard luminosity formula. + + Args: + img (torch.Tensor): The input image tensor with shape (batch_size, channels, height, width). + The image is expected to be in RGB format (3 channels). + + Returns: + torch.Tensor: The grayscale image tensor with shape (batch_size, 3, height, width). + The grayscale values are replicated across all three channels. + """ + x = 0.299 * img[:, 0:1] + 0.587 * img[:, 1:2] + 0.114 * img[:, 2:3] + x = x.repeat(1, 3, 1, 1) + return x + + +def process_face_embeddings( + face_helper_1, + clip_vision_model, + face_helper_2, + eva_transform_mean, + eva_transform_std, + app, + device, + weight_dtype, + image, + original_id_image=None, + is_align_face=True, +): + """ + Process face embeddings from an image, extracting relevant features such as face embeddings, landmarks, and parsed + face features using a series of face detection and alignment tools. + + Args: + face_helper_1: Face helper object (first helper) for alignment and landmark detection. + clip_vision_model: Pre-trained CLIP vision model used for feature extraction. + face_helper_2: Face helper object (second helper) for embedding extraction. + eva_transform_mean: Mean values for image normalization before passing to EVA model. + eva_transform_std: Standard deviation values for image normalization before passing to EVA model. + app: Application instance used for face detection. + device: Device (CPU or GPU) where the computations will be performed. + weight_dtype: Data type of the weights for precision (e.g., `torch.float32`). + image: Input image in RGB format with pixel values in the range [0, 255]. + original_id_image: (Optional) Original image for feature extraction if `is_align_face` is False. + is_align_face: Boolean flag indicating whether face alignment should be performed. + + Returns: + tuple: + - id_cond: Concatenated tensor of Ante face embedding and CLIP vision embedding + - id_vit_hidden: Hidden state of the CLIP vision model, a list of tensors. + - return_face_features_image_2: Processed face features image after normalization and parsing. + - face_kps: Keypoints of the face detected in the image. + """ + + face_helper_1.clean_all() + image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + # get antelopev2 embedding + face_info = app.get(image_bgr) + if len(face_info) > 0: + face_info = sorted(face_info, key=lambda x: (x["bbox"][2] - x["bbox"][0]) * (x["bbox"][3] - x["bbox"][1]))[ + -1 + ] # only use the maximum face + id_ante_embedding = face_info["embedding"] # (512,) + face_kps = face_info["kps"] + else: + id_ante_embedding = None + face_kps = None + + # using facexlib to detect and align face + face_helper_1.read_image(image_bgr) + face_helper_1.get_face_landmarks_5(only_center_face=True) + if face_kps is None: + face_kps = face_helper_1.all_landmarks_5[0] + face_helper_1.align_warp_face() + if len(face_helper_1.cropped_faces) == 0: + raise RuntimeError("facexlib align face fail") + align_face = face_helper_1.cropped_faces[0] # (512, 512, 3) # RGB + + # in case insightface didn't detect face + if id_ante_embedding is None: + logger.warning("Failed to detect face using insightface. Extracting embedding with align face") + id_ante_embedding = face_helper_2.get_feat(align_face) + + id_ante_embedding = torch.from_numpy(id_ante_embedding).to(device, weight_dtype) # torch.Size([512]) + if id_ante_embedding.ndim == 1: + id_ante_embedding = id_ante_embedding.unsqueeze(0) # torch.Size([1, 512]) + + # parsing + if is_align_face: + input = img2tensor(align_face, bgr2rgb=True).unsqueeze(0) / 255.0 # torch.Size([1, 3, 512, 512]) + input = input.to(device) + parsing_out = face_helper_1.face_parse(normalize(input, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))[0] + parsing_out = parsing_out.argmax(dim=1, keepdim=True) # torch.Size([1, 1, 512, 512]) + bg_label = [0, 16, 18, 7, 8, 9, 14, 15] + bg = sum(parsing_out == i for i in bg_label).bool() + white_image = torch.ones_like(input) # torch.Size([1, 3, 512, 512]) + # only keep the face features + return_face_features_image = torch.where(bg, white_image, to_gray(input)) # torch.Size([1, 3, 512, 512]) + return_face_features_image_2 = torch.where(bg, white_image, input) # torch.Size([1, 3, 512, 512]) + else: + original_image_bgr = cv2.cvtColor(original_id_image, cv2.COLOR_RGB2BGR) + input = img2tensor(original_image_bgr, bgr2rgb=True).unsqueeze(0) / 255.0 # torch.Size([1, 3, 512, 512]) + input = input.to(device) + return_face_features_image = return_face_features_image_2 = input + + # transform img before sending to eva-clip-vit + face_features_image = resize( + return_face_features_image, clip_vision_model.image_size, InterpolationMode.BICUBIC + ) # torch.Size([1, 3, 336, 336]) + face_features_image = normalize(face_features_image, eva_transform_mean, eva_transform_std) + id_cond_vit, id_vit_hidden = clip_vision_model( + face_features_image.to(weight_dtype), return_all_features=False, return_hidden=True, shuffle=False + ) # torch.Size([1, 768]), list(torch.Size([1, 577, 1024])) + id_cond_vit_norm = torch.norm(id_cond_vit, 2, 1, True) + id_cond_vit = torch.div(id_cond_vit, id_cond_vit_norm) + + id_cond = torch.cat( + [id_ante_embedding, id_cond_vit], dim=-1 + ) # torch.Size([1, 512]), torch.Size([1, 768]) -> torch.Size([1, 1280]) + + return ( + id_cond, + id_vit_hidden, + return_face_features_image_2, + face_kps, + ) # torch.Size([1, 1280]), list(torch.Size([1, 577, 1024])) + + +def process_face_embeddings_infer( + face_helper_1, + clip_vision_model, + face_helper_2, + eva_transform_mean, + eva_transform_std, + app, + device, + weight_dtype, + img_file_path, + is_align_face=True, +): + """ + Process face embeddings from an input image for inference, including alignment, feature extraction, and embedding + concatenation. + + Args: + face_helper_1: Face helper object (first helper) for alignment and landmark detection. + clip_vision_model: Pre-trained CLIP vision model used for feature extraction. + face_helper_2: Face helper object (second helper) for embedding extraction. + eva_transform_mean: Mean values for image normalization before passing to EVA model. + eva_transform_std: Standard deviation values for image normalization before passing to EVA model. + app: Application instance used for face detection. + device: Device (CPU or GPU) where the computations will be performed. + weight_dtype: Data type of the weights for precision (e.g., `torch.float32`). + img_file_path: Path to the input image file (string) or a numpy array representing an image. + is_align_face: Boolean flag indicating whether face alignment should be performed (default: True). + + Returns: + tuple: + - id_cond: Concatenated tensor of Ante face embedding and CLIP vision embedding. + - id_vit_hidden: Hidden state of the CLIP vision model, a list of tensors. + - image: Processed face image after feature extraction and alignment. + - face_kps: Keypoints of the face detected in the image. + """ + + # Load and preprocess the input image + if isinstance(img_file_path, str): + image = np.array(load_image(image=img_file_path).convert("RGB")) + else: + image = np.array(ImageOps.exif_transpose(Image.fromarray(img_file_path)).convert("RGB")) + + # Resize image to ensure the longer side is 1024 pixels + image = resize_numpy_image_long(image, 1024) + original_id_image = image + + # Process the image to extract face embeddings and related features + id_cond, id_vit_hidden, align_crop_face_image, face_kps = process_face_embeddings( + face_helper_1, + clip_vision_model, + face_helper_2, + eva_transform_mean, + eva_transform_std, + app, + device, + weight_dtype, + image, + original_id_image, + is_align_face, + ) + + # Convert the aligned cropped face image (torch tensor) to a numpy array + tensor = align_crop_face_image.cpu().detach() + tensor = tensor.squeeze() + tensor = tensor.permute(1, 2, 0) + tensor = tensor.numpy() * 255 + tensor = tensor.astype(np.uint8) + image = ImageOps.exif_transpose(Image.fromarray(tensor)) + + return id_cond, id_vit_hidden, image, face_kps + + +def prepare_face_models(model_path, device, dtype): + """ + Prepare all face models for the facial recognition task. + + Parameters: + - model_path: Path to the directory containing model files. + - device: The device (e.g., 'cuda', 'xpu', 'cpu') where models will be loaded. + - dtype: Data type (e.g., torch.float32) for model inference. + + Returns: + - face_helper_1: First face restoration helper. + - face_helper_2: Second face restoration helper. + - face_clip_model: CLIP model for face extraction. + - eva_transform_mean: Mean value for image normalization. + - eva_transform_std: Standard deviation value for image normalization. + - face_main_model: Main face analysis model. + """ + # get helper model + face_helper_1 = FaceRestoreHelper( + upscale_factor=1, + face_size=512, + crop_ratio=(1, 1), + det_model="retinaface_resnet50", + save_ext="png", + device=device, + model_rootpath=os.path.join(model_path, "face_encoder"), + ) + face_helper_1.face_parse = None + face_helper_1.face_parse = init_parsing_model( + model_name="bisenet", device=device, model_rootpath=os.path.join(model_path, "face_encoder") + ) + face_helper_2 = insightface.model_zoo.get_model( + f"{model_path}/face_encoder/models/antelopev2/glintr100.onnx", providers=["CUDAExecutionProvider"] + ) + face_helper_2.prepare(ctx_id=0) + + # get local facial extractor part 1 + model, _, _ = create_model_and_transforms( + "EVA02-CLIP-L-14-336", + os.path.join(model_path, "face_encoder", "EVA02_CLIP_L_336_psz14_s6B.pt"), + force_custom_clip=True, + ) + face_clip_model = model.visual + eva_transform_mean = getattr(face_clip_model, "image_mean", OPENAI_DATASET_MEAN) + eva_transform_std = getattr(face_clip_model, "image_std", OPENAI_DATASET_STD) + if not isinstance(eva_transform_mean, (list, tuple)): + eva_transform_mean = (eva_transform_mean,) * 3 + if not isinstance(eva_transform_std, (list, tuple)): + eva_transform_std = (eva_transform_std,) * 3 + eva_transform_mean = eva_transform_mean + eva_transform_std = eva_transform_std + + # get local facial extractor part 2 + face_main_model = FaceAnalysis( + name="antelopev2", root=os.path.join(model_path, "face_encoder"), providers=["CUDAExecutionProvider"] + ) + face_main_model.prepare(ctx_id=0, det_size=(640, 640)) + + # move face models to device + face_helper_1.face_det.eval() + face_helper_1.face_parse.eval() + face_clip_model.eval() + face_helper_1.face_det.to(device) + face_helper_1.face_parse.to(device) + face_clip_model.to(device, dtype=dtype) + + return face_helper_1, face_helper_2, face_clip_model, face_main_model, eva_transform_mean, eva_transform_std diff --git a/diffusers/pipelines/consisid/pipeline_consisid.py b/diffusers/pipelines/consisid/pipeline_consisid.py new file mode 100644 index 0000000000000000000000000000000000000000..20b779bf5aaa50a028315d8001d5fa430a69bd86 --- /dev/null +++ b/diffusers/pipelines/consisid/pipeline_consisid.py @@ -0,0 +1,972 @@ +# Copyright 2025 ConsisID Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from typing import Any, Callable + +import numpy as np +import PIL +import torch +from transformers import T5EncoderModel, T5Tokenizer + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput +from ...loaders import CogVideoXLoraLoaderMixin +from ...models import AutoencoderKLCogVideoX, ConsisIDTransformer3DModel +from ...models.embeddings import get_3d_rotary_pos_embed +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import CogVideoXDPMScheduler +from ...utils import is_opencv_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from .pipeline_output import ConsisIDPipelineOutput + + +if is_opencv_available(): + import cv2 + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import ConsisIDPipeline + >>> from diffusers.pipelines.consisid.consisid_utils import prepare_face_models, process_face_embeddings_infer + >>> from diffusers.utils import export_to_video + >>> from huggingface_hub import snapshot_download + + >>> snapshot_download(repo_id="BestWishYsh/ConsisID-preview", local_dir="BestWishYsh/ConsisID-preview") + >>> ( + ... face_helper_1, + ... face_helper_2, + ... face_clip_model, + ... face_main_model, + ... eva_transform_mean, + ... eva_transform_std, + ... ) = prepare_face_models("BestWishYsh/ConsisID-preview", device="cuda", dtype=torch.bfloat16) + >>> pipe = ConsisIDPipeline.from_pretrained("BestWishYsh/ConsisID-preview", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> # ConsisID works well with long and well-described prompts. Make sure the face in the image is clearly visible (e.g., preferably half-body or full-body). + >>> prompt = "The video captures a boy walking along a city street, filmed in black and white on a classic 35mm camera. His expression is thoughtful, his brow slightly furrowed as if he's lost in contemplation. The film grain adds a textured, timeless quality to the image, evoking a sense of nostalgia. Around him, the cityscape is filled with vintage buildings, cobblestone sidewalks, and softly blurred figures passing by, their outlines faint and indistinct. Streetlights cast a gentle glow, while shadows play across the boy's path, adding depth to the scene. The lighting highlights the boy's subtle smile, hinting at a fleeting moment of curiosity. The overall cinematic atmosphere, complete with classic film still aesthetics and dramatic contrasts, gives the scene an evocative and introspective feel." + >>> image = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/consisid/consisid_input.png?download=true" + + >>> id_cond, id_vit_hidden, image, face_kps = process_face_embeddings_infer( + ... face_helper_1, + ... face_clip_model, + ... face_helper_2, + ... eva_transform_mean, + ... eva_transform_std, + ... face_main_model, + ... "cuda", + ... torch.bfloat16, + ... image, + ... is_align_face=True, + ... ) + + >>> video = pipe( + ... image=image, + ... prompt=prompt, + ... num_inference_steps=50, + ... guidance_scale=6.0, + ... use_dynamic_cfg=False, + ... id_vit_hidden=id_vit_hidden, + ... id_cond=id_cond, + ... kps_cond=face_kps, + ... generator=torch.Generator("cuda").manual_seed(42), + ... ) + >>> export_to_video(video.frames[0], "output.mp4", fps=8) + ``` +""" + + +def draw_kps(image_pil, kps, color_list=[(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)]): + """ + This function draws keypoints and the limbs connecting them on an image. + + Parameters: + - image_pil (PIL.Image): Input image as a PIL object. + - kps (list of tuples): A list of keypoints where each keypoint is a tuple of (x, y) coordinates. + - color_list (list of tuples, optional): list of colors (in RGB format) for each keypoint. Default is a set of five + colors. + + Returns: + - PIL.Image: Image with the keypoints and limbs drawn. + """ + + stickwidth = 4 + limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]]) + kps = np.array(kps) + + w, h = image_pil.size + out_img = np.zeros([h, w, 3]) + + for i in range(len(limbSeq)): + index = limbSeq[i] + color = color_list[index[0]] + + x = kps[index][:, 0] + y = kps[index][:, 1] + length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5 + angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1])) + polygon = cv2.ellipse2Poly( + (int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1 + ) + out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color) + out_img = (out_img * 0.6).astype(np.uint8) + + for idx_kp, kp in enumerate(kps): + color = color_list[idx_kp] + x, y = kp + out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1) + + out_img_pil = PIL.Image.fromarray(out_img.astype(np.uint8)) + return out_img_pil + + +# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid +def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): + """ + This function calculates the resize and crop region for an image to fit a target width and height while preserving + the aspect ratio. + + Parameters: + - src (tuple): A tuple containing the source image's height (h) and width (w). + - tgt_width (int): The target width to resize the image. + - tgt_height (int): The target height to resize the image. + + Returns: + - tuple: Two tuples representing the crop region: + 1. The top-left coordinates of the crop region. + 2. The bottom-right coordinates of the crop region. + """ + + tw = tgt_width + th = tgt_height + h, w = src + r = h / w + if r > (th / tw): + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class ConsisIDPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): + r""" + Pipeline for image-to-video generation using ConsisID. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. ConsisID uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the + [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. + tokenizer (`T5Tokenizer`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`ConsisIDTransformer3DModel`]): + A text conditioned `ConsisIDTransformer3DModel` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + """ + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKLCogVideoX, + transformer: ConsisIDTransformer3DModel, + scheduler: CogVideoXDPMScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + vae=vae, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor_spatial = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.vae_scale_factor_temporal = ( + self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 + ) + self.vae_scaling_factor_image = ( + self.vae.config.scaling_factor if hasattr(self, "vae") and self.vae is not None else 0.7 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 226, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, + image: torch.Tensor, + batch_size: int = 1, + num_channels_latents: int = 16, + num_frames: int = 13, + height: int = 60, + width: int = 90, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | None = None, + latents: torch.Tensor | None = None, + kps_cond: torch.Tensor | None = None, + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_frames, + num_channels_latents, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + + image = image.unsqueeze(2) # [B, C, F, H, W] + + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i].unsqueeze(0)), generator[i]) for i in range(batch_size) + ] + if kps_cond is not None: + kps_cond = kps_cond.unsqueeze(2) + kps_cond_latents = [ + retrieve_latents(self.vae.encode(kps_cond[i].unsqueeze(0)), generator[i]) + for i in range(batch_size) + ] + else: + image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in image] + if kps_cond is not None: + kps_cond = kps_cond.unsqueeze(2) + kps_cond_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in kps_cond] + + image_latents = torch.cat(image_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W] + image_latents = self.vae_scaling_factor_image * image_latents + + if kps_cond is not None: + kps_cond_latents = torch.cat(kps_cond_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W] + kps_cond_latents = self.vae_scaling_factor_image * kps_cond_latents + + padding_shape = ( + batch_size, + num_frames - 2, + num_channels_latents, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + else: + padding_shape = ( + batch_size, + num_frames - 1, + num_channels_latents, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + + latent_padding = torch.zeros(padding_shape, device=device, dtype=dtype) + if kps_cond is not None: + image_latents = torch.cat([image_latents, kps_cond_latents, latent_padding], dim=1) + else: + image_latents = torch.cat([image_latents, latent_padding], dim=1) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents, image_latents + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.decode_latents + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width] + latents = 1 / self.vae_scaling_factor_image * latents + + frames = self.vae.decode(latents).sample + return frames + + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff_video2video.AnimateDiffVideoToVideoPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, timesteps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + image, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + latents=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if ( + not isinstance(image, torch.Tensor) + and not isinstance(image, PIL.Image.Image) + and not isinstance(image, list) + ): + raise ValueError( + "`image` has to be of type `torch.Tensor` or `PIL.Image.Image` or `list[PIL.Image.Image]` but is" + f" {type(image)}" + ) + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def _prepare_rotary_positional_embeddings( + self, + height: int, + width: int, + num_frames: int, + device: torch.device, + ) -> tuple[torch.Tensor, torch.Tensor]: + grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + base_size_width = self.transformer.config.sample_width // self.transformer.config.patch_size + base_size_height = self.transformer.config.sample_height // self.transformer.config.patch_size + + grid_crops_coords = get_resize_crop_region_for_grid( + (grid_height, grid_width), base_size_width, base_size_height + ) + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=grid_crops_coords, + grid_size=(grid_height, grid_width), + temporal_size=num_frames, + device=device, + ) + + return freqs_cos, freqs_sin + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput, + prompt: str | list[str] | None = None, + negative_prompt: str | list[str] | None = None, + height: int = 480, + width: int = 720, + num_frames: int = 49, + num_inference_steps: int = 50, + guidance_scale: float = 6.0, + use_dynamic_cfg: bool = False, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.FloatTensor | None = None, + prompt_embeds: torch.FloatTensor | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + output_type: str = "pil", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 226, + id_vit_hidden: torch.Tensor | None = None, + id_cond: torch.Tensor | None = None, + kps_cond: torch.Tensor | None = None, + ) -> ConsisIDPipelineOutput | tuple: + """ + Function invoked when calling the pipeline for generation. + + Args: + image (`PipelineImageInput`): + The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`. + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial): + The height in pixels of the generated image. This is set to 480 by default for the best results. + width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial): + The width in pixels of the generated image. This is set to 720 by default for the best results. + num_frames (`int`, defaults to `49`): + Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will + contain 1 extra frame because ConsisID is conditioned with (num_seconds * fps + 1) frames where + num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that + needs to be satisfied is that of divisibility mentioned above. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 6): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + use_dynamic_cfg (`bool`, *optional*, defaults to `False`): + If True, dynamically adjusts the guidance scale during inference. This allows the model to use a + progressive guidance scale, improving the balance between text-guided generation and image quality over + the course of the inference steps. Typically, early inference steps use a higher guidance scale for + more faithful image generation, while later steps reduce it for more diverse and natural results. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `226`): + Maximum sequence length in encoded prompt. Must be consistent with + `self.transformer.config.max_text_seq_length` otherwise may lead to poor results. + id_vit_hidden (`torch.Tensor | None`, *optional*): + The tensor representing the hidden features extracted from the face model, which are used to condition + the local facial extractor. This is crucial for the model to obtain high-frequency information of the + face. If not provided, the local facial extractor will not run normally. + id_cond (`torch.Tensor | None`, *optional*): + The tensor representing the hidden features extracted from the clip model, which are used to condition + the local facial extractor. This is crucial for the model to edit facial features If not provided, the + local facial extractor will not run normally. + kps_cond (`torch.Tensor | None`, *optional*): + A tensor that determines whether the global facial extractor use keypoint information for conditioning. + If provided, this tensor controls whether facial keypoints such as eyes, nose, and mouth landmarks are + used during the generation process. This helps ensure the model retains more facial low-frequency + information. + + Examples: + + Returns: + [`~pipelines.consisid.pipeline_output.ConsisIDPipelineOutput`] or `tuple`: + [`~pipelines.consisid.pipeline_output.ConsisIDPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial + width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial + num_frames = num_frames or self.transformer.config.sample_frames + + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + image=image, + prompt=prompt, + height=height, + width=width, + negative_prompt=negative_prompt, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + latents=latents, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device) + self._num_timesteps = len(timesteps) + + # 5. Prepare latents + is_kps = getattr(self.transformer.config, "is_kps", False) + kps_cond = kps_cond if is_kps else None + if kps_cond is not None: + kps_cond = draw_kps(image, kps_cond) + kps_cond = self.video_processor.preprocess(kps_cond, height=height, width=width).to( + device, dtype=prompt_embeds.dtype + ) + + image = self.video_processor.preprocess(image, height=height, width=width).to( + device, dtype=prompt_embeds.dtype + ) + + latent_channels = self.transformer.config.in_channels // 2 + latents, image_latents = self.prepare_latents( + image, + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + kps_cond, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Create rotary embeds if required + image_rotary_emb = ( + self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device) + if self.transformer.config.use_rotary_positional_embeddings + else None + ) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + # for DPM-solver++ + old_pred_original_sample = None + timesteps_cpu = timesteps.cpu() + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + latent_image_input = torch.cat([image_latents] * 2) if do_classifier_free_guidance else image_latents + latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=2) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # predict noise model_output + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + attention_kwargs=attention_kwargs, + return_dict=False, + id_vit_hidden=id_vit_hidden, + id_cond=id_cond, + )[0] + noise_pred = noise_pred.float() + + # perform guidance + if use_dynamic_cfg: + self._guidance_scale = 1 + guidance_scale * ( + ( + 1 + - math.cos( + math.pi + * ((num_inference_steps - timesteps_cpu[i].item()) / num_inference_steps) ** 5.0 + ) + ) + / 2 + ) + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + if not isinstance(self.scheduler, CogVideoXDPMScheduler): + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + else: + latents, old_pred_original_sample = self.scheduler.step( + noise_pred, + old_pred_original_sample, + t, + timesteps[i - 1] if i > 0 else None, + latents, + **extra_step_kwargs, + return_dict=False, + ) + latents = latents.to(prompt_embeds.dtype) + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if not output_type == "latent": + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return ConsisIDPipelineOutput(frames=video) diff --git a/diffusers/pipelines/consisid/pipeline_output.py b/diffusers/pipelines/consisid/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..83a5be8d230ba998c8ccd6f450a9d679592a9894 --- /dev/null +++ b/diffusers/pipelines/consisid/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import torch + +from diffusers.utils import BaseOutput + + +@dataclass +class ConsisIDPipelineOutput(BaseOutput): + r""" + Output class for ConsisID pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or list[list[PIL.Image.Image]]): + list of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + frames: torch.Tensor diff --git a/diffusers/pipelines/controlnet_xs/__init__.py b/diffusers/pipelines/controlnet_xs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..978278b184f985a452f9d518a1d0eb4f271c74fd --- /dev/null +++ b/diffusers/pipelines/controlnet_xs/__init__.py @@ -0,0 +1,68 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_flax_available, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_controlnet_xs"] = ["StableDiffusionControlNetXSPipeline"] + _import_structure["pipeline_controlnet_xs_sd_xl"] = ["StableDiffusionXLControlNetXSPipeline"] +try: + if not (is_transformers_available() and is_flax_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_flax_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects)) +else: + pass # _import_structure["pipeline_flax_controlnet"] = ["FlaxStableDiffusionControlNetPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_controlnet_xs import StableDiffusionControlNetXSPipeline + from .pipeline_controlnet_xs_sd_xl import StableDiffusionXLControlNetXSPipeline + + try: + if not (is_transformers_available() and is_flax_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_flax_and_transformers_objects import * # noqa F403 + else: + pass # from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline + + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py b/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py new file mode 100644 index 0000000000000000000000000000000000000000..9c81eb57e6c53356a7605a0499131b6b07026ca6 --- /dev/null +++ b/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py @@ -0,0 +1,927 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ControlNetXSAdapter, UNet2DConditionModel, UNetControlNetXSModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import empty_device_cache, is_compiled_module, is_torch_version, randn_tensor +from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # !pip install opencv-python transformers accelerate + >>> from diffusers import StableDiffusionControlNetXSPipeline, ControlNetXSAdapter + >>> from diffusers.utils import load_image + >>> import numpy as np + >>> import torch + + >>> import cv2 + >>> from PIL import Image + + >>> prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting" + >>> negative_prompt = "low quality, bad quality, sketches" + + >>> # download an image + >>> image = load_image( + ... "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png" + ... ) + + >>> # initialize the models and pipeline + >>> controlnet_conditioning_scale = 0.5 + + >>> controlnet = ControlNetXSAdapter.from_pretrained( + ... "UmerHA/Testing-ConrolNetXS-SD2.1-canny", torch_dtype=torch.float16 + ... ) + >>> pipe = StableDiffusionControlNetXSPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-2-1-base", controlnet=controlnet, torch_dtype=torch.float16 + ... ) + >>> pipe.enable_model_cpu_offload() + + >>> # get canny image + >>> image = np.array(image) + >>> image = cv2.Canny(image, 100, 200) + >>> image = image[:, :, None] + >>> image = np.concatenate([image, image, image], axis=2) + >>> canny_image = Image.fromarray(image) + >>> # generate image + >>> image = pipe( + ... prompt, controlnet_conditioning_scale=controlnet_conditioning_scale, image=canny_image + ... ).images[0] + ``` +""" + + +class StableDiffusionControlNetXSPipeline( + DeprecatedPipelineMixin, + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionLoraLoaderMixin, + FromSingleFileMixin, +): + r""" + Pipeline for text-to-image generation using Stable Diffusion with ControlNet-XS guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A [`UNet2DConditionModel`] used to create a UNetControlNetXSModel to denoise the encoded image latents. + controlnet ([`ControlNetXSAdapter`]): + A [`ControlNetXSAdapter`] to be used in combination with `unet` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for + more details about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + _last_supported_version = "0.33.1" + model_cpu_offload_seq = "text_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor"] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel | UNetControlNetXSModel, + controlnet: ControlNetXSAdapter, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if isinstance(unet, UNet2DConditionModel): + unet = UNetControlNetXSModel.from_unet(unet, controlnet) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, + ): + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Check `image` and `controlnet_conditioning_scale` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.unet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.unet, UNetControlNetXSModel) + or is_compiled + and isinstance(self.unet._orig_mod, UNetControlNetXSModel) + ): + self.check_image(image, prompt, prompt_embeds) + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + else: + assert False + + start, end = control_guidance_start, control_guidance_end + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + ): + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance: + image = torch.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.guidance_scale + def guidance_scale(self): + return self._guidance_scale + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.clip_skip + def clip_skip(self): + return self._clip_skip + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.do_classifier_free_guidance + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.cross_attention_kwargs + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.num_timesteps + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + image: PipelineImageInput = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + controlnet_conditioning_scale: float | list[float] = 1.0, + control_guidance_start: float = 0.0, + control_guidance_end: float = 1.0, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, `list[np.ndarray]`,: + `list[list[torch.Tensor]]`, `list[list[np.ndarray]]` or `list[list[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `list[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + control_guidance_start (`float` or `list[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `list[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + unet = self.unet._orig_mod if is_compiled_module(self.unet) else self.unet + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + image, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=clip_skip, + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Prepare image + image = self.prepare_image( + image=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=unet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + ) + height, width = image.shape[-2:] + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 6. Prepare latent variables + num_channels_latents = self.unet.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + is_controlnet_compiled = is_compiled_module(self.unet) + is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Relevant thread: + # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 + if torch.cuda.is_available() and is_controlnet_compiled and is_torch_higher_equal_2_1: + torch._inductor.cudagraph_mark_step_begin() + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + apply_control = ( + i / len(timesteps) >= control_guidance_start and (i + 1) / len(timesteps) <= control_guidance_end + ) + noise_pred = self.unet( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=prompt_embeds, + controlnet_cond=image, + conditioning_scale=controlnet_conditioning_scale, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=True, + apply_control=apply_control, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + # If we do sequential model offloading, let's offload unet and controlnet + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + self.controlnet.to("cpu") + empty_device_cache() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py b/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py new file mode 100644 index 0000000000000000000000000000000000000000..e7a5862b5b7a47849f7d77364f7544d9759cfbf6 --- /dev/null +++ b/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py @@ -0,0 +1,1116 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, +) + +from diffusers.utils.import_utils import is_invisible_watermark_available + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ControlNetXSAdapter, UNet2DConditionModel, UNetControlNetXSModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor +from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline +from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput + + +if is_invisible_watermark_available(): + from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker + + +from ...utils import is_torch_xla_available + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # !pip install opencv-python transformers accelerate + >>> from diffusers import StableDiffusionXLControlNetXSPipeline, ControlNetXSAdapter, AutoencoderKL + >>> from diffusers.utils import load_image + >>> import numpy as np + >>> import torch + + >>> import cv2 + >>> from PIL import Image + + >>> prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting" + >>> negative_prompt = "low quality, bad quality, sketches" + + >>> # download an image + >>> image = load_image( + ... "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png" + ... ) + + >>> # initialize the models and pipeline + >>> controlnet_conditioning_scale = 0.5 + >>> vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) + >>> controlnet = ControlNetXSAdapter.from_pretrained( + ... "UmerHA/Testing-ConrolNetXS-SDXL-canny", torch_dtype=torch.float16 + ... ) + >>> pipe = StableDiffusionXLControlNetXSPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16 + ... ) + >>> pipe.enable_model_cpu_offload() + + >>> # get canny image + >>> image = np.array(image) + >>> image = cv2.Canny(image, 100, 200) + >>> image = image[:, :, None] + >>> image = np.concatenate([image, image, image], axis=2) + >>> canny_image = Image.fromarray(image) + + >>> # generate image + >>> image = pipe( + ... prompt, controlnet_conditioning_scale=controlnet_conditioning_scale, image=canny_image + ... ).images[0] + ``` +""" + + +class StableDiffusionXLControlNetXSPipeline( + DeprecatedPipelineMixin, + DiffusionPipeline, + TextualInversionLoaderMixin, + StableDiffusionXLLoraLoaderMixin, + FromSingleFileMixin, +): + r""" + Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet-XS guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + text_encoder_2 ([`~transformers.CLIPTextModelWithProjection`]): + Second frozen text-encoder + ([laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + tokenizer_2 ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A [`UNet2DConditionModel`] used to create a UNetControlNetXSModel to denoise the encoded image latents. + controlnet ([`ControlNetXSAdapter`]): + A [`ControlNetXSAdapter`] to be used in combination with `unet` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): + Whether the negative prompt embeddings should always be set to 0. Also see the config of + `stabilityai/stable-diffusion-xl-base-1-0`. + add_watermarker (`bool`, *optional*): + Whether to use the [invisible_watermark](https://github.com/ShieldMnt/invisible-watermark/) library to + watermark output images. If not defined, it defaults to `True` if the package is installed; otherwise no + watermarker is used. + """ + + _last_supported_version = "0.33.1" + model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "feature_extractor", + ] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "add_text_embeds", + "add_time_ids", + "negative_pooled_prompt_embeds", + "negative_add_time_ids", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel | UNetControlNetXSModel, + controlnet: ControlNetXSAdapter, + scheduler: KarrasDiffusionSchedulers, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: bool | None = None, + feature_extractor: CLIPImageProcessor = None, + ): + super().__init__() + + if isinstance(unet, UNet2DConditionModel): + unet = UNetControlNetXSModel.from_unet(unet, controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None + + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt: str, + prompt_2: str | None = None, + device: torch.device | None = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: str | None = None, + negative_prompt_2: str | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2: + pooled_prompt_embeds = prompt_embeds[0] + + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: list[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2: + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + prompt_2, + image, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, + ): + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + # Check `image` and ``controlnet_conditioning_scale`` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.unet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.unet, UNetControlNetXSModel) + or is_compiled + and isinstance(self.unet._orig_mod, UNetControlNetXSModel) + ): + self.check_image(image, prompt, prompt_embeds) + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + else: + assert False + + start, end = control_guidance_start, control_guidance_end + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + ): + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance: + image = torch.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.base_add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + deprecate( + "upcast_vae", + "1.0.0", + "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.", + ) + self.vae.to(dtype=torch.float32) + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.guidance_scale + def guidance_scale(self): + return self._guidance_scale + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.clip_skip + def clip_skip(self): + return self._clip_skip + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.do_classifier_free_guidance + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.cross_attention_kwargs + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.num_timesteps + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + prompt_2: str | list[str] | None = None, + image: PipelineImageInput = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + negative_prompt: str | list[str] | None = None, + negative_prompt_2: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + controlnet_conditioning_scale: float | list[float] = 1.0, + control_guidance_start: float = 0.0, + control_guidance_end: float = 1.0, + original_size: tuple[int, int] = None, + crops_coords_top_left: tuple[int, int] = (0, 0), + target_size: tuple[int, int] = None, + negative_original_size: tuple[int, int] | None = None, + negative_crops_coords_top_left: tuple[int, int] = (0, 0), + negative_target_size: tuple[int, int] | None = None, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, `list[np.ndarray]`,: + `list[list[torch.Tensor]]`, `list[list[np.ndarray]]` or `list[list[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 5.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. This is sent to `tokenizer_2` + and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, pooled text embeddings are generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt + weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `list[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. + control_guidance_start (`float`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] is + returned, otherwise a `tuple` is returned containing the output images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + unet = self.unet._orig_mod if is_compiled_module(self.unet) else self.unet + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + image, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt, + prompt_2, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=clip_skip, + ) + + # 4. Prepare image + if isinstance(unet, UNetControlNetXSModel): + image = self.prepare_image( + image=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=unet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + ) + height, width = image.shape[-2:] + else: + assert False + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 6. Prepare latent variables + num_channels_latents = self.unet.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Prepare added time ids & embeddings + if isinstance(image, list): + original_size = original_size or image[0].shape[-2:] + else: + original_size = original_size or image.shape[-2:] + target_size = target_size or (height, width) + + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + is_controlnet_compiled = is_compiled_module(self.unet) + is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Relevant thread: + # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 + if is_controlnet_compiled and is_torch_higher_equal_2_1: + torch._inductor.cudagraph_mark_step_begin() + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + # predict the noise residual + apply_control = ( + i / len(timesteps) >= control_guidance_start and (i + 1) / len(timesteps) <= control_guidance_end + ) + noise_pred = self.unet( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=prompt_embeds, + controlnet_cond=image, + conditioning_scale=controlnet_conditioning_scale, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=True, + apply_control=apply_control, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + # manually for max memory savings + if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + if not output_type == "latent": + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/diffusers/pipelines/cosmos/__init__.py b/diffusers/pipelines/cosmos/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5fc66cdf84b662e08d06949a9d14f00eb302627b --- /dev/null +++ b/diffusers/pipelines/cosmos/__init__.py @@ -0,0 +1,64 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_cosmos2_5_predict"] = [ + "Cosmos2_5_PredictBasePipeline", + ] + _import_structure["pipeline_cosmos2_5_transfer"] = [ + "Cosmos2_5_TransferPipeline", + ] + _import_structure["pipeline_cosmos2_text2image"] = ["Cosmos2TextToImagePipeline"] + _import_structure["pipeline_cosmos2_video2world"] = ["Cosmos2VideoToWorldPipeline"] + _import_structure["pipeline_cosmos_text2world"] = ["CosmosTextToWorldPipeline"] + _import_structure["pipeline_cosmos_video2world"] = ["CosmosVideoToWorldPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_cosmos2_5_predict import ( + Cosmos2_5_PredictBasePipeline, + ) + from .pipeline_cosmos2_5_transfer import Cosmos2_5_TransferPipeline + from .pipeline_cosmos2_text2image import Cosmos2TextToImagePipeline + from .pipeline_cosmos2_video2world import Cosmos2VideoToWorldPipeline + from .pipeline_cosmos_text2world import CosmosTextToWorldPipeline + from .pipeline_cosmos_video2world import CosmosVideoToWorldPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/pipelines/cosmos/pipeline_cosmos2_5_predict.py b/diffusers/pipelines/cosmos/pipeline_cosmos2_5_predict.py new file mode 100644 index 0000000000000000000000000000000000000000..cdea71a5ab93af1c7ec642d9b1dd862ec79a6255 --- /dev/null +++ b/diffusers/pipelines/cosmos/pipeline_cosmos2_5_predict.py @@ -0,0 +1,881 @@ +# Copyright 2025 The NVIDIA Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable + +import numpy as np +import torch +import torchvision +import torchvision.transforms +import torchvision.transforms.functional +from transformers import AutoTokenizer, Qwen2_5_VLForConditionalGeneration + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput +from ...models import AutoencoderKLWan, CosmosTransformer3DModel +from ...schedulers import UniPCMultistepScheduler +from ...utils import is_cosmos_guardrail_available, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import CosmosPipelineOutput + + +if is_cosmos_guardrail_available(): + from cosmos_guardrail import CosmosSafetyChecker +else: + + class CosmosSafetyChecker: + def __init__(self, *args, **kwargs): + raise ImportError( + "`cosmos_guardrail` is not installed. Please install it to use the safety checker for Cosmos: `pip install cosmos_guardrail`." + ) + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +DEFAULT_NEGATIVE_PROMPT = ( + "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, " + "over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, " + "underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, " + "jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, " + "fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. " + "Overall, the video is of poor quality." +) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import Cosmos2_5_PredictBasePipeline + >>> from diffusers.utils import export_to_video, load_image, load_video + + >>> model_id = "nvidia/Cosmos-Predict2.5-2B" + >>> pipe = Cosmos2_5_PredictBasePipeline.from_pretrained( + ... model_id, revision="diffusers/base/post-trained", torch_dtype=torch.bfloat16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> # Common negative prompt reused across modes. + >>> negative_prompt = ( + ... "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, " + ... "over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, " + ... "underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky " + ... "movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, " + ... "fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. " + ... "Overall, the video is of poor quality." + ... ) + + >>> # Text2World: generate a 93-frame world video from text only. + >>> prompt = ( + ... "As the red light shifts to green, the red bus at the intersection begins to move forward, its headlights " + ... "cutting through the falling snow. The snowy tire tracks deepen as the vehicle inches ahead, casting fresh " + ... "lines onto the slushy road. Around it, streetlights glow warmer, illuminating the drifting flakes and wet " + ... "reflections on the asphalt. Other cars behind start to edge forward, their beams joining the scene. " + ... "The stillness of the urban street transitions into motion as the quiet snowfall is punctuated by the slow " + ... "advance of traffic through the frosty city corridor." + ... ) + >>> video = pipe( + ... image=None, + ... video=None, + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... num_frames=93, + ... generator=torch.Generator().manual_seed(1), + ... ).frames[0] + >>> export_to_video(video, "text2world.mp4", fps=16) + + >>> # Image2World: condition on a single image and generate a 93-frame world video. + >>> prompt = ( + ... "A high-definition video captures the precision of robotic welding in an industrial setting. " + ... "The first frame showcases a robotic arm, equipped with a welding torch, positioned over a large metal structure. " + ... "The welding process is in full swing, with bright sparks and intense light illuminating the scene, creating a vivid " + ... "display of blue and white hues. A significant amount of smoke billows around the welding area, partially obscuring " + ... "the view but emphasizing the heat and activity. The background reveals parts of the workshop environment, including a " + ... "ventilation system and various pieces of machinery, indicating a busy and functional industrial workspace. As the video " + ... "progresses, the robotic arm maintains its steady position, continuing the welding process and moving to its left. " + ... "The welding torch consistently emits sparks and light, and the smoke continues to rise, diffusing slightly as it moves upward. " + ... "The metal surface beneath the torch shows ongoing signs of heating and melting. The scene retains its industrial ambiance, with " + ... "the welding sparks and smoke dominating the visual field, underscoring the ongoing nature of the welding operation." + ... ) + >>> image = load_image( + ... "https://media.githubusercontent.com/media/nvidia-cosmos/cosmos-predict2.5/refs/heads/main/assets/base/robot_welding.jpg" + ... ) + >>> video = pipe( + ... image=image, + ... video=None, + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... num_frames=93, + ... generator=torch.Generator().manual_seed(1), + ... ).frames[0] + >>> export_to_video(video, "image2world.mp4", fps=16) + + >>> # Video2World: condition on an input clip and predict a 93-frame world video. + >>> prompt = ( + ... "The video opens with an aerial view of a large-scale sand mining construction operation, showcasing extensive piles " + ... "of brown sand meticulously arranged in parallel rows. A central water channel, fed by a water pipe, flows through the " + ... "middle of these sand heaps, creating ripples and movement as it cascades down. The surrounding area features dense green " + ... "vegetation on the left, contrasting with the sandy terrain, while a body of water is visible in the background on the right. " + ... "As the video progresses, a piece of heavy machinery, likely a bulldozer, enters the frame from the right, moving slowly along " + ... "the edge of the sand piles. This machinery's presence indicates ongoing construction work in the operation. The final frame " + ... "captures the same scene, with the water continuing its flow and the bulldozer still in motion, maintaining the dynamic yet " + ... "steady pace of the construction activity." + ... ) + >>> input_video = load_video( + ... "https://github.com/nvidia-cosmos/cosmos-predict2.5/raw/refs/heads/main/assets/base/sand_mining.mp4" + ... ) + >>> video = pipe( + ... image=None, + ... video=input_video, + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... num_frames=93, + ... generator=torch.Generator().manual_seed(1), + ... ).frames[0] + >>> export_to_video(video, "video2world.mp4", fps=16) + + >>> # To produce an image instead of a world (video) clip, set num_frames=1 and + >>> # save the first frame: pipe(..., num_frames=1).frames[0][0]. + ``` +""" + + +class Cosmos2_5_PredictBasePipeline(DiffusionPipeline): + r""" + Pipeline for [Cosmos Predict2.5](https://github.com/nvidia-cosmos/cosmos-predict2.5) base model. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + text_encoder ([`Qwen2_5_VLForConditionalGeneration`]): + Frozen text-encoder. Cosmos Predict2.5 uses the [Qwen2.5 + VL](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) encoder. + tokenizer (`AutoTokenizer`): + Tokenizer associated with the Qwen2.5 VL encoder. + transformer ([`CosmosTransformer3DModel`]): + Conditional Transformer to denoise the encoded image latents. + scheduler ([`UniPCMultistepScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLWan`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + # We mark safety_checker as optional here to get around some test failures, but it is not really optional + _optional_components = ["safety_checker"] + _exclude_from_cpu_offload = ["safety_checker"] + + def __init__( + self, + text_encoder: Qwen2_5_VLForConditionalGeneration, + tokenizer: AutoTokenizer, + transformer: CosmosTransformer3DModel, + vae: AutoencoderKLWan, + scheduler: UniPCMultistepScheduler, + safety_checker: CosmosSafetyChecker = None, + ): + super().__init__() + + if safety_checker is None: + safety_checker = CosmosSafetyChecker() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + safety_checker=safety_checker, + ) + + self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1).float() + if getattr(self.vae.config, "latents_mean", None) is not None + else None + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).float() + if getattr(self.vae.config, "latents_std", None) is not None + else None + ) + self.latents_mean = latents_mean + self.latents_std = latents_std + + if self.latents_mean is None or self.latents_std is None: + raise ValueError("VAE configuration must define both `latents_mean` and `latents_std`.") + + def _get_prompt_embeds( + self, + prompt: str | list[str] = None, + max_sequence_length: int = 512, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + prompt = [prompt] if isinstance(prompt, str) else prompt + + input_ids_batch = [] + + for sample_idx in range(len(prompt)): + conversations = [ + { + "role": "system", + "content": [ + { + "type": "text", + "text": "You are a helpful assistant who will provide prompts to an image generator.", + } + ], + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": prompt[sample_idx], + } + ], + }, + ] + input_ids = self.tokenizer.apply_chat_template( + conversations, + tokenize=True, + add_generation_prompt=False, + add_vision_id=False, + max_length=max_sequence_length, + truncation=True, + padding="max_length", + ) + input_ids = ( + input_ids["input_ids"] if not isinstance(input_ids, list) and "input_ids" in input_ids else input_ids + ) + input_ids = torch.LongTensor(input_ids) + input_ids_batch.append(input_ids) + + input_ids_batch = torch.stack(input_ids_batch, dim=0) + + outputs = self.text_encoder( + input_ids_batch.to(device), + output_hidden_states=True, + ) + hidden_states = outputs.hidden_states + + normalized_hidden_states = [] + for layer_idx in range(1, len(hidden_states)): + normalized_state = (hidden_states[layer_idx] - hidden_states[layer_idx].mean(dim=-1, keepdim=True)) / ( + hidden_states[layer_idx].std(dim=-1, keepdim=True) + 1e-8 + ) + normalized_hidden_states.append(normalized_state) + + prompt_embeds = torch.cat(normalized_hidden_states, dim=-1) + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + return prompt_embeds + + # Modified from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 512, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_prompt_embeds( + prompt=prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt if negative_prompt is not None else DEFAULT_NEGATIVE_PROMPT + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_prompt_embeds( + prompt=negative_prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = negative_prompt_embeds.shape + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds, negative_prompt_embeds + + # Modified from diffusers.pipelines.cosmos.pipeline_cosmos2_video2world.Cosmos2VideoToWorldPipeline.prepare_latents and + # diffusers.pipelines.cosmos.pipeline_cosmos2_video2world.Cosmos2TextToImagePipeline.prepare_latents + def prepare_latents( + self, + video: torch.Tensor | None, + batch_size: int, + num_channels_latents: int = 16, + height: int = 704, + width: int = 1280, + num_frames_in: int = 93, + num_frames_out: int = 93, + do_classifier_free_guidance: bool = True, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + B = batch_size + C = num_channels_latents + T = (num_frames_out - 1) // self.vae_scale_factor_temporal + 1 + H = height // self.vae_scale_factor_spatial + W = width // self.vae_scale_factor_spatial + shape = (B, C, T, H, W) + + if num_frames_in == 0: + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + cond_mask = torch.zeros((B, 1, T, H, W), dtype=latents.dtype, device=latents.device) + cond_indicator = torch.zeros((B, 1, T, 1, 1), dtype=latents.dtype, device=latents.device) + + cond_latents = torch.zeros_like(latents) + + return ( + latents, + cond_latents, + cond_mask, + cond_indicator, + ) + else: + if video is None: + raise ValueError("`video` must be provided when `num_frames_in` is greater than 0.") + needs_preprocessing = not (isinstance(video, torch.Tensor) and video.ndim == 5 and video.shape[1] == 3) + if needs_preprocessing: + video = self.video_processor.preprocess_video(video, height, width) + video = video.to(device=device, dtype=self.vae.dtype) + if isinstance(generator, list): + cond_latents = [ + retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator=generator[i]) + for i in range(batch_size) + ] + else: + cond_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video] + + cond_latents = torch.cat(cond_latents, dim=0).to(dtype) + + latents_mean = self.latents_mean.to(device=device, dtype=dtype) + latents_std = self.latents_std.to(device=device, dtype=dtype) + cond_latents = (cond_latents - latents_mean) / latents_std + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + padding_shape = (B, 1, T, H, W) + ones_padding = latents.new_ones(padding_shape) + zeros_padding = latents.new_zeros(padding_shape) + + num_cond_latent_frames = (num_frames_in - 1) // self.vae_scale_factor_temporal + 1 + cond_indicator = latents.new_zeros(1, 1, latents.size(2), 1, 1) + cond_indicator[:, :, 0:num_cond_latent_frames] = 1.0 + cond_mask = cond_indicator * ones_padding + (1 - cond_indicator) * zeros_padding + + return ( + latents, + cond_latents, + cond_mask, + cond_indicator, + ) + + # Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput | None = None, + video: list[PipelineImageInput] | None = None, + prompt: str | list[str] | None = None, + negative_prompt: str | list[str] | None = None, + height: int = 704, + width: int = 1280, + num_frames: int = 93, + num_inference_steps: int = 36, + guidance_scale: float = 7.0, + num_videos_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback_on_step_end: Callable[[int, int, None], PipelineCallback | MultiPipelineCallbacks] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + conditional_frame_timestep: float = 0.1, + num_latent_conditional_frames: int = 2, + ): + r""" + The call function to the pipeline for generation. Supports three modes: + + - **Text2World**: `image=None`, `video=None`, `prompt` provided. Generates a world clip. + - **Image2World**: `image` provided, `video=None`, `prompt` provided. Conditions on a single frame. + - **Video2World**: `video` provided, `image=None`, `prompt` provided. Conditions on an input clip. + + Set `num_frames=93` (default) to produce a world video, or `num_frames=1` to produce a single image frame (the + above in "*2Image mode"). + + Outputs follow `output_type` (e.g., `"pil"` returns a list of `num_frames` PIL images per prompt). + + Args: + image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, *optional*): + Optional single image for Image2World conditioning. Must be `None` when `video` is provided. + video (`list[PIL.Image.Image]`, `np.ndarray`, `torch.Tensor`, *optional*): + Optional input video for Video2World conditioning. Must be `None` when `image` is provided. + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide generation. Required unless `prompt_embeds` is supplied. + height (`int`, defaults to `704`): + The height in pixels of the generated image. + width (`int`, defaults to `1280`): + The width in pixels of the generated image. + num_frames (`int`, defaults to `93`): + Number of output frames. Use `93` for world (video) generation; set to `1` to return a single frame. + num_inference_steps (`int`, defaults to `35`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to `7.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`CosmosPipelineOutput`] instead of a plain tuple. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `512`): + The maximum number of tokens in the prompt. If the prompt exceeds this length, it will be truncated. If + the prompt is shorter than this length, it will be padded. + num_latent_conditional_frames (`int`, defaults to `2`): + Number of latent conditional frames to use for Video2World conditioning. The number of pixel frames + extracted from the input video is calculated as `4 * (num_latent_conditional_frames - 1) + 1`. Set to 1 + for Image2World-like behavior (single frame conditioning). + + Examples: + + Returns: + [`~CosmosPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`CosmosPipelineOutput`] is returned, otherwise a `tuple` is returned where + the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + if self.safety_checker is None: + raise ValueError( + f"You have disabled the safety checker for {self.__class__}. This is in violation of the " + "[NVIDIA Open Model License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license). " + f"Please ensure that you are compliant with the license agreement." + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, prompt_embeds, callback_on_step_end_tensor_inputs) + + self._guidance_scale = guidance_scale + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + if self.safety_checker is not None: + self.safety_checker.to(device) + if prompt is not None: + prompt_list = [prompt] if isinstance(prompt, str) else prompt + for p in prompt_list: + if not self.safety_checker.check_text_safety(p): + raise ValueError( + f"Cosmos Guardrail detected unsafe text in the prompt: {p}. Please ensure that the " + f"prompt abides by the NVIDIA Open Model License Agreement." + ) + + # Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Encode input prompt + ( + prompt_embeds, + negative_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + device=device, + max_sequence_length=max_sequence_length, + ) + + vae_dtype = self.vae.dtype + transformer_dtype = self.transformer.dtype + + num_frames_in = None + if image is not None: + if batch_size != 1: + raise ValueError(f"batch_size must be 1 for image input (given {batch_size})") + + image = torchvision.transforms.functional.to_tensor(image).unsqueeze(0) + video = torch.cat([image, torch.zeros_like(image).repeat(num_frames - 1, 1, 1, 1)], dim=0) + video = video.unsqueeze(0) + num_frames_in = 1 + elif video is None: + video = torch.zeros(batch_size, num_frames, 3, height, width, dtype=torch.uint8) + num_frames_in = 0 + else: + if batch_size != 1: + raise ValueError(f"batch_size must be 1 for video input (given {batch_size})") + + if num_latent_conditional_frames not in [1, 2]: + raise ValueError( + f"num_latent_conditional_frames must be 1 or 2, but got {num_latent_conditional_frames}" + ) + + frames_to_extract = 4 * (num_latent_conditional_frames - 1) + 1 + + total_input_frames = len(video) + + if total_input_frames < frames_to_extract: + raise ValueError( + f"Input video has only {total_input_frames} frames but Video2World requires at least " + f"{frames_to_extract} frames for conditioning." + ) + + num_frames_in = frames_to_extract + + assert video is not None + video = self.video_processor.preprocess_video(video, height, width) + + # For Video2World: extract last frames_to_extract frames from input, then pad + if image is None and num_frames_in > 0 and num_frames_in < video.shape[2]: + video = video[:, :, -num_frames_in:, :, :] + + num_frames_out = num_frames + + if video.shape[2] < num_frames_out: + n_pad_frames = num_frames_out - video.shape[2] + last_frame = video[:, :, -1:, :, :] # [B, C, T==1, H, W] + pad_frames = last_frame.repeat(1, 1, n_pad_frames, 1, 1) # [B, C, T, H, W] + video = torch.cat((video, pad_frames), dim=2) + + assert num_frames_in <= num_frames_out, f"expected ({num_frames_in=}) <= ({num_frames_out=})" + + video = video.to(device=device, dtype=vae_dtype) + + num_channels_latents = self.transformer.config.in_channels - 1 + latents, cond_latent, cond_mask, cond_indicator = self.prepare_latents( + video=video, + batch_size=batch_size * num_videos_per_prompt, + num_channels_latents=num_channels_latents, + height=height, + width=width, + num_frames_in=num_frames_in, + num_frames_out=num_frames, + do_classifier_free_guidance=self.do_classifier_free_guidance, + dtype=torch.float32, + device=device, + generator=generator, + latents=latents, + ) + cond_timestep = torch.ones_like(cond_indicator) * conditional_frame_timestep + cond_mask = cond_mask.to(transformer_dtype) + + padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype) + + # Denoising loop + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + self._num_timesteps = len(timesteps) + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + gt_velocity = (latents - cond_latent) * cond_mask + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t.cpu().item() + + # NOTE: assumes sigma(t) \in [0, 1] + sigma_t = ( + torch.tensor(self.scheduler.sigmas[i].item()) + .unsqueeze(0) + .to(device=device, dtype=transformer_dtype) + ) + + in_latents = cond_mask * cond_latent + (1 - cond_mask) * latents + in_latents = in_latents.to(transformer_dtype) + in_timestep = cond_indicator * cond_timestep + (1 - cond_indicator) * sigma_t + noise_pred = self.transformer( + hidden_states=in_latents, + condition_mask=cond_mask, + timestep=in_timestep, + encoder_hidden_states=prompt_embeds, + padding_mask=padding_mask, + return_dict=False, + )[0] + # NOTE: replace velocity (noise_pred) with gt_velocity for conditioning inputs only + noise_pred = gt_velocity + noise_pred * (1 - cond_mask) + + if self.do_classifier_free_guidance: + noise_pred_neg = self.transformer( + hidden_states=in_latents, + condition_mask=cond_mask, + timestep=in_timestep, + encoder_hidden_states=negative_prompt_embeds, + padding_mask=padding_mask, + return_dict=False, + )[0] + # NOTE: replace velocity (noise_pred_neg) with gt_velocity for conditioning inputs only + noise_pred_neg = gt_velocity + noise_pred_neg * (1 - cond_mask) + noise_pred = noise_pred + self.guidance_scale * (noise_pred - noise_pred_neg) + + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if not output_type == "latent": + latents_mean = self.latents_mean.to(latents.device, latents.dtype) + latents_std = self.latents_std.to(latents.device, latents.dtype) + latents = latents * latents_std + latents_mean + video = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0] + video = self._match_num_frames(video, num_frames) + + assert self.safety_checker is not None + self.safety_checker.to(device) + video = self.video_processor.postprocess_video(video, output_type="np") + video = (video * 255).astype(np.uint8) + video_batch = [] + for vid in video: + vid = self.safety_checker.check_video_safety(vid) + video_batch.append(vid) + video = np.stack(video_batch).astype(np.float32) / 255.0 * 2 - 1 + video = torch.from_numpy(video).permute(0, 4, 1, 2, 3) + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return CosmosPipelineOutput(frames=video) + + def _match_num_frames(self, video: torch.Tensor, target_num_frames: int) -> torch.Tensor: + if target_num_frames <= 0 or video.shape[2] == target_num_frames: + return video + + frames_per_latent = max(self.vae_scale_factor_temporal, 1) + video = torch.repeat_interleave(video, repeats=frames_per_latent, dim=2) + + current_frames = video.shape[2] + if current_frames < target_num_frames: + pad = video[:, :, -1:, :, :].repeat(1, 1, target_num_frames - current_frames, 1, 1) + video = torch.cat([video, pad], dim=2) + elif current_frames > target_num_frames: + video = video[:, :, :target_num_frames] + + return video diff --git a/diffusers/pipelines/cosmos/pipeline_cosmos2_5_transfer.py b/diffusers/pipelines/cosmos/pipeline_cosmos2_5_transfer.py new file mode 100644 index 0000000000000000000000000000000000000000..bbe38c44355e71f97c6d93d10eaa09c84eb3677c --- /dev/null +++ b/diffusers/pipelines/cosmos/pipeline_cosmos2_5_transfer.py @@ -0,0 +1,1012 @@ +# Copyright 2025 The NVIDIA Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Dict, List, Optional, Union + +import numpy as np +import PIL.Image +import torch +from transformers import AutoTokenizer, Qwen2_5_VLForConditionalGeneration + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput +from ...models import AutoencoderKLWan, CosmosControlNetModel, CosmosTransformer3DModel +from ...schedulers import UniPCMultistepScheduler +from ...utils import is_cosmos_guardrail_available, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import CosmosPipelineOutput + + +if is_cosmos_guardrail_available(): + from cosmos_guardrail import CosmosSafetyChecker +else: + + class CosmosSafetyChecker: + def __init__(self, *args, **kwargs): + raise ImportError( + "`cosmos_guardrail` is not installed. Please install it to use the safety checker for Cosmos: `pip install cosmos_guardrail`." + ) + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def _maybe_pad_or_trim_video(video: torch.Tensor, num_frames: int): + n_pad_frames = num_frames - video.shape[2] + if n_pad_frames > 0: + last_frame = video[:, :, -1:, :, :] + video = torch.cat((video, last_frame.repeat(1, 1, n_pad_frames, 1, 1)), dim=2) + elif num_frames < video.shape[2]: + video = video[:, :, :num_frames, :, :] + return video + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +DEFAULT_NEGATIVE_PROMPT = "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality." + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import cv2 + >>> import numpy as np + >>> import torch + >>> from diffusers import Cosmos2_5_TransferPipeline, AutoModel + >>> from diffusers.utils import export_to_video, load_video + + >>> model_id = "nvidia/Cosmos-Transfer2.5-2B" + >>> # Load a Transfer2.5 controlnet variant (edge, depth, seg, or blur) + >>> controlnet = AutoModel.from_pretrained(model_id, revision="diffusers/controlnet/general/edge") + >>> pipe = Cosmos2_5_TransferPipeline.from_pretrained( + ... model_id, controlnet=controlnet, revision="diffusers/general", torch_dtype=torch.bfloat16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> # Video2World with edge control: Generate video guided by edge maps extracted from input video. + >>> prompt = ( + ... "The video is a demonstration of robotic manipulation, likely in a laboratory or testing environment. It" + ... "features two robotic arms interacting with a piece of blue fabric. The setting is a room with a beige" + ... "couch in the background, providing a neutral backdrop for the robotic activity. The robotic arms are" + ... "positioned on either side of the fabric, which is placed on a yellow cushion. The left robotic arm is" + ... "white with a black gripper, while the right arm is black with a more complex, articulated gripper. At the" + ... "beginning, the fabric is laid out on the cushion. The left robotic arm approaches the fabric, its gripper" + ... "opening and closing as it positions itself. The right arm remains stationary initially, poised to assist." + ... "As the video progresses, the left arm grips the fabric, lifting it slightly off the cushion. The right arm" + ... "then moves in, its gripper adjusting to grasp the opposite side of the fabric. Both arms work in" + ... "coordination, lifting and holding the fabric between them. The fabric is manipulated with precision," + ... "showcasing the dexterity and control of the robotic arms. The camera remains static throughout, focusing" + ... "on the interaction between the robotic arms and the fabric, allowing viewers to observe the detailed" + ... "movements and coordination involved in the task." + ... ) + >>> negative_prompt = ( + ... "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, " + ... "over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, " + ... "underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky " + ... "movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, " + ... "fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. " + ... "Overall, the video is of poor quality." + ... ) + >>> input_video = load_video( + ... "https://github.com/nvidia-cosmos/cosmos-transfer2.5/raw/refs/heads/main/assets/robot_example/robot_input.mp4" + ... ) + >>> num_frames = 93 + + >>> # Extract edge maps from the input video using Canny edge detection + >>> edge_maps = [ + ... cv2.Canny(cv2.cvtColor(np.array(frame.convert("RGB")), cv2.COLOR_RGB2BGR), 100, 200) + ... for frame in input_video[:num_frames] + ... ] + >>> edge_maps = np.stack(edge_maps)[None] # (T, H, W) -> (1, T, H, W) + >>> controls = torch.from_numpy(edge_maps).expand(3, -1, -1, -1) # (1, T, H, W) -> (3, T, H, W) + >>> controls = [Image.fromarray(x.numpy()) for x in controls.permute(1, 2, 3, 0)] + >>> export_to_video(controls, "edge_controlled_video_edge.mp4", fps=30) + + >>> # Transfer inference with controls. + >>> video = pipe( + ... controls=controls, + ... controls_conditioning_scale=1.0, + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... num_frames=num_frames, + ... ).frames[0] + >>> export_to_video(video, "edge_controlled_video.mp4", fps=30) + ``` +""" + + +class Cosmos2_5_TransferPipeline(DiffusionPipeline): + r""" + Pipeline for Cosmos Transfer2.5, supporting auto-regressive inference. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + text_encoder ([`Qwen2_5_VLForConditionalGeneration`]): + Frozen text-encoder. Cosmos Transfer2.5 uses the [Qwen2.5 + VL](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) encoder. + tokenizer (`AutoTokenizer`): + Tokenizer associated with the Qwen2.5 VL encoder. + transformer ([`CosmosTransformer3DModel`]): + Conditional Transformer to denoise the encoded image latents. + scheduler ([`UniPCMultistepScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLWan`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + controlnet ([`CosmosControlNetModel`]): + ControlNet used to condition generation on control inputs. + """ + + model_cpu_offload_seq = "text_encoder->transformer->controlnet->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + # We mark safety_checker as optional here to get around some test failures, but it is not really optional + _optional_components = ["safety_checker"] + _exclude_from_cpu_offload = ["safety_checker"] + + def __init__( + self, + text_encoder: Qwen2_5_VLForConditionalGeneration, + tokenizer: AutoTokenizer, + transformer: CosmosTransformer3DModel, + vae: AutoencoderKLWan, + scheduler: UniPCMultistepScheduler, + controlnet: CosmosControlNetModel, + safety_checker: Optional[CosmosSafetyChecker] = None, + ): + super().__init__() + + if safety_checker is None: + safety_checker = CosmosSafetyChecker() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + controlnet=controlnet, + scheduler=scheduler, + safety_checker=safety_checker, + ) + + self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1).float() + if getattr(self.vae.config, "latents_mean", None) is not None + else None + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).float() + if getattr(self.vae.config, "latents_std", None) is not None + else None + ) + self.latents_mean = latents_mean + self.latents_std = latents_std + + if self.latents_mean is None or self.latents_std is None: + raise ValueError("VAE configuration must define both `latents_mean` and `latents_std`.") + + def _get_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + max_sequence_length: int = 512, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + prompt = [prompt] if isinstance(prompt, str) else prompt + + input_ids_batch = [] + + for sample_idx in range(len(prompt)): + conversations = [ + { + "role": "system", + "content": [ + { + "type": "text", + "text": "You are a helpful assistant who will provide prompts to an image generator.", + } + ], + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": prompt[sample_idx], + } + ], + }, + ] + input_ids = self.tokenizer.apply_chat_template( + conversations, + tokenize=True, + add_generation_prompt=False, + add_vision_id=False, + max_length=max_sequence_length, + truncation=True, + padding="max_length", + ) + input_ids = ( + input_ids["input_ids"] if not isinstance(input_ids, list) and "input_ids" in input_ids else input_ids + ) + input_ids = torch.LongTensor(input_ids) + input_ids_batch.append(input_ids) + + input_ids_batch = torch.stack(input_ids_batch, dim=0) + + outputs = self.text_encoder( + input_ids_batch.to(device), + output_hidden_states=True, + ) + hidden_states = outputs.hidden_states + + normalized_hidden_states = [] + for layer_idx in range(1, len(hidden_states)): + normalized_state = (hidden_states[layer_idx] - hidden_states[layer_idx].mean(dim=-1, keepdim=True)) / ( + hidden_states[layer_idx].std(dim=-1, keepdim=True) + 1e-8 + ) + normalized_hidden_states.append(normalized_state) + + prompt_embeds = torch.cat(normalized_hidden_states, dim=-1) + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + return prompt_embeds + + # Modified from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 512, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_prompt_embeds( + prompt=prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_prompt_embeds( + prompt=negative_prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = negative_prompt_embeds.shape + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds, negative_prompt_embeds + + # Modified from diffusers.pipelines.cosmos.pipeline_cosmos2_video2world.Cosmos2VideoToWorldPipeline.prepare_latents and + # diffusers.pipelines.cosmos.pipeline_cosmos2_video2world.Cosmos2TextToImagePipeline.prepare_latents + def prepare_latents( + self, + video: Optional[torch.Tensor], + batch_size: int, + num_channels_latents: int = 16, + height: int = 704, + width: int = 1280, + num_frames_in: int = 93, + num_frames_out: int = 93, + do_classifier_free_guidance: bool = True, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + num_cond_latent_frames: int = 0, + ) -> torch.Tensor: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + B = batch_size + C = num_channels_latents + T = (num_frames_out - 1) // self.vae_scale_factor_temporal + 1 + H = height // self.vae_scale_factor_spatial + W = width // self.vae_scale_factor_spatial + shape = (B, C, T, H, W) + + if latents is not None: + if latents.shape[1:] != shape[1:]: + raise ValueError(f"Unexpected `latents` shape, got {latents.shape}, expected {shape}.") + latents = latents.to(device=device, dtype=dtype) + else: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + if num_frames_in == 0: + cond_mask = torch.zeros((B, 1, T, H, W), dtype=latents.dtype, device=latents.device) + cond_indicator = torch.zeros((B, 1, T, 1, 1), dtype=latents.dtype, device=latents.device) + + cond_latents = torch.zeros_like(latents) + + return ( + latents, + cond_latents, + cond_mask, + cond_indicator, + ) + else: + if video is None: + raise ValueError("`video` must be provided when `num_frames_in` is greater than 0.") + video = video.to(device=device, dtype=self.vae.dtype) + if isinstance(generator, list): + cond_latents = [ + retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator=generator[i]) + for i in range(batch_size) + ] + else: + cond_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video] + + cond_latents = torch.cat(cond_latents, dim=0).to(dtype) + + latents_mean = self.latents_mean.to(device=device, dtype=dtype) + latents_std = self.latents_std.to(device=device, dtype=dtype) + cond_latents = (cond_latents - latents_mean) / latents_std + + padding_shape = (B, 1, T, H, W) + ones_padding = latents.new_ones(padding_shape) + zeros_padding = latents.new_zeros(padding_shape) + + cond_indicator = latents.new_zeros(B, 1, latents.size(2), 1, 1) + cond_indicator[:, :, 0:num_cond_latent_frames, :, :] = 1.0 + cond_mask = cond_indicator * ones_padding + (1 - cond_indicator) * zeros_padding + + return ( + latents, + cond_latents, + cond_mask, + cond_indicator, + ) + + # Modified from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + num_ar_conditional_frames=None, + num_ar_latent_conditional_frames=None, + num_frames_per_chunk=None, + num_frames=None, + conditional_frame_timestep=0.1, + ): + if width <= 0 or height <= 0 or height % 16 != 0 or width % 16 != 0: + raise ValueError( + f"`height` and `width` have to be divisible by 16 (& positive) but are {height} and {width}." + ) + + if num_frames is not None and num_frames <= 0: + raise ValueError(f"`num_frames` has to be a positive integer when provided but is {num_frames}.") + + if conditional_frame_timestep < 0 or conditional_frame_timestep > 1: + raise ValueError( + "`conditional_frame_timestep` has to be a float in the [0, 1] interval but is " + f"{conditional_frame_timestep}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if num_ar_latent_conditional_frames is not None and num_ar_conditional_frames is not None: + raise ValueError( + "Provide only one of `num_ar_conditional_frames` or `num_ar_latent_conditional_frames`, not both." + ) + if num_ar_latent_conditional_frames is None and num_ar_conditional_frames is None: + raise ValueError("Provide either `num_ar_conditional_frames` or `num_ar_latent_conditional_frames`.") + if num_ar_latent_conditional_frames is not None and num_ar_latent_conditional_frames < 0: + raise ValueError("`num_ar_latent_conditional_frames` must be >= 0.") + if num_ar_conditional_frames is not None and num_ar_conditional_frames < 0: + raise ValueError("`num_ar_conditional_frames` must be >= 0.") + + if num_ar_latent_conditional_frames is not None: + num_ar_conditional_frames = max( + 0, (num_ar_latent_conditional_frames - 1) * self.vae_scale_factor_temporal + 1 + ) + + min_chunk_len = self.vae_scale_factor_temporal + 1 + if num_frames_per_chunk < min_chunk_len: + logger.warning(f"{num_frames_per_chunk=} must be larger than {min_chunk_len=}, setting to min_chunk_len") + num_frames_per_chunk = min_chunk_len + + max_frames_by_rope = None + if getattr(self.transformer.config, "max_size", None) is not None: + max_frames_by_rope = max( + size // patch + for size, patch in zip(self.transformer.config.max_size, self.transformer.config.patch_size) + ) + if num_frames_per_chunk > max_frames_by_rope: + raise ValueError( + f"{num_frames_per_chunk=} is too large for RoPE setting ({max_frames_by_rope=}). " + "Please reduce `num_frames_per_chunk`." + ) + + if num_ar_conditional_frames >= num_frames_per_chunk: + raise ValueError( + f"{num_ar_conditional_frames=} must be smaller than {num_frames_per_chunk=} for chunked generation." + ) + + return num_frames_per_chunk + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + controls: PipelineImageInput | List[PipelineImageInput], + controls_conditioning_scale: Union[float, List[float]] = 1.0, + prompt: Union[str, List[str]] | None = None, + negative_prompt: Union[str, List[str]] = DEFAULT_NEGATIVE_PROMPT, + height: int = 704, + width: Optional[int] = None, + num_frames: Optional[int] = None, + num_frames_per_chunk: int = 93, + num_inference_steps: int = 36, + guidance_scale: float = 3.0, + num_videos_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + conditional_frame_timestep: float = 0.1, + num_ar_conditional_frames: Optional[int] = 1, + num_ar_latent_conditional_frames: Optional[int] = None, + ): + r""" + `controls` drive the conditioning through ControlNet. Controls are assumed to be pre-processed, e.g. edge maps + are pre-computed. + + Setting `num_frames` will restrict the total number of frames output, if not provided or assigned to None + (default) then the number of output frames will match the input `controls`. + + Auto-regressive inference is supported and thus a sliding window of `num_frames_per_chunk` frames are used per + denoising loop. In addition, when auto-regressive inference is performed, the previous + `num_ar_latent_conditional_frames` or `num_ar_conditional_frames` are used to condition the following denoising + inference loops. + + Args: + controls (`PipelineImageInput`, `List[PipelineImageInput]`): + Control image or video input used by the ControlNet. + controls_conditioning_scale (`float` or `List[float]`, *optional*, defaults to `1.0`): + The scale factor(s) for the ControlNet outputs. A single float is broadcast to all control blocks. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide generation. Required unless `prompt_embeds` is supplied. + height (`int`, defaults to `704`): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. If not provided, this will be determined based on the + aspect ratio of the input and the provided height. + num_frames (`int`, *optional*): + Number of output frames. Defaults to `None` to output the same number of frames as the input + `controls`. + num_inference_steps (`int`, defaults to `36`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to `3.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs. Can be used to + tweak the same generation with different prompts. If not provided, a latents tensor is generated by + sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`CosmosPipelineOutput`] instead of a plain tuple. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `512`): + The maximum number of tokens in the prompt. If the prompt exceeds this length, it will be truncated. If + the prompt is shorter than this length, it will be padded. + num_ar_conditional_frames (`int`, *optional*, defaults to `1`): + Number of frames to condition on subsequent inference loops in auto-regressive inference, i.e. for the + second chunk and onwards. Only used if `num_ar_latent_conditional_frames` is `None`. + + This is only used when auto-regressive inference is performed, i.e. when the number of frames in + controls is > num_frames_per_chunk + num_ar_latent_conditional_frames (`int`, *optional*): + Number of latent frames to condition on subsequent inference loops in auto-regressive inference, i.e. + for the second chunk and onwards. Only used if `num_ar_conditional_frames` is `None`. + + This is only used when auto-regressive inference is performed, i.e. when the number of frames in + controls is > num_frames_per_chunk + Examples: + + Returns: + [`~CosmosPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`CosmosPipelineOutput`] is returned, otherwise a `tuple` is returned where + the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + if self.safety_checker is None: + raise ValueError( + f"You have disabled the safety checker for {self.__class__}. This is in violation of the " + "[NVIDIA Open Model License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license). " + f"Please ensure that you are compliant with the license agreement." + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + if width is None: + frame = controls[0] if isinstance(controls, list) else controls + if isinstance(frame, list): + frame = frame[0] + if isinstance(frame, (torch.Tensor, np.ndarray)): + if frame.ndim == 5: + frame = frame[0, 0] + elif frame.ndim == 4: + frame = frame[0] + + if isinstance(frame, PIL.Image.Image): + width = int((height + 16) * (frame.width / frame.height)) + else: + if frame.ndim != 3: + raise ValueError("`controls` must contain 3D frames in CHW format.") + width = int((height + 16) * (frame.shape[2] / frame.shape[1])) # NOTE: assuming C H W + + num_frames_per_chunk = self.check_inputs( + prompt, + height, + width, + prompt_embeds, + callback_on_step_end_tensor_inputs, + num_ar_conditional_frames, + num_ar_latent_conditional_frames, + num_frames_per_chunk, + num_frames, + conditional_frame_timestep, + ) + + if num_ar_latent_conditional_frames is not None: + num_cond_latent_frames = num_ar_latent_conditional_frames + num_ar_conditional_frames = max(0, (num_cond_latent_frames - 1) * self.vae_scale_factor_temporal + 1) + else: + num_cond_latent_frames = max(0, (num_ar_conditional_frames - 1) // self.vae_scale_factor_temporal + 1) + + self._guidance_scale = guidance_scale + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + if self.safety_checker is not None: + self.safety_checker.to(device) + if prompt is not None: + prompt_list = [prompt] if isinstance(prompt, str) else prompt + for p in prompt_list: + if not self.safety_checker.check_text_safety(p): + raise ValueError( + f"Cosmos Guardrail detected unsafe text in the prompt: {p}. Please ensure that the " + f"prompt abides by the NVIDIA Open Model License Agreement." + ) + + # Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Encode input prompt + ( + prompt_embeds, + negative_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + device=device, + max_sequence_length=max_sequence_length, + ) + + vae_dtype = self.vae.dtype + transformer_dtype = self.transformer.dtype + + if getattr(self.transformer.config, "img_context_dim_in", None): + img_context = torch.zeros( + batch_size, + self.transformer.config.img_context_num_tokens, + self.transformer.config.img_context_dim_in, + device=prompt_embeds.device, + dtype=transformer_dtype, + ) + + if num_videos_per_prompt > 1: + img_context = img_context.repeat_interleave(num_videos_per_prompt, dim=0) + + encoder_hidden_states = (prompt_embeds, img_context) + neg_encoder_hidden_states = (negative_prompt_embeds, img_context) + else: + encoder_hidden_states = prompt_embeds + neg_encoder_hidden_states = negative_prompt_embeds + + control_video = self.video_processor.preprocess_video(controls, height, width) + if control_video.shape[0] != batch_size: + if control_video.shape[0] == 1: + control_video = control_video.repeat(batch_size, 1, 1, 1, 1) + else: + raise ValueError( + f"Expected controls batch size {batch_size} to match prompt batch size, but got {control_video.shape[0]}." + ) + + num_frames_out = control_video.shape[2] + if num_frames is not None: + num_frames_out = min(num_frames_out, num_frames) + + control_video = _maybe_pad_or_trim_video(control_video, num_frames_out) + + # chunk information + num_latent_frames_per_chunk = (num_frames_per_chunk - 1) // self.vae_scale_factor_temporal + 1 + chunk_stride = num_frames_per_chunk - num_ar_conditional_frames + chunk_idxs = [ + (start_idx, min(start_idx + num_frames_per_chunk, num_frames_out)) + for start_idx in range(0, num_frames_out - num_ar_conditional_frames, chunk_stride) + ] + + video_chunks = [] + latents_mean = self.latents_mean.to(dtype=vae_dtype, device=device) + latents_std = self.latents_std.to(dtype=vae_dtype, device=device) + + def decode_latents(latents): + latents = latents * latents_std + latents_mean + video = self.vae.decode(latents.to(dtype=self.vae.dtype, device=device), return_dict=False)[0] + return video + + latents_arg = latents + initial_num_cond_latent_frames = 0 + latent_chunks = [] + num_chunks = len(chunk_idxs) + total_steps = num_inference_steps * num_chunks + with self.progress_bar(total=total_steps) as progress_bar: + for chunk_idx, (start_idx, end_idx) in enumerate(chunk_idxs): + if chunk_idx == 0: + prev_output = torch.zeros((batch_size, num_frames_per_chunk, 3, height, width), dtype=vae_dtype) + prev_output = self.video_processor.preprocess_video(prev_output, height, width) + else: + prev_output = video_chunks[-1].clone() + if num_ar_conditional_frames > 0: + prev_output[:, :, :num_ar_conditional_frames] = prev_output[:, :, -num_ar_conditional_frames:] + prev_output[:, :, num_ar_conditional_frames:] = -1 # -1 == 0 in processed video space + else: + prev_output.fill_(-1) + + chunk_video = prev_output.to(device=device, dtype=vae_dtype) + chunk_video = _maybe_pad_or_trim_video(chunk_video, num_frames_per_chunk) + latents, cond_latent, cond_mask, cond_indicator = self.prepare_latents( + video=chunk_video, + batch_size=batch_size * num_videos_per_prompt, + num_channels_latents=self.transformer.config.in_channels - 1, + height=height, + width=width, + num_frames_in=chunk_video.shape[2], + num_frames_out=num_frames_per_chunk, + do_classifier_free_guidance=self.do_classifier_free_guidance, + dtype=torch.float32, + device=device, + generator=generator, + num_cond_latent_frames=initial_num_cond_latent_frames + if chunk_idx == 0 + else num_cond_latent_frames, + latents=latents_arg, + ) + cond_mask = cond_mask.to(transformer_dtype) + cond_timestep = torch.ones_like(cond_indicator) * conditional_frame_timestep + padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype) + + chunk_control_video = control_video[:, :, start_idx:end_idx, ...].to( + device=device, dtype=self.vae.dtype + ) + chunk_control_video = _maybe_pad_or_trim_video(chunk_control_video, num_frames_per_chunk) + if isinstance(generator, list): + controls_latents = [ + retrieve_latents(self.vae.encode(chunk_control_video[i].unsqueeze(0)), generator=generator[i]) + for i in range(chunk_control_video.shape[0]) + ] + else: + controls_latents = [ + retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator=generator) + for vid in chunk_control_video + ] + controls_latents = torch.cat(controls_latents, dim=0).to(transformer_dtype) + + controls_latents = (controls_latents - latents_mean) / latents_std + + # Denoising loop + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + self._num_timesteps = len(timesteps) + + gt_velocity = (latents - cond_latent) * cond_mask + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t.cpu().item() + + # NOTE: assumes sigma(t) \in [0, 1] + sigma_t = ( + torch.tensor(self.scheduler.sigmas[i].item()) + .unsqueeze(0) + .to(device=device, dtype=transformer_dtype) + ) + + in_latents = cond_mask * cond_latent + (1 - cond_mask) * latents + in_latents = in_latents.to(transformer_dtype) + in_timestep = cond_indicator * cond_timestep + (1 - cond_indicator) * sigma_t + control_output = self.controlnet( + controls_latents=controls_latents, + latents=in_latents, + timestep=in_timestep, + encoder_hidden_states=encoder_hidden_states, + condition_mask=cond_mask, + conditioning_scale=controls_conditioning_scale, + padding_mask=padding_mask, + return_dict=False, + ) + control_blocks = control_output[0] + + noise_pred = self.transformer( + hidden_states=in_latents, + timestep=in_timestep, + encoder_hidden_states=encoder_hidden_states, + block_controlnet_hidden_states=control_blocks, + condition_mask=cond_mask, + padding_mask=padding_mask, + return_dict=False, + )[0] + noise_pred = gt_velocity + noise_pred * (1 - cond_mask) + + if self.do_classifier_free_guidance: + control_output = self.controlnet( + controls_latents=controls_latents, + latents=in_latents, + timestep=in_timestep, + encoder_hidden_states=neg_encoder_hidden_states, # NOTE: negative prompt + condition_mask=cond_mask, + conditioning_scale=controls_conditioning_scale, + padding_mask=padding_mask, + return_dict=False, + ) + control_blocks = control_output[0] + + noise_pred_neg = self.transformer( + hidden_states=in_latents, + timestep=in_timestep, + encoder_hidden_states=neg_encoder_hidden_states, # NOTE: negative prompt + block_controlnet_hidden_states=control_blocks, + condition_mask=cond_mask, + padding_mask=padding_mask, + return_dict=False, + )[0] + # NOTE: replace velocity (noise_pred_neg) with gt_velocity for conditioning inputs only + noise_pred_neg = gt_velocity + noise_pred_neg * (1 - cond_mask) + noise_pred = noise_pred + self.guidance_scale * (noise_pred - noise_pred_neg) + + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == total_steps - 1 or ((i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + video_chunks.append(decode_latents(latents).detach().cpu()) + latent_chunks.append(latents.detach().cpu()) + + self._current_timestep = None + + if not output_type == "latent": + video_chunks = [ + chunk[:, :, num_ar_conditional_frames:, ...] if chunk_idx != 0 else chunk + for chunk_idx, chunk in enumerate(video_chunks) + ] + video = torch.cat(video_chunks, dim=2) + video = video[:, :, :num_frames_out, ...] + + assert self.safety_checker is not None + self.safety_checker.to(device) + video = self.video_processor.postprocess_video(video, output_type="np") + video = (video * 255).astype(np.uint8) + video_batch = [] + for vid in video: + vid = self.safety_checker.check_video_safety(vid) + if vid is None: + video_batch.append(np.zeros_like(video[0])) + else: + video_batch.append(vid) + video = np.stack(video_batch).astype(np.float32) / 255.0 * 2 - 1 + video = torch.from_numpy(video).permute(0, 4, 1, 2, 3) + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + latent_T = (num_frames_out - 1) // self.vae_scale_factor_temporal + 1 + latent_chunks = [ + chunk[:, :, num_cond_latent_frames:, ...] if chunk_idx != 0 else chunk + for chunk_idx, chunk in enumerate(latent_chunks) + ] + video = torch.cat(latent_chunks, dim=2) + video = video[:, :, :latent_T, ...] + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return CosmosPipelineOutput(frames=video) diff --git a/diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py b/diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py new file mode 100644 index 0000000000000000000000000000000000000000..f24e19eea0d4bab2e404eb1135d0d83dd8fa0c53 --- /dev/null +++ b/diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py @@ -0,0 +1,679 @@ +# Copyright 2025 The NVIDIA Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable + +import numpy as np +import torch +from transformers import T5EncoderModel, T5TokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...models import AutoencoderKLWan, CosmosTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_cosmos_guardrail_available, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import CosmosImagePipelineOutput + + +if is_cosmos_guardrail_available(): + from cosmos_guardrail import CosmosSafetyChecker +else: + + class CosmosSafetyChecker: + def __init__(self, *args, **kwargs): + raise ImportError( + "`cosmos_guardrail` is not installed. Please install it to use the safety checker for Cosmos: `pip install cosmos_guardrail`." + ) + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +DEFAULT_NEGATIVE_PROMPT = ( + "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, " + "over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, " + "underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, " + "jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, " + "fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. " + "Overall, the video is of poor quality." +) + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import Cosmos2TextToImagePipeline + + >>> # Available checkpoints: nvidia/Cosmos-Predict2-2B-Text2Image, nvidia/Cosmos-Predict2-14B-Text2Image + >>> model_id = "nvidia/Cosmos-Predict2-2B-Text2Image" + >>> pipe = Cosmos2TextToImagePipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> prompt = "A close-up shot captures a vibrant yellow scrubber vigorously working on a grimy plate, its bristles moving in circular motions to lift stubborn grease and food residue. The dish, once covered in remnants of a hearty meal, gradually reveals its original glossy surface. Suds form and bubble around the scrubber, creating a satisfying visual of cleanliness in progress. The sound of scrubbing fills the air, accompanied by the gentle clinking of the dish against the sink. As the scrubber continues its task, the dish transforms, gleaming under the bright kitchen lights, symbolizing the triumph of cleanliness over mess." + >>> negative_prompt = "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality." + + >>> output = pipe( + ... prompt=prompt, negative_prompt=negative_prompt, generator=torch.Generator().manual_seed(1) + ... ).images[0] + >>> output.save("output.png") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class Cosmos2TextToImagePipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using [Cosmos Predict2](https://github.com/nvidia-cosmos/cosmos-predict2). + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. Cosmos uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the + [t5-11b](https://huggingface.co/google-t5/t5-11b) variant. + tokenizer (`T5TokenizerFast`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`CosmosTransformer3DModel`]): + Conditional Transformer to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLWan`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + # We mark safety_checker as optional here to get around some test failures, but it is not really optional + _optional_components = ["safety_checker"] + + def __init__( + self, + text_encoder: T5EncoderModel, + tokenizer: T5TokenizerFast, + transformer: CosmosTransformer3DModel, + vae: AutoencoderKLWan, + scheduler: FlowMatchEulerDiscreteScheduler, + safety_checker: CosmosSafetyChecker = None, + ): + super().__init__() + + if safety_checker is None: + safety_checker = CosmosSafetyChecker() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + safety_checker=safety_checker, + ) + + self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + self.sigma_max = 80.0 + self.sigma_min = 0.002 + self.sigma_data = 1.0 + self.final_sigmas_type = "sigma_min" + if self.scheduler is not None: + self.scheduler.register_to_config( + sigma_max=self.sigma_max, + sigma_min=self.sigma_min, + sigma_data=self.sigma_data, + final_sigmas_type=self.final_sigmas_type, + ) + + # Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + max_sequence_length: int = 512, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + prompt = [prompt] if isinstance(prompt, str) else prompt + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + return_length=True, + return_offsets_mapping=False, + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask.bool().to(device) + + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=prompt_attention_mask + ).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + lengths = prompt_attention_mask.sum(dim=1).cpu() + for i, length in enumerate(lengths): + prompt_embeds[i, length:] = 0 + + return prompt_embeds + + # Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.encode_prompt with num_videos_per_prompt->num_images_per_prompt + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_images_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 512, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_images_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt if negative_prompt is not None else DEFAULT_NEGATIVE_PROMPT + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = negative_prompt_embeds.shape + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: 16, + height: int = 768, + width: int = 1360, + num_frames: int = 1, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) * self.scheduler.config.sigma_max + + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + latent_height = height // self.vae_scale_factor_spatial + latent_width = width // self.vae_scale_factor_spatial + shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents * self.scheduler.config.sigma_max + + # Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str | list[str] | None = None, + height: int = 768, + width: int = 1360, + num_inference_steps: int = 35, + guidance_scale: float = 7.0, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, defaults to `768`): + The height in pixels of the generated image. + width (`int`, defaults to `1360`): + The width in pixels of the generated image. + num_inference_steps (`int`, defaults to `35`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to `7.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`CosmosImagePipelineOutput`] instead of a plain tuple. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~CosmosImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`CosmosImagePipelineOutput`] is returned, otherwise a `tuple` is returned + where the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + if self.safety_checker is None: + raise ValueError( + f"You have disabled the safety checker for {self.__class__}. This is in violation of the " + "[NVIDIA Open Model License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license). " + f"Please ensure that you are compliant with the license agreement." + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + num_frames = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, prompt_embeds, callback_on_step_end_tensor_inputs) + + self._guidance_scale = guidance_scale + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + if self.safety_checker is not None: + self.safety_checker.to(device) + if prompt is not None: + prompt_list = [prompt] if isinstance(prompt, str) else prompt + for p in prompt_list: + if not self.safety_checker.check_text_safety(p): + raise ValueError( + f"Cosmos Guardrail detected unsafe text in the prompt: {p}. Please ensure that the " + f"prompt abides by the NVIDIA Open Model License Agreement." + ) + self.safety_checker.to("cpu") + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + ( + prompt_embeds, + negative_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_images_per_prompt=num_images_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + device=device, + max_sequence_length=max_sequence_length, + ) + + # 4. Prepare timesteps + sigmas_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64 + sigmas = torch.linspace(0, 1, num_inference_steps, dtype=sigmas_dtype) + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, device=device, sigmas=sigmas) + if self.scheduler.config.get("final_sigmas_type", "zero") == "sigma_min": + # Replace the last sigma (which is zero) with the minimum sigma value + self.scheduler.sigmas[-1] = self.scheduler.sigmas[-2] + + # 5. Prepare latent variables + transformer_dtype = self.transformer.dtype + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + ) + + padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype) + + # 6. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + current_sigma = self.scheduler.sigmas[i] + + current_t = current_sigma / (current_sigma + 1) + c_in = 1 - current_t + c_skip = 1 - current_t + c_out = -current_t + timestep = current_t.expand(latents.shape[0]).to(transformer_dtype) # [B, 1, T, 1, 1] + + latent_model_input = latents * c_in + latent_model_input = latent_model_input.to(transformer_dtype) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + padding_mask=padding_mask, + return_dict=False, + )[0] + noise_pred = (c_skip * latents + c_out * noise_pred.float()).to(transformer_dtype) + + if self.do_classifier_free_guidance: + noise_pred_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + padding_mask=padding_mask, + return_dict=False, + )[0] + noise_pred_uncond = (c_skip * latents + c_out * noise_pred_uncond.float()).to(transformer_dtype) + noise_pred = noise_pred + self.guidance_scale * (noise_pred - noise_pred_uncond) + + noise_pred = (latents - noise_pred) / current_sigma + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if not output_type == "latent": + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std / self.scheduler.config.sigma_data + latents_mean + video = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0] + + if self.safety_checker is not None: + self.safety_checker.to(device) + video = self.video_processor.postprocess_video(video, output_type="np") + video = (video * 255).astype(np.uint8) + video_batch = [] + for vid in video: + vid = self.safety_checker.check_video_safety(vid) + video_batch.append(vid) + video = np.stack(video_batch).astype(np.float32) / 255.0 * 2 - 1 + video = torch.from_numpy(video).permute(0, 4, 1, 2, 3) + video = self.video_processor.postprocess_video(video, output_type=output_type) + self.safety_checker.to("cpu") + else: + video = self.video_processor.postprocess_video(video, output_type=output_type) + image = [batch[0] for batch in video] + if isinstance(video, torch.Tensor): + image = torch.stack(image) + elif isinstance(video, np.ndarray): + image = np.stack(image) + else: + image = latents[:, :, 0] + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return CosmosImagePipelineOutput(images=image) diff --git a/diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py b/diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py new file mode 100644 index 0000000000000000000000000000000000000000..bdb13af06637fec9a7b3b81a874b3c18f08d93dd --- /dev/null +++ b/diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py @@ -0,0 +1,798 @@ +# Copyright 2025 The NVIDIA Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable + +import numpy as np +import torch +from transformers import T5EncoderModel, T5TokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput +from ...models import AutoencoderKLWan, CosmosTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_cosmos_guardrail_available, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import CosmosPipelineOutput + + +if is_cosmos_guardrail_available(): + from cosmos_guardrail import CosmosSafetyChecker +else: + + class CosmosSafetyChecker: + def __init__(self, *args, **kwargs): + raise ImportError( + "`cosmos_guardrail` is not installed. Please install it to use the safety checker for Cosmos: `pip install cosmos_guardrail`." + ) + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +DEFAULT_NEGATIVE_PROMPT = ( + "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, " + "over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, " + "underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, " + "jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, " + "fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. " + "Overall, the video is of poor quality." +) + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import Cosmos2VideoToWorldPipeline + >>> from diffusers.utils import export_to_video, load_image + + >>> # Available checkpoints: nvidia/Cosmos-Predict2-2B-Video2World, nvidia/Cosmos-Predict2-14B-Video2World + >>> model_id = "nvidia/Cosmos-Predict2-2B-Video2World" + >>> pipe = Cosmos2VideoToWorldPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> prompt = "A close-up shot captures a vibrant yellow scrubber vigorously working on a grimy plate, its bristles moving in circular motions to lift stubborn grease and food residue. The dish, once covered in remnants of a hearty meal, gradually reveals its original glossy surface. Suds form and bubble around the scrubber, creating a satisfying visual of cleanliness in progress. The sound of scrubbing fills the air, accompanied by the gentle clinking of the dish against the sink. As the scrubber continues its task, the dish transforms, gleaming under the bright kitchen lights, symbolizing the triumph of cleanliness over mess." + >>> negative_prompt = "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality." + >>> image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yellow-scrubber.png" + ... ) + + >>> video = pipe( + ... image=image, prompt=prompt, negative_prompt=negative_prompt, generator=torch.Generator().manual_seed(1) + ... ).frames[0] + >>> export_to_video(video, "output.mp4", fps=16) + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class Cosmos2VideoToWorldPipeline(DiffusionPipeline): + r""" + Pipeline for video-to-world generation using [Cosmos Predict2](https://github.com/nvidia-cosmos/cosmos-predict2). + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. Cosmos uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the + [t5-11b](https://huggingface.co/google-t5/t5-11b) variant. + tokenizer (`T5TokenizerFast`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`CosmosTransformer3DModel`]): + Conditional Transformer to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLWan`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + # We mark safety_checker as optional here to get around some test failures, but it is not really optional + _optional_components = ["safety_checker"] + + def __init__( + self, + text_encoder: T5EncoderModel, + tokenizer: T5TokenizerFast, + transformer: CosmosTransformer3DModel, + vae: AutoencoderKLWan, + scheduler: FlowMatchEulerDiscreteScheduler, + safety_checker: CosmosSafetyChecker = None, + ): + super().__init__() + + if safety_checker is None: + safety_checker = CosmosSafetyChecker() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + safety_checker=safety_checker, + ) + + self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + self.sigma_max = 80.0 + self.sigma_min = 0.002 + self.sigma_data = 1.0 + self.final_sigmas_type = "sigma_min" + if self.scheduler is not None: + self.scheduler.register_to_config( + sigma_max=self.sigma_max, + sigma_min=self.sigma_min, + sigma_data=self.sigma_data, + final_sigmas_type=self.final_sigmas_type, + ) + + # Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + max_sequence_length: int = 512, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + prompt = [prompt] if isinstance(prompt, str) else prompt + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + return_length=True, + return_offsets_mapping=False, + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask.bool().to(device) + + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=prompt_attention_mask + ).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + lengths = prompt_attention_mask.sum(dim=1).cpu() + for i, length in enumerate(lengths): + prompt_embeds[i, length:] = 0 + + return prompt_embeds + + # Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 512, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt if negative_prompt is not None else DEFAULT_NEGATIVE_PROMPT + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = negative_prompt_embeds.shape + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, + video: torch.Tensor, + batch_size: int, + num_channels_latents: 16, + height: int = 704, + width: int = 1280, + num_frames: int = 93, + do_classifier_free_guidance: bool = True, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + num_cond_frames = video.size(2) + if num_cond_frames >= num_frames: + # Take the last `num_frames` frames for conditioning + num_cond_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + video = video[:, :, -num_frames:] + else: + num_cond_latent_frames = (num_cond_frames - 1) // self.vae_scale_factor_temporal + 1 + num_padding_frames = num_frames - num_cond_frames + last_frame = video[:, :, -1:] + padding = last_frame.repeat(1, 1, num_padding_frames, 1, 1) + video = torch.cat([video, padding], dim=2) + + if isinstance(generator, list): + init_latents = [ + retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator=generator[i]) + for i in range(batch_size) + ] + else: + init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video] + + init_latents = torch.cat(init_latents, dim=0).to(dtype) + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1).to(device, dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(device, dtype) + ) + init_latents = (init_latents - latents_mean) / latents_std * self.scheduler.config.sigma_data + + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + latent_height = height // self.vae_scale_factor_spatial + latent_width = width // self.vae_scale_factor_spatial + shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + latents = latents * self.scheduler.config.sigma_max + + padding_shape = (batch_size, 1, num_latent_frames, latent_height, latent_width) + ones_padding = latents.new_ones(padding_shape) + zeros_padding = latents.new_zeros(padding_shape) + + cond_indicator = latents.new_zeros(1, 1, latents.size(2), 1, 1) + cond_indicator[:, :, :num_cond_latent_frames] = 1.0 + cond_mask = cond_indicator * ones_padding + (1 - cond_indicator) * zeros_padding + + uncond_indicator = uncond_mask = None + if do_classifier_free_guidance: + uncond_indicator = latents.new_zeros(1, 1, latents.size(2), 1, 1) + uncond_indicator[:, :, :num_cond_latent_frames] = 1.0 + uncond_mask = uncond_indicator * ones_padding + (1 - uncond_indicator) * zeros_padding + + return latents, init_latents, cond_indicator, uncond_indicator, cond_mask, uncond_mask + + # Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput = None, + video: list[PipelineImageInput] = None, + prompt: str | list[str] = None, + negative_prompt: str | list[str] | None = None, + height: int = 704, + width: int = 1280, + num_frames: int = 93, + num_inference_steps: int = 35, + guidance_scale: float = 7.0, + fps: int = 16, + num_videos_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + sigma_conditioning: float = 0.0001, + ): + r""" + The call function to the pipeline for generation. + + Args: + image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, *optional*): + The image to be used as a conditioning input for the video generation. + video (`list[PIL.Image.Image]`, `np.ndarray`, `torch.Tensor`, *optional*): + The video to be used as a conditioning input for the video generation. + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, defaults to `704`): + The height in pixels of the generated image. + width (`int`, defaults to `1280`): + The width in pixels of the generated image. + num_frames (`int`, defaults to `93`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `35`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to `7.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. + fps (`int`, defaults to `16`): + The frames per second of the generated video. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`CosmosPipelineOutput`] instead of a plain tuple. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `512`): + The maximum number of tokens in the prompt. If the prompt exceeds this length, it will be truncated. If + the prompt is shorter than this length, it will be padded. + sigma_conditioning (`float`, defaults to `0.0001`): + The sigma value used for scaling conditioning latents. Ideally, it should not be changed or should be + set to a small value close to zero. + + Examples: + + Returns: + [`~CosmosPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`CosmosPipelineOutput`] is returned, otherwise a `tuple` is returned where + the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + if self.safety_checker is None: + raise ValueError( + f"You have disabled the safety checker for {self.__class__}. This is in violation of the " + "[NVIDIA Open Model License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license). " + f"Please ensure that you are compliant with the license agreement." + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, prompt_embeds, callback_on_step_end_tensor_inputs) + + self._guidance_scale = guidance_scale + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + if self.safety_checker is not None: + self.safety_checker.to(device) + if prompt is not None: + prompt_list = [prompt] if isinstance(prompt, str) else prompt + for p in prompt_list: + if not self.safety_checker.check_text_safety(p): + raise ValueError( + f"Cosmos Guardrail detected unsafe text in the prompt: {p}. Please ensure that the " + f"prompt abides by the NVIDIA Open Model License Agreement." + ) + self.safety_checker.to("cpu") + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + ( + prompt_embeds, + negative_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + device=device, + max_sequence_length=max_sequence_length, + ) + + # 4. Prepare timesteps + sigmas_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64 + sigmas = torch.linspace(0, 1, num_inference_steps, dtype=sigmas_dtype) + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, device=device, sigmas=sigmas) + if self.scheduler.config.final_sigmas_type == "sigma_min": + # Replace the last sigma (which is zero) with the minimum sigma value + self.scheduler.sigmas[-1] = self.scheduler.sigmas[-2] + + # 5. Prepare latent variables + vae_dtype = self.vae.dtype + transformer_dtype = self.transformer.dtype + + if image is not None: + video = self.video_processor.preprocess(image, height, width).unsqueeze(2) + else: + video = self.video_processor.preprocess_video(video, height, width) + video = video.to(device=device, dtype=vae_dtype) + + num_channels_latents = self.transformer.config.in_channels - 1 + latents, conditioning_latents, cond_indicator, uncond_indicator, cond_mask, uncond_mask = self.prepare_latents( + video, + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + self.do_classifier_free_guidance, + torch.float32, + device, + generator, + latents, + ) + unconditioning_latents = None + + cond_mask = cond_mask.to(transformer_dtype) + if self.do_classifier_free_guidance: + uncond_mask = uncond_mask.to(transformer_dtype) + unconditioning_latents = conditioning_latents + + padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype) + sigma_conditioning = torch.tensor(sigma_conditioning, dtype=torch.float32, device=device) + t_conditioning = sigma_conditioning / (sigma_conditioning + 1) + + # 6. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + current_sigma = self.scheduler.sigmas[i] + + current_t = current_sigma / (current_sigma + 1) + c_in = 1 - current_t + c_skip = 1 - current_t + c_out = -current_t + timestep = current_t.view(1, 1, 1, 1, 1).expand( + latents.size(0), -1, latents.size(2), -1, -1 + ) # [B, 1, T, 1, 1] + + cond_latent = latents * c_in + cond_latent = cond_indicator * conditioning_latents + (1 - cond_indicator) * cond_latent + cond_latent = cond_latent.to(transformer_dtype) + cond_timestep = cond_indicator * t_conditioning + (1 - cond_indicator) * timestep + cond_timestep = cond_timestep.to(transformer_dtype) + + noise_pred = self.transformer( + hidden_states=cond_latent, + timestep=cond_timestep, + encoder_hidden_states=prompt_embeds, + fps=fps, + condition_mask=cond_mask, + padding_mask=padding_mask, + return_dict=False, + )[0] + noise_pred = (c_skip * latents + c_out * noise_pred.float()).to(transformer_dtype) + noise_pred = cond_indicator * conditioning_latents + (1 - cond_indicator) * noise_pred + + if self.do_classifier_free_guidance: + uncond_latent = latents * c_in + uncond_latent = uncond_indicator * unconditioning_latents + (1 - uncond_indicator) * uncond_latent + uncond_latent = uncond_latent.to(transformer_dtype) + uncond_timestep = uncond_indicator * t_conditioning + (1 - uncond_indicator) * timestep + uncond_timestep = uncond_timestep.to(transformer_dtype) + + noise_pred_uncond = self.transformer( + hidden_states=uncond_latent, + timestep=uncond_timestep, + encoder_hidden_states=negative_prompt_embeds, + fps=fps, + condition_mask=uncond_mask, + padding_mask=padding_mask, + return_dict=False, + )[0] + noise_pred_uncond = (c_skip * latents + c_out * noise_pred_uncond.float()).to(transformer_dtype) + noise_pred_uncond = ( + uncond_indicator * unconditioning_latents + (1 - uncond_indicator) * noise_pred_uncond + ) + noise_pred = noise_pred + self.guidance_scale * (noise_pred - noise_pred_uncond) + + noise_pred = (latents - noise_pred) / current_sigma + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if not output_type == "latent": + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents = latents * latents_std / self.scheduler.config.sigma_data + latents_mean + video = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0] + + if self.safety_checker is not None: + self.safety_checker.to(device) + video = self.video_processor.postprocess_video(video, output_type="np") + video = (video * 255).astype(np.uint8) + video_batch = [] + for vid in video: + vid = self.safety_checker.check_video_safety(vid) + video_batch.append(vid) + video = np.stack(video_batch).astype(np.float32) / 255.0 * 2 - 1 + video = torch.from_numpy(video).permute(0, 4, 1, 2, 3) + video = self.video_processor.postprocess_video(video, output_type=output_type) + self.safety_checker.to("cpu") + else: + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return CosmosPipelineOutput(frames=video) diff --git a/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py b/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py new file mode 100644 index 0000000000000000000000000000000000000000..e144d62d5933fa7a59c4ab63f330c20928e66053 --- /dev/null +++ b/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py @@ -0,0 +1,670 @@ +# Copyright 2025 The NVIDIA Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable + +import numpy as np +import torch +from transformers import T5EncoderModel, T5TokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...models import AutoencoderKLCosmos, CosmosTransformer3DModel +from ...schedulers import EDMEulerScheduler +from ...utils import is_cosmos_guardrail_available, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import CosmosPipelineOutput + + +if is_cosmos_guardrail_available(): + from cosmos_guardrail import CosmosSafetyChecker +else: + + class CosmosSafetyChecker: + def __init__(self, *args, **kwargs): + raise ImportError( + "`cosmos_guardrail` is not installed. Please install it to use the safety checker for Cosmos: `pip install cosmos_guardrail`." + ) + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +DEFAULT_NEGATIVE_PROMPT = ( + "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, " + "over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, " + "underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, " + "jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, " + "fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. " + "Overall, the video is of poor quality." +) + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import CosmosTextToWorldPipeline + >>> from diffusers.utils import export_to_video + + >>> model_id = "nvidia/Cosmos-1.0-Diffusion-7B-Text2World" + >>> pipe = CosmosTextToWorldPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> prompt = "A sleek, humanoid robot stands in a vast warehouse filled with neatly stacked cardboard boxes on industrial shelves. The robot's metallic body gleams under the bright, even lighting, highlighting its futuristic design and intricate joints. A glowing blue light emanates from its chest, adding a touch of advanced technology. The background is dominated by rows of boxes, suggesting a highly organized storage system. The floor is lined with wooden pallets, enhancing the industrial setting. The camera remains static, capturing the robot's poised stance amidst the orderly environment, with a shallow depth of field that keeps the focus on the robot while subtly blurring the background for a cinematic effect." + + >>> output = pipe(prompt=prompt).frames[0] + >>> export_to_video(output, "output.mp4", fps=30) + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class CosmosTextToWorldPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-world generation using [Cosmos Predict1](https://github.com/nvidia-cosmos/cosmos-predict1). + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. Cosmos uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the + [t5-11b](https://huggingface.co/google-t5/t5-11b) variant. + tokenizer (`T5TokenizerFast`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`CosmosTransformer3DModel`]): + Conditional Transformer to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLCosmos`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + # We mark safety_checker as optional here to get around some test failures, but it is not really optional + _optional_components = ["safety_checker"] + + def __init__( + self, + text_encoder: T5EncoderModel, + tokenizer: T5TokenizerFast, + transformer: CosmosTransformer3DModel, + vae: AutoencoderKLCosmos, + scheduler: EDMEulerScheduler, + safety_checker: CosmosSafetyChecker = None, + ): + super().__init__() + + if safety_checker is None: + safety_checker = CosmosSafetyChecker() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + safety_checker=safety_checker, + ) + + self.vae_scale_factor_temporal = ( + self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 8 + ) + self.vae_scale_factor_spatial = self.vae.config.spatial_compression_ratio if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + max_sequence_length: int = 512, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + prompt = [prompt] if isinstance(prompt, str) else prompt + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + return_length=True, + return_offsets_mapping=False, + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask.bool().to(device) + + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=prompt_attention_mask + ).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + lengths = prompt_attention_mask.sum(dim=1).cpu() + for i, length in enumerate(lengths): + prompt_embeds[i, length:] = 0 + + return prompt_embeds + + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 512, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt if negative_prompt is not None else DEFAULT_NEGATIVE_PROMPT + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = negative_prompt_embeds.shape + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: 16, + height: int = 704, + width: int = 1280, + num_frames: int = 121, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) * self.scheduler.config.sigma_max + + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + latent_height = height // self.vae_scale_factor_spatial + latent_width = width // self.vae_scale_factor_spatial + shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents * self.scheduler.config.sigma_max + + def check_inputs( + self, + prompt, + height, + width, + prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str | list[str] | None = None, + height: int = 704, + width: int = 1280, + num_frames: int = 121, + num_inference_steps: int = 36, + guidance_scale: float = 7.0, + fps: int = 30, + num_videos_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, defaults to `720`): + The height in pixels of the generated image. + width (`int`, defaults to `1280`): + The width in pixels of the generated image. + num_frames (`int`, defaults to `121`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `36`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to `7.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. + fps (`int`, defaults to `30`): + The frames per second of the generated video. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`CosmosPipelineOutput`] instead of a plain tuple. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~CosmosPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`CosmosPipelineOutput`] is returned, otherwise a `tuple` is returned where + the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + if self.safety_checker is None: + raise ValueError( + f"You have disabled the safety checker for {self.__class__}. This is in violation of the " + "[NVIDIA Open Model License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license). " + f"Please ensure that you are compliant with the license agreement." + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, prompt_embeds, callback_on_step_end_tensor_inputs) + + self._guidance_scale = guidance_scale + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + if self.safety_checker is not None: + self.safety_checker.to(device) + if prompt is not None: + prompt_list = [prompt] if isinstance(prompt, str) else prompt + for p in prompt_list: + if not self.safety_checker.check_text_safety(p): + raise ValueError( + f"Cosmos Guardrail detected unsafe text in the prompt: {p}. Please ensure that the " + f"prompt abides by the NVIDIA Open Model License Agreement." + ) + self.safety_checker.to("cpu") + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + ( + prompt_embeds, + negative_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + device=device, + max_sequence_length=max_sequence_length, + ) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device) + + # 5. Prepare latent variables + transformer_dtype = self.transformer.dtype + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + ) + + padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype) + + # 6. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + timestep = t.expand(latents.shape[0]).to(transformer_dtype) + + latent_model_input = latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = latent_model_input.to(transformer_dtype) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + fps=fps, + padding_mask=padding_mask, + return_dict=False, + )[0] + + sample = latents + if self.do_classifier_free_guidance: + noise_pred_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + fps=fps, + padding_mask=padding_mask, + return_dict=False, + )[0] + noise_pred = torch.cat([noise_pred_uncond, noise_pred]) + sample = torch.cat([sample, sample]) + + # pred_original_sample (x0) + noise_pred = self.scheduler.step(noise_pred, t, sample, return_dict=False)[1] + self.scheduler._step_index -= 1 + + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) + noise_pred = noise_pred_cond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) + + # pred_sample (eps) + latents = self.scheduler.step( + noise_pred, t, latents, return_dict=False, pred_original_sample=noise_pred + )[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if not output_type == "latent": + if self.vae.config.latents_mean is not None: + latents_mean, latents_std = self.vae.config.latents_mean, self.vae.config.latents_std + latents_mean = ( + torch.tensor(latents_mean) + .view(1, self.vae.config.latent_channels, -1, 1, 1)[:, :, : latents.size(2)] + .to(latents) + ) + latents_std = ( + torch.tensor(latents_std) + .view(1, self.vae.config.latent_channels, -1, 1, 1)[:, :, : latents.size(2)] + .to(latents) + ) + latents = latents * latents_std / self.scheduler.config.sigma_data + latents_mean + else: + latents = latents / self.scheduler.config.sigma_data + video = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0] + + if self.safety_checker is not None: + self.safety_checker.to(device) + video = self.video_processor.postprocess_video(video, output_type="np") + video = (video * 255).astype(np.uint8) + video_batch = [] + for vid in video: + vid = self.safety_checker.check_video_safety(vid) + video_batch.append(vid) + video = np.stack(video_batch).astype(np.float32) / 255.0 * 2 - 1 + video = torch.from_numpy(video).permute(0, 4, 1, 2, 3) + video = self.video_processor.postprocess_video(video, output_type=output_type) + self.safety_checker.to("cpu") + else: + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return CosmosPipelineOutput(frames=video) diff --git a/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py b/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py new file mode 100644 index 0000000000000000000000000000000000000000..377c3c05d2849b9f12796160cc46ddf5e35db66e --- /dev/null +++ b/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py @@ -0,0 +1,832 @@ +# Copyright 2025 The NVIDIA Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable + +import numpy as np +import torch +from transformers import T5EncoderModel, T5TokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput +from ...models import AutoencoderKLCosmos, CosmosTransformer3DModel +from ...schedulers import EDMEulerScheduler +from ...utils import is_cosmos_guardrail_available, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import CosmosPipelineOutput + + +if is_cosmos_guardrail_available(): + from cosmos_guardrail import CosmosSafetyChecker +else: + + class CosmosSafetyChecker: + def __init__(self, *args, **kwargs): + raise ImportError( + "`cosmos_guardrail` is not installed. Please install it to use the safety checker for Cosmos: `pip install cosmos_guardrail`." + ) + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +DEFAULT_NEGATIVE_PROMPT = ( + "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, " + "over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, " + "underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, " + "jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, " + "fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. " + "Overall, the video is of poor quality." +) + +EXAMPLE_DOC_STRING = """ + Examples: + Image conditioning: + + ```python + >>> import torch + >>> from diffusers import CosmosVideoToWorldPipeline + >>> from diffusers.utils import export_to_video, load_image + + >>> model_id = "nvidia/Cosmos-1.0-Diffusion-7B-Video2World" + >>> pipe = CosmosVideoToWorldPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> prompt = "The video depicts a long, straight highway stretching into the distance, flanked by metal guardrails. The road is divided into multiple lanes, with a few vehicles visible in the far distance. The surrounding landscape features dry, grassy fields on one side and rolling hills on the other. The sky is mostly clear with a few scattered clouds, suggesting a bright, sunny day." + >>> image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input.jpg" + ... ) + + >>> video = pipe(image=image, prompt=prompt).frames[0] + >>> export_to_video(video, "output.mp4", fps=30) + ``` + + Video conditioning: + + ```python + >>> import torch + >>> from diffusers import CosmosVideoToWorldPipeline + >>> from diffusers.utils import export_to_video, load_video + + >>> model_id = "nvidia/Cosmos-1.0-Diffusion-7B-Video2World" + >>> pipe = CosmosVideoToWorldPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16) + >>> pipe.transformer = torch.compile(pipe.transformer) + >>> pipe.to("cuda") + + >>> prompt = "The video depicts a winding mountain road covered in snow, with a single vehicle traveling along it. The road is flanked by steep, rocky cliffs and sparse vegetation. The landscape is characterized by rugged terrain and a river visible in the distance. The scene captures the solitude and beauty of a winter drive through a mountainous region." + >>> video = load_video( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input-vid.mp4" + ... )[ + ... :21 + ... ] # This example uses only the first 21 frames + + >>> video = pipe(video=video, prompt=prompt).frames[0] + >>> export_to_video(video, "output.mp4", fps=30) + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class CosmosVideoToWorldPipeline(DiffusionPipeline): + r""" + Pipeline for image-to-world and video-to-world generation using [Cosmos + Predict-1](https://github.com/nvidia-cosmos/cosmos-predict1). + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. Cosmos uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the + [t5-11b](https://huggingface.co/google-t5/t5-11b) variant. + tokenizer (`T5TokenizerFast`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`CosmosTransformer3DModel`]): + Conditional Transformer to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLCosmos`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + # We mark safety_checker as optional here to get around some test failures, but it is not really optional + _optional_components = ["safety_checker"] + + def __init__( + self, + text_encoder: T5EncoderModel, + tokenizer: T5TokenizerFast, + transformer: CosmosTransformer3DModel, + vae: AutoencoderKLCosmos, + scheduler: EDMEulerScheduler, + safety_checker: CosmosSafetyChecker = None, + ): + super().__init__() + + if safety_checker is None: + safety_checker = CosmosSafetyChecker() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + safety_checker=safety_checker, + ) + + self.vae_scale_factor_temporal = ( + self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 8 + ) + self.vae_scale_factor_spatial = self.vae.config.spatial_compression_ratio if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + # Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + max_sequence_length: int = 512, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + prompt = [prompt] if isinstance(prompt, str) else prompt + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + return_length=True, + return_offsets_mapping=False, + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask.bool().to(device) + + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=prompt_attention_mask + ).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + lengths = prompt_attention_mask.sum(dim=1).cpu() + for i, length in enumerate(lengths): + prompt_embeds[i, length:] = 0 + + return prompt_embeds + + # Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 512, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt if negative_prompt is not None else DEFAULT_NEGATIVE_PROMPT + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = negative_prompt_embeds.shape + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, + video: torch.Tensor, + batch_size: int, + num_channels_latents: 16, + height: int = 704, + width: int = 1280, + num_frames: int = 121, + do_classifier_free_guidance: bool = True, + input_frames_guidance: bool = False, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + num_cond_frames = video.size(2) + if num_cond_frames >= num_frames: + # Take the last `num_frames` frames for conditioning + num_cond_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + video = video[:, :, -num_frames:] + else: + num_cond_latent_frames = (num_cond_frames - 1) // self.vae_scale_factor_temporal + 1 + num_padding_frames = num_frames - num_cond_frames + padding = video.new_zeros(video.size(0), video.size(1), num_padding_frames, video.size(3), video.size(4)) + video = torch.cat([video, padding], dim=2) + + if isinstance(generator, list): + init_latents = [ + retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator=generator[i]) + for i in range(batch_size) + ] + else: + init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video] + + init_latents = torch.cat(init_latents, dim=0).to(dtype) + + if self.vae.config.latents_mean is not None: + latents_mean, latents_std = self.vae.config.latents_mean, self.vae.config.latents_std + latents_mean = ( + torch.tensor(latents_mean) + .view(1, self.vae.config.latent_channels, -1, 1, 1)[:, :, : init_latents.size(2)] + .to(init_latents) + ) + latents_std = ( + torch.tensor(latents_std) + .view(1, self.vae.config.latent_channels, -1, 1, 1)[:, :, : init_latents.size(2)] + .to(init_latents) + ) + init_latents = (init_latents - latents_mean) * self.scheduler.config.sigma_data / latents_std + else: + init_latents = init_latents * self.scheduler.config.sigma_data + + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + latent_height = height // self.vae_scale_factor_spatial + latent_width = width // self.vae_scale_factor_spatial + shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + latents = latents * self.scheduler.config.sigma_max + + padding_shape = (batch_size, 1, num_latent_frames, latent_height, latent_width) + ones_padding = latents.new_ones(padding_shape) + zeros_padding = latents.new_zeros(padding_shape) + + cond_indicator = latents.new_zeros(1, 1, latents.size(2), 1, 1) + cond_indicator[:, :, :num_cond_latent_frames] = 1.0 + cond_mask = cond_indicator * ones_padding + (1 - cond_indicator) * zeros_padding + + uncond_indicator = uncond_mask = None + if do_classifier_free_guidance: + uncond_indicator = latents.new_zeros(1, 1, latents.size(2), 1, 1) + uncond_indicator[:, :, :num_cond_latent_frames] = 1.0 + uncond_mask = zeros_padding + if not input_frames_guidance: + uncond_mask = uncond_indicator * ones_padding + (1 - uncond_indicator) * zeros_padding + + return latents, init_latents, cond_indicator, uncond_indicator, cond_mask, uncond_mask + + def check_inputs( + self, + prompt, + height, + width, + prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + image=None, + video=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if image is None and video is None: + raise ValueError("Either `image` or `video` has to be provided.") + if image is not None and video is not None: + raise ValueError("Only one of `image` or `video` has to be provided.") + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput = None, + video: list[PipelineImageInput] = None, + prompt: str | list[str] = None, + negative_prompt: str | list[str] | None = None, + height: int = 704, + width: int = 1280, + num_frames: int = 121, + num_inference_steps: int = 36, + guidance_scale: float = 7.0, + input_frames_guidance: bool = False, + augment_sigma: float = 0.001, + fps: int = 30, + num_videos_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, defaults to `720`): + The height in pixels of the generated image. + width (`int`, defaults to `1280`): + The width in pixels of the generated image. + num_frames (`int`, defaults to `121`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `36`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to `7.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. + fps (`int`, defaults to `30`): + The frames per second of the generated video. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`CosmosPipelineOutput`] instead of a plain tuple. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~CosmosPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`CosmosPipelineOutput`] is returned, otherwise a `tuple` is returned where + the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + if self.safety_checker is None: + raise ValueError( + f"You have disabled the safety checker for {self.__class__}. This is in violation of the " + "[NVIDIA Open Model License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license). " + f"Please ensure that you are compliant with the license agreement." + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, prompt_embeds, callback_on_step_end_tensor_inputs, image, video) + + self._guidance_scale = guidance_scale + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + if self.safety_checker is not None: + self.safety_checker.to(device) + if prompt is not None: + prompt_list = [prompt] if isinstance(prompt, str) else prompt + for p in prompt_list: + if not self.safety_checker.check_text_safety(p): + raise ValueError( + f"Cosmos Guardrail detected unsafe text in the prompt: {p}. Please ensure that the " + f"prompt abides by the NVIDIA Open Model License Agreement." + ) + self.safety_checker.to("cpu") + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + ( + prompt_embeds, + negative_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + device=device, + max_sequence_length=max_sequence_length, + ) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device) + + # 5. Prepare latent variables + vae_dtype = self.vae.dtype + transformer_dtype = self.transformer.dtype + + if image is not None: + video = self.video_processor.preprocess(image, height, width).unsqueeze(2) + else: + video = self.video_processor.preprocess_video(video, height, width) + video = video.to(device=device, dtype=vae_dtype) + + num_channels_latents = self.transformer.config.in_channels - 1 + latents, conditioning_latents, cond_indicator, uncond_indicator, cond_mask, uncond_mask = self.prepare_latents( + video, + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + self.do_classifier_free_guidance, + input_frames_guidance, + torch.float32, + device, + generator, + latents, + ) + cond_mask = cond_mask.to(transformer_dtype) + if self.do_classifier_free_guidance: + uncond_mask = uncond_mask.to(transformer_dtype) + + augment_sigma = torch.tensor([augment_sigma], device=device, dtype=torch.float32) + padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype) + + # 6. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + timestep = t.expand(latents.shape[0]).to(transformer_dtype) + + current_sigma = self.scheduler.sigmas[i] + is_augment_sigma_greater = augment_sigma >= current_sigma + + c_in_augment = self.scheduler._get_conditioning_c_in(augment_sigma) + c_in_original = self.scheduler._get_conditioning_c_in(current_sigma) + + current_cond_indicator = cond_indicator * 0 if is_augment_sigma_greater else cond_indicator + cond_noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=torch.float32) + cond_latent = conditioning_latents + cond_noise * augment_sigma[:, None, None, None, None] + cond_latent = cond_latent * c_in_augment / c_in_original + cond_latent = current_cond_indicator * cond_latent + (1 - current_cond_indicator) * latents + cond_latent = self.scheduler.scale_model_input(cond_latent, t) + cond_latent = cond_latent.to(transformer_dtype) + + noise_pred = self.transformer( + hidden_states=cond_latent, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + fps=fps, + condition_mask=cond_mask, + padding_mask=padding_mask, + return_dict=False, + )[0] + + sample = latents + if self.do_classifier_free_guidance: + current_uncond_indicator = uncond_indicator * 0 if is_augment_sigma_greater else uncond_indicator + uncond_noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=torch.float32) + uncond_latent = conditioning_latents + uncond_noise * augment_sigma[:, None, None, None, None] + uncond_latent = uncond_latent * c_in_augment / c_in_original + uncond_latent = current_uncond_indicator * uncond_latent + (1 - current_uncond_indicator) * latents + uncond_latent = self.scheduler.scale_model_input(uncond_latent, t) + uncond_latent = uncond_latent.to(transformer_dtype) + + noise_pred_uncond = self.transformer( + hidden_states=uncond_latent, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + fps=fps, + condition_mask=uncond_mask, + padding_mask=padding_mask, + return_dict=False, + )[0] + noise_pred = torch.cat([noise_pred_uncond, noise_pred]) + sample = torch.cat([sample, sample]) + + # pred_original_sample (x0) + noise_pred = self.scheduler.step(noise_pred, t, sample, return_dict=False)[1] + self.scheduler._step_index -= 1 + + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2, dim=0) + noise_pred_uncond = ( + current_uncond_indicator * conditioning_latents + + (1 - current_uncond_indicator) * noise_pred_uncond + ) + noise_pred_cond = ( + current_cond_indicator * conditioning_latents + (1 - current_cond_indicator) * noise_pred_cond + ) + noise_pred = noise_pred_cond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) + else: + noise_pred = ( + current_cond_indicator * conditioning_latents + (1 - current_cond_indicator) * noise_pred + ) + + # pred_sample (eps) + latents = self.scheduler.step( + noise_pred, t, latents, return_dict=False, pred_original_sample=noise_pred + )[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if not output_type == "latent": + if self.vae.config.latents_mean is not None: + latents_mean, latents_std = self.vae.config.latents_mean, self.vae.config.latents_std + latents_mean = ( + torch.tensor(latents_mean) + .view(1, self.vae.config.latent_channels, -1, 1, 1)[:, :, : latents.size(2)] + .to(latents) + ) + latents_std = ( + torch.tensor(latents_std) + .view(1, self.vae.config.latent_channels, -1, 1, 1)[:, :, : latents.size(2)] + .to(latents) + ) + latents = latents * latents_std / self.scheduler.config.sigma_data + latents_mean + else: + latents = latents / self.scheduler.config.sigma_data + video = self.vae.decode(latents.to(vae_dtype), return_dict=False)[0] + + if self.safety_checker is not None: + self.safety_checker.to(device) + video = self.video_processor.postprocess_video(video, output_type="np") + video = (video * 255).astype(np.uint8) + video_batch = [] + for vid in video: + vid = self.safety_checker.check_video_safety(vid) + video_batch.append(vid) + video = np.stack(video_batch).astype(np.float32) / 255.0 * 2 - 1 + video = torch.from_numpy(video).permute(0, 4, 1, 2, 3) + video = self.video_processor.postprocess_video(video, output_type=output_type) + self.safety_checker.to("cpu") + else: + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return CosmosPipelineOutput(frames=video) diff --git a/diffusers/pipelines/cosmos/pipeline_output.py b/diffusers/pipelines/cosmos/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..1ded292f8dfbdd7e938ebdcea1bc9fd03b545b84 --- /dev/null +++ b/diffusers/pipelines/cosmos/pipeline_output.py @@ -0,0 +1,39 @@ +from dataclasses import dataclass + +import numpy as np +import PIL.Image +import torch + +from diffusers.utils import BaseOutput, get_logger + + +logger = get_logger(__name__) + + +@dataclass +class CosmosPipelineOutput(BaseOutput): + r""" + Output class for Cosmos any-to-world/video pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or list[list[PIL.Image.Image]]): + list of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + frames: torch.Tensor + + +@dataclass +class CosmosImagePipelineOutput(BaseOutput): + """ + Output class for Cosmos any-to-image pipelines. + + Args: + images (`list[PIL.Image.Image]` or `np.ndarray`) + list of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: list[PIL.Image.Image] | np.ndarray diff --git a/diffusers/pipelines/ddpm/__init__.py b/diffusers/pipelines/ddpm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eb41dd1dcf642c791f3d7b0d985efcaf3e4a2c22 --- /dev/null +++ b/diffusers/pipelines/ddpm/__init__.py @@ -0,0 +1,22 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + _LazyModule, +) + + +_import_structure = {"pipeline_ddpm": ["DDPMPipeline"]} + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + from .pipeline_ddpm import DDPMPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) diff --git a/diffusers/pipelines/ddpm/pipeline_ddpm.py b/diffusers/pipelines/ddpm/pipeline_ddpm.py new file mode 100644 index 0000000000000000000000000000000000000000..6d4796cbea1f3fe00f4a3417bcff96ea09e104a3 --- /dev/null +++ b/diffusers/pipelines/ddpm/pipeline_ddpm.py @@ -0,0 +1,139 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch + +from ...models import UNet2DModel +from ...schedulers import DDPMScheduler +from ...utils import is_torch_xla_available +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +class DDPMPipeline(DiffusionPipeline): + r""" + Pipeline for image generation. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Parameters: + unet ([`UNet2DModel`]): + A `UNet2DModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of + [`DDPMScheduler`], or [`DDIMScheduler`]. + """ + + model_cpu_offload_seq = "unet" + + def __init__(self, unet: UNet2DModel, scheduler: DDPMScheduler): + super().__init__() + self.register_modules(unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + batch_size: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + num_inference_steps: int = 1000, + output_type: str | None = "pil", + return_dict: bool = True, + ) -> ImagePipelineOutput | tuple: + r""" + The call function to the pipeline for generation. + + Args: + batch_size (`int`, *optional*, defaults to 1): + The number of images to generate. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + num_inference_steps (`int`, *optional*, defaults to 1000): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Example: + + ```py + >>> from diffusers import DDPMPipeline + + >>> # load model and scheduler + >>> pipe = DDPMPipeline.from_pretrained("google/ddpm-cat-256") + + >>> # run pipeline in inference (sample random noise and denoise) + >>> image = pipe().images[0] + + >>> # save image + >>> image.save("ddpm_generated_image.png") + ``` + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images + """ + # Sample gaussian noise to begin loop + if isinstance(self.unet.config.sample_size, int): + image_shape = ( + batch_size, + self.unet.config.in_channels, + self.unet.config.sample_size, + self.unet.config.sample_size, + ) + else: + image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size) + + if self.device.type == "mps": + # randn does not work reproducibly on mps + image = randn_tensor(image_shape, generator=generator, dtype=self.unet.dtype) + image = image.to(self.device) + else: + image = randn_tensor(image_shape, generator=generator, device=self.device, dtype=self.unet.dtype) + + # set step values + self.scheduler.set_timesteps(num_inference_steps) + + for t in self.progress_bar(self.scheduler.timesteps): + # 1. predict noise model_output + model_output = self.unet(image, t).sample + + # 2. compute previous image: x_t -> x_t-1 + image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample + + if XLA_AVAILABLE: + xm.mark_step() + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/pipelines/deprecated/__init__.py b/diffusers/pipelines/deprecated/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9936323170adbceac2c5c25e3881ea731d8602e1 --- /dev/null +++ b/diffusers/pipelines/deprecated/__init__.py @@ -0,0 +1,153 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_librosa_available, + is_note_seq_available, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_pt_objects + + _dummy_objects.update(get_objects_from_module(dummy_pt_objects)) +else: + _import_structure["latent_diffusion_uncond"] = ["LDMPipeline"] + _import_structure["pndm"] = ["PNDMPipeline"] + _import_structure["repaint"] = ["RePaintPipeline"] + _import_structure["score_sde_ve"] = ["ScoreSdeVePipeline"] + _import_structure["stochastic_karras_ve"] = ["KarrasVePipeline"] + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["alt_diffusion"] = [ + "AltDiffusionImg2ImgPipeline", + "AltDiffusionPipeline", + "AltDiffusionPipelineOutput", + ] + _import_structure["versatile_diffusion"] = [ + "VersatileDiffusionDualGuidedPipeline", + "VersatileDiffusionImageVariationPipeline", + "VersatileDiffusionPipeline", + "VersatileDiffusionTextToImagePipeline", + ] + _import_structure["vq_diffusion"] = ["VQDiffusionPipeline"] + _import_structure["stable_diffusion_variants"] = [ + "CycleDiffusionPipeline", + "StableDiffusionInpaintPipelineLegacy", + "StableDiffusionPix2PixZeroPipeline", + "StableDiffusionParadigmsPipeline", + "StableDiffusionModelEditingPipeline", + ] + +try: + if not (is_torch_available() and is_librosa_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_librosa_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_librosa_objects)) + +else: + _import_structure["audio_diffusion"] = ["AudioDiffusionPipeline", "Mel"] + +try: + if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_transformers_and_torch_and_note_seq_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_transformers_and_torch_and_note_seq_objects)) + +else: + _import_structure["spectrogram_diffusion"] = ["MidiProcessor", "SpectrogramDiffusionPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_pt_objects import * + + else: + from .latent_diffusion_uncond import LDMPipeline + from .pndm import PNDMPipeline + from .repaint import RePaintPipeline + from .score_sde_ve import ScoreSdeVePipeline + from .stochastic_karras_ve import KarrasVePipeline + + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + + else: + from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline, AltDiffusionPipelineOutput + from .audio_diffusion import AudioDiffusionPipeline, Mel + from .spectrogram_diffusion import SpectrogramDiffusionPipeline + from .stable_diffusion_variants import ( + CycleDiffusionPipeline, + StableDiffusionInpaintPipelineLegacy, + StableDiffusionModelEditingPipeline, + StableDiffusionParadigmsPipeline, + StableDiffusionPix2PixZeroPipeline, + ) + from .stochastic_karras_ve import KarrasVePipeline + from .versatile_diffusion import ( + VersatileDiffusionDualGuidedPipeline, + VersatileDiffusionImageVariationPipeline, + VersatileDiffusionPipeline, + VersatileDiffusionTextToImagePipeline, + ) + from .vq_diffusion import VQDiffusionPipeline + + try: + if not (is_torch_available() and is_librosa_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_librosa_objects import * + else: + from .audio_diffusion import AudioDiffusionPipeline, Mel + + try: + if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403 + else: + from .spectrogram_diffusion import ( + MidiProcessor, + SpectrogramDiffusionPipeline, + ) + + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/pipelines/deprecated/alt_diffusion/__init__.py b/diffusers/pipelines/deprecated/alt_diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..71fa15b3feff08dc4008d1fa02ba61ad1300efed --- /dev/null +++ b/diffusers/pipelines/deprecated/alt_diffusion/__init__.py @@ -0,0 +1,53 @@ +from typing import TYPE_CHECKING + +from ....utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ....utils import dummy_torch_and_transformers_objects + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["modeling_roberta_series"] = ["RobertaSeriesModelWithTransformation"] + _import_structure["pipeline_alt_diffusion"] = ["AltDiffusionPipeline"] + _import_structure["pipeline_alt_diffusion_img2img"] = ["AltDiffusionImg2ImgPipeline"] + + _import_structure["pipeline_output"] = ["AltDiffusionPipelineOutput"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ....utils.dummy_torch_and_transformers_objects import * + + else: + from .modeling_roberta_series import RobertaSeriesModelWithTransformation + from .pipeline_alt_diffusion import AltDiffusionPipeline + from .pipeline_alt_diffusion_img2img import AltDiffusionImg2ImgPipeline + from .pipeline_output import AltDiffusionPipelineOutput + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py b/diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py new file mode 100644 index 0000000000000000000000000000000000000000..ed72e505b9c34ce373ce8a13beaea528ede2a09e --- /dev/null +++ b/diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py @@ -0,0 +1,123 @@ +from dataclasses import dataclass + +import torch +from torch import nn +from transformers import RobertaPreTrainedModel, XLMRobertaConfig, XLMRobertaModel +from transformers.utils import ModelOutput + + +@dataclass +class TransformationModelOutput(ModelOutput): + """ + Base class for text model's outputs that also contains a pooling of the last hidden states. + + Args: + text_embeds (`torch.Tensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The text embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + tuple of `torch.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one + for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + tuple of `torch.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + projection_state: torch.Tensor | None = None + last_hidden_state: torch.Tensor = None + hidden_states: tuple[torch.Tensor] | None = None + attentions: tuple[torch.Tensor] | None = None + + +class RobertaSeriesConfig(XLMRobertaConfig): + def __init__( + self, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + project_dim=512, + pooler_fn="cls", + learn_encoder=False, + use_attention_mask=True, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + self.project_dim = project_dim + self.pooler_fn = pooler_fn + self.learn_encoder = learn_encoder + self.use_attention_mask = use_attention_mask + + +class RobertaSeriesModelWithTransformation(RobertaPreTrainedModel): + _keys_to_ignore_on_load_unexpected = [r"pooler", r"logit_scale"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + base_model_prefix = "roberta" + config_class = RobertaSeriesConfig + + def __init__(self, config): + super().__init__(config) + self.roberta = XLMRobertaModel(config) + self.transformation = nn.Linear(config.hidden_size, config.project_dim) + self.has_pre_transformation = getattr(config, "has_pre_transformation", False) + if self.has_pre_transformation: + self.transformation_pre = nn.Linear(config.hidden_size, config.project_dim) + self.pre_LN = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.post_init() + + def forward( + self, + input_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + token_type_ids: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + head_mask: torch.Tensor | None = None, + inputs_embeds: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + encoder_attention_mask: torch.Tensor | None = None, + output_attentions: bool | None = None, + return_dict: bool | None = None, + output_hidden_states: bool | None = None, + ): + r""" """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.base_model( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=True if self.has_pre_transformation else output_hidden_states, + return_dict=return_dict, + ) + + if self.has_pre_transformation: + sequence_output2 = outputs["hidden_states"][-2] + sequence_output2 = self.pre_LN(sequence_output2) + projection_state2 = self.transformation_pre(sequence_output2) + + return TransformationModelOutput( + projection_state=projection_state2, + last_hidden_state=outputs.last_hidden_state, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + projection_state = self.transformation(outputs.last_hidden_state) + return TransformationModelOutput( + projection_state=projection_state, + last_hidden_state=outputs.last_hidden_state, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py b/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..9fab42916e9e61176544ac868e46ac2d5f72c1fc --- /dev/null +++ b/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py @@ -0,0 +1,990 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import torch +from packaging import version +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, XLMRobertaTokenizer + +from ....configuration_utils import FrozenDict +from ....image_processor import PipelineImageInput, VaeImageProcessor +from ....loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from ....models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ....models.lora import adjust_lora_scale_text_encoder +from ....schedulers import KarrasDiffusionSchedulers +from ....utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ....utils.torch_utils import randn_tensor +from ...pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ...stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from .modeling_roberta_series import RobertaSeriesModelWithTransformation +from .pipeline_output import AltDiffusionPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import AltDiffusionPipeline + + >>> pipe = AltDiffusionPipeline.from_pretrained("BAAI/AltDiffusion-m9", torch_dtype=torch.float16) + >>> pipe = pipe.to("cuda") + + >>> # "dark elf princess, highly detailed, d & d, fantasy, highly detailed, digital painting, trending on artstation, concept art, sharp focus, illustration, art by artgerm and greg rutkowski and fuji choko and viktoria gavrilenko and hoang lap" + >>> prompt = "黑暗精灵公主,非常详细,幻想,非常详细,数字绘画,概念艺术,敏锐的焦点,插图" + >>> image = pipe(prompt).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class AltDiffusionPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionLoraLoaderMixin, + IPAdapterMixin, + FromSingleFileMixin, +): + r""" + Pipeline for text-to-image generation using Alt Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.RobertaSeriesModelWithTransformation`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.XLMRobertaTokenizer`]): + A `XLMRobertaTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for + more details about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: RobertaSeriesModelWithTransformation, + tokenizer: XLMRobertaTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection = None, + requires_safety_checker: bool = True, + ): + super().__init__() + + if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Alt Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = ( + unet is not None + and hasattr(unet.config, "_diffusers_version") + and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0") + ) + is_unet_sample_size_less_64 = ( + unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5" + " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32): + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + timesteps (`torch.Tensor`): + generate embedding vectors at these timesteps + embedding_dim (`int`, *optional*, defaults to 512): + dimension of the embeddings to generate + dtype: + data type of the generated embeddings + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + timesteps: list[int] = None, + sigmas: list[float] = None, + guidance_scale: float = 7.5, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + guidance_rescale: float = 0.0, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when + using zero terminal SNR. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + # to deal with lora scaling and other possible forward hooks + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + clip_skip=self.clip_skip, + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + if ip_adapter_image is not None: + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) + if self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 6.1 Add image embeds for IP-Adapter + added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + + # 6.2 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://huggingface.co/papers/2305.08891 + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return AltDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py b/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py new file mode 100644 index 0000000000000000000000000000000000000000..b6cd51c6d2030369fb4a0e23dbad58839cb3e4e4 --- /dev/null +++ b/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -0,0 +1,1045 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import PIL.Image +import torch +from packaging import version +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, XLMRobertaTokenizer + +from ....configuration_utils import FrozenDict +from ....image_processor import PipelineImageInput, VaeImageProcessor +from ....loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from ....models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ....models.lora import adjust_lora_scale_text_encoder +from ....schedulers import KarrasDiffusionSchedulers +from ....utils import ( + PIL_INTERPOLATION, + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ....utils.torch_utils import randn_tensor +from ...pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ...stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from .modeling_roberta_series import RobertaSeriesModelWithTransformation +from .pipeline_output import AltDiffusionPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import requests + >>> import torch + >>> from PIL import Image + >>> from io import BytesIO + + >>> from diffusers import AltDiffusionImg2ImgPipeline + + >>> device = "cuda" + >>> model_id_or_path = "BAAI/AltDiffusion-m9" + >>> pipe = AltDiffusionImg2ImgPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16) + >>> pipe = pipe.to(device) + + >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + + >>> response = requests.get(url) + >>> init_image = Image.open(BytesIO(response.content)).convert("RGB") + >>> init_image = init_image.resize((768, 512)) + + >>> # "A fantasy landscape, trending on artstation" + >>> prompt = "幻想风景, artstation" + + >>> images = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images + >>> images[0].save("幻想风景.png") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess +def preprocess(image): + deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead" + deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False) + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 + + image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + return image + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class AltDiffusionImg2ImgPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + IPAdapterMixin, + StableDiffusionLoraLoaderMixin, + FromSingleFileMixin, +): + r""" + Pipeline for text-guided image-to-image generation using Alt Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.RobertaSeriesModelWithTransformation`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.XLMRobertaTokenizer`]): + A `XLMRobertaTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for + more details about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: RobertaSeriesModelWithTransformation, + tokenizer: XLMRobertaTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection = None, + requires_safety_checker: bool = True, + ): + super().__init__() + + if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Alt Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = ( + unet is not None + and hasattr(unet.config, "_diffusers_version") + and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0") + ) + is_unet_sample_size_less_64 = ( + unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5" + " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + strength, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + + else: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + init_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + init_latents = self.vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + latents = init_latents + + return latents + + def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32): + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + timesteps (`torch.Tensor`): + generate embedding vectors at these timesteps + embedding_dim (`int`, *optional*, defaults to 512): + dimension of the embeddings to generate + dtype: + data type of the generated embeddings + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + image: PipelineImageInput = None, + strength: float = 0.8, + num_inference_steps: int | None = 50, + timesteps: list[int] = None, + sigmas: list[float] = None, + guidance_scale: float | None = 7.5, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float | None = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + clip_skip: int = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + strength (`float`, *optional*, defaults to 0.8): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter is modulated by `strength`. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + Examples: + + Returns: + [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + strength, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + if ip_adapter_image is not None: + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) + if self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + + # 4. Preprocess image + image = self.image_processor.preprocess(image) + + # 5. set timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 6. Prepare latent variables + latents = self.prepare_latents( + image, + latent_timestep, + batch_size, + num_images_per_prompt, + prompt_embeds.dtype, + device, + generator, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Add image embeds for IP-Adapter + added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + + # 7.2 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return AltDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/pipelines/deprecated/alt_diffusion/pipeline_output.py b/diffusers/pipelines/deprecated/alt_diffusion/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..259a8675696551b4010396aadec14a66a00a372d --- /dev/null +++ b/diffusers/pipelines/deprecated/alt_diffusion/pipeline_output.py @@ -0,0 +1,27 @@ +from dataclasses import dataclass + +import numpy as np +import PIL.Image + +from ....utils import ( + BaseOutput, +) + + +@dataclass +# Copied from diffusers.pipelines.stable_diffusion.pipeline_output.StableDiffusionPipelineOutput with Stable->Alt +class AltDiffusionPipelineOutput(BaseOutput): + """ + Output class for Alt Diffusion pipelines. + + Args: + images (`list[PIL.Image.Image]` or `np.ndarray`) + list of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, + num_channels)`. + nsfw_content_detected (`list[bool]`) + list indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content or + `None` if safety checking could not be performed. + """ + + images: list[PIL.Image.Image] | np.ndarray + nsfw_content_detected: list[bool] | None diff --git a/diffusers/pipelines/deprecated/audio_diffusion/__init__.py b/diffusers/pipelines/deprecated/audio_diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3127951863a7db3f9dd8e42ac5ab64fa9ac3ec0c --- /dev/null +++ b/diffusers/pipelines/deprecated/audio_diffusion/__init__.py @@ -0,0 +1,23 @@ +from typing import TYPE_CHECKING + +from ....utils import DIFFUSERS_SLOW_IMPORT, _LazyModule + + +_import_structure = { + "mel": ["Mel"], + "pipeline_audio_diffusion": ["AudioDiffusionPipeline"], +} + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + from .mel import Mel + from .pipeline_audio_diffusion import AudioDiffusionPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) diff --git a/diffusers/pipelines/deprecated/audio_diffusion/mel.py b/diffusers/pipelines/deprecated/audio_diffusion/mel.py new file mode 100644 index 0000000000000000000000000000000000000000..0902f993a0600aa08a39a638f4a75bea5d301415 --- /dev/null +++ b/diffusers/pipelines/deprecated/audio_diffusion/mel.py @@ -0,0 +1,179 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np # noqa: E402 + +from ....configuration_utils import ConfigMixin, register_to_config +from ....schedulers.scheduling_utils import SchedulerMixin + + +try: + import librosa # noqa: E402 + + _librosa_can_be_imported = True + _import_error = "" +except Exception as e: + _librosa_can_be_imported = False + _import_error = ( + f"Cannot import librosa because {e}. Make sure to correctly install librosa to be able to install it." + ) + + +from PIL import Image # noqa: E402 + + +class Mel(ConfigMixin, SchedulerMixin): + """ + Parameters: + x_res (`int`): + x resolution of spectrogram (time). + y_res (`int`): + y resolution of spectrogram (frequency bins). + sample_rate (`int`): + Sample rate of audio. + n_fft (`int`): + Number of Fast Fourier Transforms. + hop_length (`int`): + Hop length (a higher number is recommended if `y_res` < 256). + top_db (`int`): + Loudest decibel value. + n_iter (`int`): + Number of iterations for Griffin-Lim Mel inversion. + """ + + config_name = "mel_config.json" + + @register_to_config + def __init__( + self, + x_res: int = 256, + y_res: int = 256, + sample_rate: int = 22050, + n_fft: int = 2048, + hop_length: int = 512, + top_db: int = 80, + n_iter: int = 32, + ): + self.hop_length = hop_length + self.sr = sample_rate + self.n_fft = n_fft + self.top_db = top_db + self.n_iter = n_iter + self.set_resolution(x_res, y_res) + self.audio = None + + if not _librosa_can_be_imported: + raise ValueError(_import_error) + + def set_resolution(self, x_res: int, y_res: int): + """Set resolution. + + Args: + x_res (`int`): + x resolution of spectrogram (time). + y_res (`int`): + y resolution of spectrogram (frequency bins). + """ + self.x_res = x_res + self.y_res = y_res + self.n_mels = self.y_res + self.slice_size = self.x_res * self.hop_length - 1 + + def load_audio(self, audio_file: str = None, raw_audio: np.ndarray = None): + """Load audio. + + Args: + audio_file (`str`): + An audio file that must be on disk due to [Librosa](https://librosa.org/) limitation. + raw_audio (`np.ndarray`): + The raw audio file as a NumPy array. + """ + if audio_file is not None: + self.audio, _ = librosa.load(audio_file, mono=True, sr=self.sr) + else: + self.audio = raw_audio + + # Pad with silence if necessary. + if len(self.audio) < self.x_res * self.hop_length: + self.audio = np.concatenate([self.audio, np.zeros((self.x_res * self.hop_length - len(self.audio),))]) + + def get_number_of_slices(self) -> int: + """Get number of slices in audio. + + Returns: + `int`: + Number of spectograms audio can be sliced into. + """ + return len(self.audio) // self.slice_size + + def get_audio_slice(self, slice: int = 0) -> np.ndarray: + """Get slice of audio. + + Args: + slice (`int`): + Slice number of audio (out of `get_number_of_slices()`). + + Returns: + `np.ndarray`: + The audio slice as a NumPy array. + """ + return self.audio[self.slice_size * slice : self.slice_size * (slice + 1)] + + def get_sample_rate(self) -> int: + """Get sample rate. + + Returns: + `int`: + Sample rate of audio. + """ + return self.sr + + def audio_slice_to_image(self, slice: int) -> Image.Image: + """Convert slice of audio to spectrogram. + + Args: + slice (`int`): + Slice number of audio to convert (out of `get_number_of_slices()`). + + Returns: + `PIL Image`: + A grayscale image of `x_res x y_res`. + """ + S = librosa.feature.melspectrogram( + y=self.get_audio_slice(slice), sr=self.sr, n_fft=self.n_fft, hop_length=self.hop_length, n_mels=self.n_mels + ) + log_S = librosa.power_to_db(S, ref=np.max, top_db=self.top_db) + bytedata = (((log_S + self.top_db) * 255 / self.top_db).clip(0, 255) + 0.5).astype(np.uint8) + image = Image.fromarray(bytedata) + return image + + def image_to_audio(self, image: Image.Image) -> np.ndarray: + """Converts spectrogram to audio. + + Args: + image (`PIL Image`): + An grayscale image of `x_res x y_res`. + + Returns: + audio (`np.ndarray`): + The audio as a NumPy array. + """ + bytedata = np.frombuffer(image.tobytes(), dtype="uint8").reshape((image.height, image.width)) + log_S = bytedata.astype("float") * self.top_db / 255 - self.top_db + S = librosa.db_to_power(log_S) + audio = librosa.feature.inverse.mel_to_audio( + S, sr=self.sr, n_fft=self.n_fft, hop_length=self.hop_length, n_iter=self.n_iter + ) + return audio diff --git a/diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py b/diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..f63fc8aacbc8b29f6cbebb87154f918260933ea8 --- /dev/null +++ b/diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py @@ -0,0 +1,325 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from math import acos, sin + +import numpy as np +import torch +from PIL import Image + +from ....models import AutoencoderKL, UNet2DConditionModel +from ....schedulers import DDIMScheduler, DDPMScheduler +from ....utils.torch_utils import randn_tensor +from ...pipeline_utils import AudioPipelineOutput, BaseOutput, DiffusionPipeline, ImagePipelineOutput +from .mel import Mel + + +class AudioDiffusionPipeline(DiffusionPipeline): + """ + Pipeline for audio diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Parameters: + vqae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + mel ([`Mel`]): + Transform audio into a spectrogram. + scheduler ([`DDIMScheduler`] or [`DDPMScheduler`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`] or [`DDPMScheduler`]. + """ + + _optional_components = ["vqvae"] + + def __init__( + self, + vqvae: AutoencoderKL, + unet: UNet2DConditionModel, + mel: Mel, + scheduler: DDIMScheduler | DDPMScheduler, + ): + super().__init__() + self.register_modules(unet=unet, scheduler=scheduler, mel=mel, vqvae=vqvae) + + def get_default_steps(self) -> int: + """Returns default number of steps recommended for inference. + + Returns: + `int`: + The number of steps. + """ + return 50 if isinstance(self.scheduler, DDIMScheduler) else 1000 + + @torch.no_grad() + def __call__( + self, + batch_size: int = 1, + audio_file: str = None, + raw_audio: np.ndarray = None, + slice: int = 0, + start_step: int = 0, + steps: int = None, + generator: torch.Generator = None, + mask_start_secs: float = 0, + mask_end_secs: float = 0, + step_generator: torch.Generator = None, + eta: float = 0, + noise: torch.Tensor = None, + encoding: torch.Tensor = None, + return_dict=True, + ) -> AudioPipelineOutput | ImagePipelineOutput | tuple[list[Image.Image], tuple[int, list[np.ndarray]]]: + """ + The call function to the pipeline for generation. + + Args: + batch_size (`int`): + Number of samples to generate. + audio_file (`str`): + An audio file that must be on disk due to [Librosa](https://librosa.org/) limitation. + raw_audio (`np.ndarray`): + The raw audio file as a NumPy array. + slice (`int`): + Slice number of audio to convert. + start_step (int): + Step to start diffusion from. + steps (`int`): + Number of denoising steps (defaults to `50` for DDIM and `1000` for DDPM). + generator (`torch.Generator`): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + mask_start_secs (`float`): + Number of seconds of audio to mask (not generate) at start. + mask_end_secs (`float`): + Number of seconds of audio to mask (not generate) at end. + step_generator (`torch.Generator`): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) used to denoise. + None + eta (`float`): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + noise (`torch.Tensor`): + A noise tensor of shape `(batch_size, 1, height, width)` or `None`. + encoding (`torch.Tensor`): + A tensor for [`UNet2DConditionModel`] of shape `(batch_size, seq_length, cross_attention_dim)`. + return_dict (`bool`): + Whether or not to return a [`AudioPipelineOutput`], [`ImagePipelineOutput`] or a plain tuple. + + Examples: + + For audio diffusion: + + ```py + import torch + from IPython.display import Audio + from diffusers import DiffusionPipeline + + device = "cuda" if torch.cuda.is_available() else "cpu" + pipe = DiffusionPipeline.from_pretrained("teticio/audio-diffusion-256").to(device) + + output = pipe() + display(output.images[0]) + display(Audio(output.audios[0], rate=mel.get_sample_rate())) + ``` + + For latent audio diffusion: + + ```py + import torch + from IPython.display import Audio + from diffusers import DiffusionPipeline + + device = "cuda" if torch.cuda.is_available() else "cpu" + pipe = DiffusionPipeline.from_pretrained("teticio/latent-audio-diffusion-256").to(device) + + output = pipe() + display(output.images[0]) + display(Audio(output.audios[0], rate=pipe.mel.get_sample_rate())) + ``` + + For other tasks like variation, inpainting, outpainting, etc: + + ```py + output = pipe( + raw_audio=output.audios[0, 0], + start_step=int(pipe.get_default_steps() / 2), + mask_start_secs=1, + mask_end_secs=1, + ) + display(output.images[0]) + display(Audio(output.audios[0], rate=pipe.mel.get_sample_rate())) + ``` + + Returns: + `list[PIL Image]`: + A list of Mel spectrograms (`float`, `list[np.ndarray]`) with the sample rate and raw audio. + """ + + steps = steps or self.get_default_steps() + self.scheduler.set_timesteps(steps) + step_generator = step_generator or generator + # For backwards compatibility + if isinstance(self.unet.config.sample_size, int): + self.unet.config.sample_size = (self.unet.config.sample_size, self.unet.config.sample_size) + if noise is None: + noise = randn_tensor( + ( + batch_size, + self.unet.config.in_channels, + self.unet.config.sample_size[0], + self.unet.config.sample_size[1], + ), + generator=generator, + device=self.device, + ) + images = noise + mask = None + + if audio_file is not None or raw_audio is not None: + self.mel.load_audio(audio_file, raw_audio) + input_image = self.mel.audio_slice_to_image(slice) + input_image = np.frombuffer(input_image.tobytes(), dtype="uint8").reshape( + (input_image.height, input_image.width) + ) + input_image = (input_image / 255) * 2 - 1 + input_images = torch.tensor(input_image[np.newaxis, :, :], dtype=torch.float).to(self.device) + + if self.vqvae is not None: + input_images = self.vqvae.encode(torch.unsqueeze(input_images, 0)).latent_dist.sample( + generator=generator + )[0] + input_images = self.vqvae.config.scaling_factor * input_images + + if start_step > 0: + images[0, 0] = self.scheduler.add_noise(input_images, noise, self.scheduler.timesteps[start_step - 1]) + + pixels_per_second = ( + self.unet.config.sample_size[1] * self.mel.get_sample_rate() / self.mel.x_res / self.mel.hop_length + ) + mask_start = int(mask_start_secs * pixels_per_second) + mask_end = int(mask_end_secs * pixels_per_second) + mask = self.scheduler.add_noise(input_images, noise, torch.tensor(self.scheduler.timesteps[start_step:])) + + for step, t in enumerate(self.progress_bar(self.scheduler.timesteps[start_step:])): + if isinstance(self.unet, UNet2DConditionModel): + model_output = self.unet(images, t, encoding)["sample"] + else: + model_output = self.unet(images, t)["sample"] + + if isinstance(self.scheduler, DDIMScheduler): + images = self.scheduler.step( + model_output=model_output, + timestep=t, + sample=images, + eta=eta, + generator=step_generator, + )["prev_sample"] + else: + images = self.scheduler.step( + model_output=model_output, + timestep=t, + sample=images, + generator=step_generator, + )["prev_sample"] + + if mask is not None: + if mask_start > 0: + images[:, :, :, :mask_start] = mask[:, step, :, :mask_start] + if mask_end > 0: + images[:, :, :, -mask_end:] = mask[:, step, :, -mask_end:] + + if self.vqvae is not None: + # 0.18215 was scaling factor used in training to ensure unit variance + images = 1 / self.vqvae.config.scaling_factor * images + images = self.vqvae.decode(images)["sample"] + + images = (images / 2 + 0.5).clamp(0, 1) + images = images.cpu().permute(0, 2, 3, 1).numpy() + images = (images * 255).round().astype("uint8") + images = list( + (Image.fromarray(_[:, :, 0]) for _ in images) + if images.shape[3] == 1 + else (Image.fromarray(_, mode="RGB").convert("L") for _ in images) + ) + + audios = [self.mel.image_to_audio(_) for _ in images] + if not return_dict: + return images, (self.mel.get_sample_rate(), audios) + + return BaseOutput(**AudioPipelineOutput(np.array(audios)[:, np.newaxis, :]), **ImagePipelineOutput(images)) + + @torch.no_grad() + def encode(self, images: list[Image.Image], steps: int = 50) -> np.ndarray: + """ + Reverse the denoising step process to recover a noisy image from the generated image. + + Args: + images (`list[PIL Image]`): + list of images to encode. + steps (`int`): + Number of encoding steps to perform (defaults to `50`). + + Returns: + `np.ndarray`: + A noise tensor of shape `(batch_size, 1, height, width)`. + """ + + # Only works with DDIM as this method is deterministic + assert isinstance(self.scheduler, DDIMScheduler) + self.scheduler.set_timesteps(steps) + sample = np.array( + [np.frombuffer(image.tobytes(), dtype="uint8").reshape((1, image.height, image.width)) for image in images] + ) + sample = (sample / 255) * 2 - 1 + sample = torch.Tensor(sample).to(self.device) + + for t in self.progress_bar(torch.flip(self.scheduler.timesteps, (0,))): + prev_timestep = t - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps + alpha_prod_t = self.scheduler.alphas_cumprod[t] + alpha_prod_t_prev = ( + self.scheduler.alphas_cumprod[prev_timestep] + if prev_timestep >= 0 + else self.scheduler.final_alpha_cumprod + ) + beta_prod_t = 1 - alpha_prod_t + model_output = self.unet(sample, t)["sample"] + pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * model_output + sample = (sample - pred_sample_direction) * alpha_prod_t_prev ** (-0.5) + sample = sample * alpha_prod_t ** (0.5) + beta_prod_t ** (0.5) * model_output + + return sample + + @staticmethod + def slerp(x0: torch.Tensor, x1: torch.Tensor, alpha: float) -> torch.Tensor: + """Spherical Linear intERPolation. + + Args: + x0 (`torch.Tensor`): + The first tensor to interpolate between. + x1 (`torch.Tensor`): + Second tensor to interpolate between. + alpha (`float`): + Interpolation between 0 and 1 + + Returns: + `torch.Tensor`: + The interpolated tensor. + """ + + theta = acos(torch.dot(torch.flatten(x0), torch.flatten(x1)) / torch.norm(x0) / torch.norm(x1)) + return sin((1 - alpha) * theta) * x0 / sin(theta) + sin(alpha * theta) * x1 / sin(theta) diff --git a/diffusers/pipelines/deprecated/latent_diffusion_uncond/__init__.py b/diffusers/pipelines/deprecated/latent_diffusion_uncond/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..214f5bbca969f9ae0629578c72aaf339f86ded88 --- /dev/null +++ b/diffusers/pipelines/deprecated/latent_diffusion_uncond/__init__.py @@ -0,0 +1,18 @@ +from typing import TYPE_CHECKING + +from ....utils import DIFFUSERS_SLOW_IMPORT, _LazyModule + + +_import_structure = {"pipeline_latent_diffusion_uncond": ["LDMPipeline"]} + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + from .pipeline_latent_diffusion_uncond import LDMPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) diff --git a/diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py b/diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py new file mode 100644 index 0000000000000000000000000000000000000000..70a65e2ef5be290f1443205d1239956908b54670 --- /dev/null +++ b/diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py @@ -0,0 +1,129 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect + +import torch + +from ....models import UNet2DModel, VQModel +from ....schedulers import DDIMScheduler +from ....utils.torch_utils import randn_tensor +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +class LDMPipeline(DiffusionPipeline): + r""" + Pipeline for unconditional image generation using latent diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Parameters: + vqvae ([`VQModel`]): + Vector-quantized (VQ) model to encode and decode images to and from latent representations. + unet ([`UNet2DModel`]): + A `UNet2DModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + [`DDIMScheduler`] is used in combination with `unet` to denoise the encoded image latents. + """ + + def __init__(self, vqvae: VQModel, unet: UNet2DModel, scheduler: DDIMScheduler): + super().__init__() + self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + batch_size: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + eta: float = 0.0, + num_inference_steps: int = 50, + output_type: str | None = "pil", + return_dict: bool = True, + **kwargs, + ) -> tuple | ImagePipelineOutput: + r""" + The call function to the pipeline for generation. + + Args: + batch_size (`int`, *optional*, defaults to 1): + Number of images to generate. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Example: + + ```py + >>> from diffusers import LDMPipeline + + >>> # load model and scheduler + >>> pipe = LDMPipeline.from_pretrained("CompVis/ldm-celebahq-256") + + >>> # run pipeline in inference (sample random noise and denoise) + >>> image = pipe().images[0] + ``` + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images + """ + + latents = randn_tensor( + (batch_size, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size), + generator=generator, + ) + latents = latents.to(self.device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + + self.scheduler.set_timesteps(num_inference_steps) + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + + extra_kwargs = {} + if accepts_eta: + extra_kwargs["eta"] = eta + + for t in self.progress_bar(self.scheduler.timesteps): + latent_model_input = self.scheduler.scale_model_input(latents, t) + # predict the noise residual + noise_prediction = self.unet(latent_model_input, t).sample + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_prediction, t, latents, **extra_kwargs).prev_sample + + # adjust latents with inverse of vae scale + latents = latents / self.vqvae.config.scaling_factor + # decode the image latents with the VAE + image = self.vqvae.decode(latents).sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/pipelines/deprecated/pndm/__init__.py b/diffusers/pipelines/deprecated/pndm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5e3bdba74079d77576655e22b43014a0438a9c2e --- /dev/null +++ b/diffusers/pipelines/deprecated/pndm/__init__.py @@ -0,0 +1,18 @@ +from typing import TYPE_CHECKING + +from ....utils import DIFFUSERS_SLOW_IMPORT, _LazyModule + + +_import_structure = {"pipeline_pndm": ["PNDMPipeline"]} + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + from .pipeline_pndm import PNDMPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) diff --git a/diffusers/pipelines/deprecated/pndm/pipeline_pndm.py b/diffusers/pipelines/deprecated/pndm/pipeline_pndm.py new file mode 100644 index 0000000000000000000000000000000000000000..fb116511f7275750a11d2c8cf791bcedeffaff82 --- /dev/null +++ b/diffusers/pipelines/deprecated/pndm/pipeline_pndm.py @@ -0,0 +1,119 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch + +from ....models import UNet2DModel +from ....schedulers import PNDMScheduler +from ....utils.torch_utils import randn_tensor +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +class PNDMPipeline(DiffusionPipeline): + r""" + Pipeline for unconditional image generation. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Parameters: + unet ([`UNet2DModel`]): + A `UNet2DModel` to denoise the encoded image latents. + scheduler ([`PNDMScheduler`]): + A `PNDMScheduler` to be used in combination with `unet` to denoise the encoded image. + """ + + unet: UNet2DModel + scheduler: PNDMScheduler + + def __init__(self, unet: UNet2DModel, scheduler: PNDMScheduler): + super().__init__() + + scheduler = PNDMScheduler.from_config(scheduler.config) + + self.register_modules(unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + batch_size: int = 1, + num_inference_steps: int = 50, + generator: torch.Generator | list[torch.Generator] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + **kwargs, + ) -> ImagePipelineOutput | tuple: + r""" + The call function to the pipeline for generation. + + Args: + batch_size (`int`, `optional`, defaults to 1): + The number of images to generate. + num_inference_steps (`int`, `optional`, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + generator (`torch.Generator`, `optional`): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + output_type (`str`, `optional`, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`ImagePipelineOutput`] instead of a plain tuple. + + Example: + + ```py + >>> from diffusers import PNDMPipeline + + >>> # load model and scheduler + >>> pndm = PNDMPipeline.from_pretrained("google/ddpm-cifar10-32") + + >>> # run pipeline in inference (sample random noise and denoise) + >>> image = pndm().images[0] + + >>> # save image + >>> image.save("pndm_generated_image.png") + ``` + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images. + """ + # For more information on the sampling method you can take a look at Algorithm 2 of + # the official paper: https://huggingface.co/papers/2202.09778 + + # Sample gaussian noise to begin loop + image = randn_tensor( + (batch_size, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size), + generator=generator, + device=self.device, + ) + + self.scheduler.set_timesteps(num_inference_steps) + for t in self.progress_bar(self.scheduler.timesteps): + model_output = self.unet(image, t).sample + + image = self.scheduler.step(model_output, t, image).prev_sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/pipelines/deprecated/repaint/__init__.py b/diffusers/pipelines/deprecated/repaint/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2c6b04af52d40e8a2bfa2aa5812b9fb8b1da06f5 --- /dev/null +++ b/diffusers/pipelines/deprecated/repaint/__init__.py @@ -0,0 +1,19 @@ +from typing import TYPE_CHECKING + +from ....utils import DIFFUSERS_SLOW_IMPORT, _LazyModule + + +_import_structure = {"pipeline_repaint": ["RePaintPipeline"]} + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + from .pipeline_repaint import RePaintPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) diff --git a/diffusers/pipelines/deprecated/repaint/pipeline_repaint.py b/diffusers/pipelines/deprecated/repaint/pipeline_repaint.py new file mode 100644 index 0000000000000000000000000000000000000000..3231d5e130497fd5dd091291be4a164e7f76ac6b --- /dev/null +++ b/diffusers/pipelines/deprecated/repaint/pipeline_repaint.py @@ -0,0 +1,229 @@ +# Copyright 2025 ETH Zurich Computer Vision Lab and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np +import PIL.Image +import torch + +from ....models import UNet2DModel +from ....schedulers import RePaintScheduler +from ....utils import PIL_INTERPOLATION, deprecate, logging +from ....utils.torch_utils import randn_tensor +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess +def _preprocess_image(image: list | PIL.Image.Image | torch.Tensor): + deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead" + deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False) + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 + + image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + return image + + +def _preprocess_mask(mask: list | PIL.Image.Image | torch.Tensor): + if isinstance(mask, torch.Tensor): + return mask + elif isinstance(mask, PIL.Image.Image): + mask = [mask] + + if isinstance(mask[0], PIL.Image.Image): + w, h = mask[0].size + w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 + mask = [np.array(m.convert("L").resize((w, h), resample=PIL_INTERPOLATION["nearest"]))[None, :] for m in mask] + mask = np.concatenate(mask, axis=0) + mask = mask.astype(np.float32) / 255.0 + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + mask = torch.from_numpy(mask) + elif isinstance(mask[0], torch.Tensor): + mask = torch.cat(mask, dim=0) + return mask + + +class RePaintPipeline(DiffusionPipeline): + r""" + Pipeline for image inpainting using RePaint. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Parameters: + unet ([`UNet2DModel`]): + A `UNet2DModel` to denoise the encoded image latents. + scheduler ([`RePaintScheduler`]): + A `RePaintScheduler` to be used in combination with `unet` to denoise the encoded image. + """ + + unet: UNet2DModel + scheduler: RePaintScheduler + model_cpu_offload_seq = "unet" + + def __init__(self, unet: UNet2DModel, scheduler: RePaintScheduler): + super().__init__() + self.register_modules(unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + image: torch.Tensor | PIL.Image.Image, + mask_image: torch.Tensor | PIL.Image.Image, + num_inference_steps: int = 250, + eta: float = 0.0, + jump_length: int = 10, + jump_n_sample: int = 10, + generator: torch.Generator | list[torch.Generator] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + ) -> ImagePipelineOutput | tuple: + r""" + The call function to the pipeline for generation. + + Args: + image (`torch.Tensor` or `PIL.Image.Image`): + The original image to inpaint on. + mask_image (`torch.Tensor` or `PIL.Image.Image`): + The mask_image where 0.0 define which part of the original image to inpaint. + num_inference_steps (`int`, *optional*, defaults to 1000): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + eta (`float`): + The weight of the added noise in a diffusion step. Its value is between 0.0 and 1.0; 0.0 corresponds to + DDIM and 1.0 is the DDPM scheduler. + jump_length (`int`, *optional*, defaults to 10): + The number of steps taken forward in time before going backward in time for a single jump ("j" in + RePaint paper). Take a look at Figure 9 and 10 in the + [paper](https://huggingface.co/papers/2201.09865). + jump_n_sample (`int`, *optional*, defaults to 10): + The number of times to make a forward time jump for a given chosen time sample. Take a look at Figure 9 + and 10 in the [paper](https://huggingface.co/papers/2201.09865). + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + output_type (`str`, `optional`, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`ImagePipelineOutput`] instead of a plain tuple. + + Example: + + ```py + >>> from io import BytesIO + >>> import torch + >>> import PIL + >>> import requests + >>> from diffusers import RePaintPipeline, RePaintScheduler + + + >>> def download_image(url): + ... response = requests.get(url) + ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") + + + >>> img_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/repaint/celeba_hq_256.png" + >>> mask_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/repaint/mask_256.png" + + >>> # Load the original image and the mask as PIL images + >>> original_image = download_image(img_url).resize((256, 256)) + >>> mask_image = download_image(mask_url).resize((256, 256)) + + >>> # Load the RePaint scheduler and pipeline based on a pretrained DDPM model + >>> scheduler = RePaintScheduler.from_pretrained("google/ddpm-ema-celebahq-256") + >>> pipe = RePaintPipeline.from_pretrained("google/ddpm-ema-celebahq-256", scheduler=scheduler) + >>> pipe = pipe.to("cuda") + + >>> generator = torch.Generator(device="cuda").manual_seed(0) + >>> output = pipe( + ... image=original_image, + ... mask_image=mask_image, + ... num_inference_steps=250, + ... eta=0.0, + ... jump_length=10, + ... jump_n_sample=10, + ... generator=generator, + ... ) + >>> inpainted_image = output.images[0] + ``` + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images. + """ + + original_image = image + + original_image = _preprocess_image(original_image) + original_image = original_image.to(device=self._execution_device, dtype=self.unet.dtype) + mask_image = _preprocess_mask(mask_image) + mask_image = mask_image.to(device=self._execution_device, dtype=self.unet.dtype) + + batch_size = original_image.shape[0] + + # sample gaussian noise to begin the loop + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + image_shape = original_image.shape + image = randn_tensor(image_shape, generator=generator, device=self._execution_device, dtype=self.unet.dtype) + + # set step values + self.scheduler.set_timesteps(num_inference_steps, jump_length, jump_n_sample, self._execution_device) + self.scheduler.eta = eta + + t_last = self.scheduler.timesteps[0] + 1 + generator = generator[0] if isinstance(generator, list) else generator + for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): + if t < t_last: + # predict the noise residual + model_output = self.unet(image, t).sample + # compute previous image: x_t -> x_t-1 + image = self.scheduler.step(model_output, t, image, original_image, mask_image, generator).prev_sample + + else: + # compute the reverse: x_t-1 -> x_t + image = self.scheduler.undo_step(image, t_last, generator) + t_last = t + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/pipelines/deprecated/score_sde_ve/__init__.py b/diffusers/pipelines/deprecated/score_sde_ve/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..87c167c3dbd26e0408a41ef197a42dc5eb7038d7 --- /dev/null +++ b/diffusers/pipelines/deprecated/score_sde_ve/__init__.py @@ -0,0 +1,19 @@ +from typing import TYPE_CHECKING + +from ....utils import DIFFUSERS_SLOW_IMPORT, _LazyModule + + +_import_structure = {"pipeline_score_sde_ve": ["ScoreSdeVePipeline"]} + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + from .pipeline_score_sde_ve import ScoreSdeVePipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) diff --git a/diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py b/diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py new file mode 100644 index 0000000000000000000000000000000000000000..c6abdba42d3ca23ce154f96fe4a7baa052d6203f --- /dev/null +++ b/diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py @@ -0,0 +1,107 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from ....models import UNet2DModel +from ....schedulers import ScoreSdeVeScheduler +from ....utils.torch_utils import randn_tensor +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +class ScoreSdeVePipeline(DiffusionPipeline): + r""" + Pipeline for unconditional image generation. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Parameters: + unet ([`UNet2DModel`]): + A `UNet2DModel` to denoise the encoded image. + scheduler ([`ScoreSdeVeScheduler`]): + A `ScoreSdeVeScheduler` to be used in combination with `unet` to denoise the encoded image. + """ + + unet: UNet2DModel + scheduler: ScoreSdeVeScheduler + + def __init__(self, unet: UNet2DModel, scheduler: ScoreSdeVeScheduler): + super().__init__() + self.register_modules(unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + batch_size: int = 1, + num_inference_steps: int = 2000, + generator: torch.Generator | list[torch.Generator] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + **kwargs, + ) -> ImagePipelineOutput | tuple: + r""" + The call function to the pipeline for generation. + + Args: + batch_size (`int`, *optional*, defaults to 1): + The number of images to generate. + generator (`torch.Generator`, `optional`): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + output_type (`str`, `optional`, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`ImagePipelineOutput`] instead of a plain tuple. + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images. + """ + + img_size = self.unet.config.sample_size + shape = (batch_size, 3, img_size, img_size) + + model = self.unet + + sample = randn_tensor(shape, generator=generator) * self.scheduler.init_noise_sigma + sample = sample.to(self.device) + + self.scheduler.set_timesteps(num_inference_steps) + self.scheduler.set_sigmas(num_inference_steps) + + for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): + sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=self.device) + + # correction step + for _ in range(self.scheduler.config.correct_steps): + model_output = self.unet(sample, sigma_t).sample + sample = self.scheduler.step_correct(model_output, sample, generator=generator).prev_sample + + # prediction step + model_output = model(sample, sigma_t).sample + output = self.scheduler.step_pred(model_output, t, sample, generator=generator) + + sample, sample_mean = output.prev_sample, output.prev_sample_mean + + sample = sample_mean.clamp(0, 1) + sample = sample.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + sample = self.numpy_to_pil(sample) + + if not return_dict: + return (sample,) + + return ImagePipelineOutput(images=sample) diff --git a/diffusers/pipelines/deprecated/spectrogram_diffusion/__init__.py b/diffusers/pipelines/deprecated/spectrogram_diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..150954baa0eb8f8a7216b4891effc14a71e21b1b --- /dev/null +++ b/diffusers/pipelines/deprecated/spectrogram_diffusion/__init__.py @@ -0,0 +1,75 @@ +# flake8: noqa +from typing import TYPE_CHECKING +from ....utils import ( + DIFFUSERS_SLOW_IMPORT, + _LazyModule, + is_note_seq_available, + OptionalDependencyNotAvailable, + is_torch_available, + is_transformers_available, + get_objects_from_module, +) + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ....utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["continous_encoder"] = ["SpectrogramContEncoder"] + _import_structure["notes_encoder"] = ["SpectrogramNotesEncoder"] + _import_structure["pipeline_spectrogram_diffusion"] = [ + "SpectrogramContEncoder", + "SpectrogramDiffusionPipeline", + "T5FilmDecoder", + ] +try: + if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ....utils import dummy_transformers_and_torch_and_note_seq_objects + + _dummy_objects.update(get_objects_from_module(dummy_transformers_and_torch_and_note_seq_objects)) +else: + _import_structure["midi_utils"] = ["MidiProcessor"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ....utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_spectrogram_diffusion import SpectrogramDiffusionPipeline + from .pipeline_spectrogram_diffusion import SpectrogramContEncoder + from .pipeline_spectrogram_diffusion import SpectrogramNotesEncoder + from .pipeline_spectrogram_diffusion import T5FilmDecoder + + try: + if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ....utils.dummy_transformers_and_torch_and_note_seq_objects import * + + else: + from .midi_utils import MidiProcessor + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py b/diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..b26e84f728693b29f62cc058c43810fc4983619b --- /dev/null +++ b/diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py @@ -0,0 +1,92 @@ +# Copyright 2022 The Music Spectrogram Diffusion Authors. +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +from transformers.modeling_utils import ModuleUtilsMixin +from transformers.models.t5.modeling_t5 import ( + T5Block, + T5Config, + T5LayerNorm, +) + +from ....configuration_utils import ConfigMixin, register_to_config +from ....models import ModelMixin + + +class SpectrogramContEncoder(ModelMixin, ConfigMixin, ModuleUtilsMixin): + @register_to_config + def __init__( + self, + input_dims: int, + targets_context_length: int, + d_model: int, + dropout_rate: float, + num_layers: int, + num_heads: int, + d_kv: int, + d_ff: int, + feed_forward_proj: str, + is_decoder: bool = False, + ): + super().__init__() + + self.input_proj = nn.Linear(input_dims, d_model, bias=False) + + self.position_encoding = nn.Embedding(targets_context_length, d_model) + self.position_encoding.weight.requires_grad = False + + self.dropout_pre = nn.Dropout(p=dropout_rate) + + t5config = T5Config( + d_model=d_model, + num_heads=num_heads, + d_kv=d_kv, + d_ff=d_ff, + feed_forward_proj=feed_forward_proj, + dropout_rate=dropout_rate, + is_decoder=is_decoder, + is_encoder_decoder=False, + ) + self.encoders = nn.ModuleList() + for lyr_num in range(num_layers): + lyr = T5Block(t5config) + self.encoders.append(lyr) + + self.layer_norm = T5LayerNorm(d_model) + self.dropout_post = nn.Dropout(p=dropout_rate) + + def forward(self, encoder_inputs, encoder_inputs_mask): + x = self.input_proj(encoder_inputs) + + # terminal relative positional encodings + max_positions = encoder_inputs.shape[1] + input_positions = torch.arange(max_positions, device=encoder_inputs.device) + + seq_lens = encoder_inputs_mask.sum(-1) + input_positions = torch.roll(input_positions.unsqueeze(0), tuple(seq_lens.tolist()), dims=0) + x += self.position_encoding(input_positions) + + x = self.dropout_pre(x) + + # inverted the attention mask + input_shape = encoder_inputs.size() + extended_attention_mask = self.get_extended_attention_mask(encoder_inputs_mask, input_shape) + + for lyr in self.encoders: + x = lyr(x, extended_attention_mask)[0] + x = self.layer_norm(x) + + return self.dropout_post(x), encoder_inputs_mask diff --git a/diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py b/diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..76b8576468d2c4502f4970d441292deda3c60965 --- /dev/null +++ b/diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py @@ -0,0 +1,667 @@ +# Copyright 2022 The Music Spectrogram Diffusion Authors. +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses +import math +import os +from typing import Any, Callable, Mapping, MutableMapping, Sequence + +import numpy as np +import torch +import torch.nn.functional as F + +from ....utils import is_note_seq_available +from .pipeline_spectrogram_diffusion import TARGET_FEATURE_LENGTH + + +if is_note_seq_available(): + import note_seq +else: + raise ImportError("Please install note-seq via `pip install note-seq`") + + +INPUT_FEATURE_LENGTH = 2048 + +SAMPLE_RATE = 16000 +HOP_SIZE = 320 +FRAME_RATE = int(SAMPLE_RATE // HOP_SIZE) + +DEFAULT_STEPS_PER_SECOND = 100 +DEFAULT_MAX_SHIFT_SECONDS = 10 +DEFAULT_NUM_VELOCITY_BINS = 1 + +SLAKH_CLASS_PROGRAMS = { + "Acoustic Piano": 0, + "Electric Piano": 4, + "Chromatic Percussion": 8, + "Organ": 16, + "Acoustic Guitar": 24, + "Clean Electric Guitar": 26, + "Distorted Electric Guitar": 29, + "Acoustic Bass": 32, + "Electric Bass": 33, + "Violin": 40, + "Viola": 41, + "Cello": 42, + "Contrabass": 43, + "Orchestral Harp": 46, + "Timpani": 47, + "String Ensemble": 48, + "Synth Strings": 50, + "Choir and Voice": 52, + "Orchestral Hit": 55, + "Trumpet": 56, + "Trombone": 57, + "Tuba": 58, + "French Horn": 60, + "Brass Section": 61, + "Soprano/Alto Sax": 64, + "Tenor Sax": 66, + "Baritone Sax": 67, + "Oboe": 68, + "English Horn": 69, + "Bassoon": 70, + "Clarinet": 71, + "Pipe": 73, + "Synth Lead": 80, + "Synth Pad": 88, +} + + +@dataclasses.dataclass +class NoteRepresentationConfig: + """Configuration note representations.""" + + onsets_only: bool + include_ties: bool + + +@dataclasses.dataclass +class NoteEventData: + pitch: int + velocity: int | None = None + program: int | None = None + is_drum: bool | None = None + instrument: int | None = None + + +@dataclasses.dataclass +class NoteEncodingState: + """Encoding state for note transcription, keeping track of active pitches.""" + + # velocity bin for active pitches and programs + active_pitches: MutableMapping[tuple[int, int], int] = dataclasses.field(default_factory=dict) + + +@dataclasses.dataclass +class EventRange: + type: str + min_value: int + max_value: int + + +@dataclasses.dataclass +class Event: + type: str + value: int + + +class Tokenizer: + def __init__(self, regular_ids: int): + # The special tokens: 0=PAD, 1=EOS, and 2=UNK + self._num_special_tokens = 3 + self._num_regular_tokens = regular_ids + + def encode(self, token_ids): + encoded = [] + for token_id in token_ids: + if not 0 <= token_id < self._num_regular_tokens: + raise ValueError( + f"token_id {token_id} does not fall within valid range of [0, {self._num_regular_tokens})" + ) + encoded.append(token_id + self._num_special_tokens) + + # Add EOS token + encoded.append(1) + + # Pad to till INPUT_FEATURE_LENGTH + encoded = encoded + [0] * (INPUT_FEATURE_LENGTH - len(encoded)) + + return encoded + + +class Codec: + """Encode and decode events. + + Useful for declaring what certain ranges of a vocabulary should be used for. This is intended to be used from + Python before encoding or after decoding with GenericTokenVocabulary. This class is more lightweight and does not + include things like EOS or UNK token handling. + + To ensure that 'shift' events are always the first block of the vocab and start at 0, that event type is required + and specified separately. + """ + + def __init__(self, max_shift_steps: int, steps_per_second: float, event_ranges: list[EventRange]): + """Define Codec. + + Args: + max_shift_steps: Maximum number of shift steps that can be encoded. + steps_per_second: Shift steps will be interpreted as having a duration of + 1 / steps_per_second. + event_ranges: Other supported event types and their ranges. + """ + self.steps_per_second = steps_per_second + self._shift_range = EventRange(type="shift", min_value=0, max_value=max_shift_steps) + self._event_ranges = [self._shift_range] + event_ranges + # Ensure all event types have unique names. + assert len(self._event_ranges) == len({er.type for er in self._event_ranges}) + + @property + def num_classes(self) -> int: + return sum(er.max_value - er.min_value + 1 for er in self._event_ranges) + + # The next couple methods are simplified special case methods just for shift + # events that are intended to be used from within autograph functions. + + def is_shift_event_index(self, index: int) -> bool: + return (self._shift_range.min_value <= index) and (index <= self._shift_range.max_value) + + @property + def max_shift_steps(self) -> int: + return self._shift_range.max_value + + def encode_event(self, event: Event) -> int: + """Encode an event to an index.""" + offset = 0 + for er in self._event_ranges: + if event.type == er.type: + if not er.min_value <= event.value <= er.max_value: + raise ValueError( + f"Event value {event.value} is not within valid range " + f"[{er.min_value}, {er.max_value}] for type {event.type}" + ) + return offset + event.value - er.min_value + offset += er.max_value - er.min_value + 1 + + raise ValueError(f"Unknown event type: {event.type}") + + def event_type_range(self, event_type: str) -> tuple[int, int]: + """Return [min_id, max_id] for an event type.""" + offset = 0 + for er in self._event_ranges: + if event_type == er.type: + return offset, offset + (er.max_value - er.min_value) + offset += er.max_value - er.min_value + 1 + + raise ValueError(f"Unknown event type: {event_type}") + + def decode_event_index(self, index: int) -> Event: + """Decode an event index to an Event.""" + offset = 0 + for er in self._event_ranges: + if offset <= index <= offset + er.max_value - er.min_value: + return Event(type=er.type, value=er.min_value + index - offset) + offset += er.max_value - er.min_value + 1 + + raise ValueError(f"Unknown event index: {index}") + + +@dataclasses.dataclass +class ProgramGranularity: + # both tokens_map_fn and program_map_fn should be idempotent + tokens_map_fn: Callable[[Sequence[int], Codec], Sequence[int]] + program_map_fn: Callable[[int], int] + + +def drop_programs(tokens, codec: Codec): + """Drops program change events from a token sequence.""" + min_program_id, max_program_id = codec.event_type_range("program") + return tokens[(tokens < min_program_id) | (tokens > max_program_id)] + + +def programs_to_midi_classes(tokens, codec): + """Modifies program events to be the first program in the MIDI class.""" + min_program_id, max_program_id = codec.event_type_range("program") + is_program = (tokens >= min_program_id) & (tokens <= max_program_id) + return np.where(is_program, min_program_id + 8 * ((tokens - min_program_id) // 8), tokens) + + +PROGRAM_GRANULARITIES = { + # "flat" granularity; drop program change tokens and set NoteSequence + # programs to zero + "flat": ProgramGranularity(tokens_map_fn=drop_programs, program_map_fn=lambda program: 0), + # map each program to the first program in its MIDI class + "midi_class": ProgramGranularity( + tokens_map_fn=programs_to_midi_classes, program_map_fn=lambda program: 8 * (program // 8) + ), + # leave programs as is + "full": ProgramGranularity(tokens_map_fn=lambda tokens, codec: tokens, program_map_fn=lambda program: program), +} + + +def frame(signal, frame_length, frame_step, pad_end=False, pad_value=0, axis=-1): + """ + equivalent of tf.signal.frame + """ + signal_length = signal.shape[axis] + if pad_end: + frames_overlap = frame_length - frame_step + rest_samples = np.abs(signal_length - frames_overlap) % np.abs(frame_length - frames_overlap) + pad_size = int(frame_length - rest_samples) + + if pad_size != 0: + pad_axis = [0] * signal.ndim + pad_axis[axis] = pad_size + signal = F.pad(signal, pad_axis, "constant", pad_value) + frames = signal.unfold(axis, frame_length, frame_step) + return frames + + +def program_to_slakh_program(program): + # this is done very hackily, probably should use a custom mapping + for slakh_program in sorted(SLAKH_CLASS_PROGRAMS.values(), reverse=True): + if program >= slakh_program: + return slakh_program + + +def audio_to_frames( + samples, + hop_size: int, + frame_rate: int, +) -> tuple[Sequence[Sequence[int]], torch.Tensor]: + """Convert audio samples to non-overlapping frames and frame times.""" + frame_size = hop_size + samples = np.pad(samples, [0, frame_size - len(samples) % frame_size], mode="constant") + + # Split audio into frames. + frames = frame( + torch.Tensor(samples).unsqueeze(0), + frame_length=frame_size, + frame_step=frame_size, + pad_end=False, # TODO check why its off by 1 here when True + ) + + num_frames = len(samples) // frame_size + + times = np.arange(num_frames) / frame_rate + return frames, times + + +def note_sequence_to_onsets_and_offsets_and_programs( + ns: note_seq.NoteSequence, +) -> tuple[Sequence[float], Sequence[NoteEventData]]: + """Extract onset & offset times and pitches & programs from a NoteSequence. + + The onset & offset times will not necessarily be in sorted order. + + Args: + ns: NoteSequence from which to extract onsets and offsets. + + Returns: + times: A list of note onset and offset times. values: A list of NoteEventData objects where velocity is zero for + note + offsets. + """ + # Sort by program and pitch and put offsets before onsets as a tiebreaker for + # subsequent stable sort. + notes = sorted(ns.notes, key=lambda note: (note.is_drum, note.program, note.pitch)) + times = [note.end_time for note in notes if not note.is_drum] + [note.start_time for note in notes] + values = [ + NoteEventData(pitch=note.pitch, velocity=0, program=note.program, is_drum=False) + for note in notes + if not note.is_drum + ] + [ + NoteEventData(pitch=note.pitch, velocity=note.velocity, program=note.program, is_drum=note.is_drum) + for note in notes + ] + return times, values + + +def num_velocity_bins_from_codec(codec: Codec): + """Get number of velocity bins from event codec.""" + lo, hi = codec.event_type_range("velocity") + return hi - lo + + +# segment an array into segments of length n +def segment(a, n): + return [a[i : i + n] for i in range(0, len(a), n)] + + +def velocity_to_bin(velocity, num_velocity_bins): + if velocity == 0: + return 0 + else: + return math.ceil(num_velocity_bins * velocity / note_seq.MAX_MIDI_VELOCITY) + + +def note_event_data_to_events( + state: NoteEncodingState | None, + value: NoteEventData, + codec: Codec, +) -> Sequence[Event]: + """Convert note event data to a sequence of events.""" + if value.velocity is None: + # onsets only, no program or velocity + return [Event("pitch", value.pitch)] + else: + num_velocity_bins = num_velocity_bins_from_codec(codec) + velocity_bin = velocity_to_bin(value.velocity, num_velocity_bins) + if value.program is None: + # onsets + offsets + velocities only, no programs + if state is not None: + state.active_pitches[(value.pitch, 0)] = velocity_bin + return [Event("velocity", velocity_bin), Event("pitch", value.pitch)] + else: + if value.is_drum: + # drum events use a separate vocabulary + return [Event("velocity", velocity_bin), Event("drum", value.pitch)] + else: + # program + velocity + pitch + if state is not None: + state.active_pitches[(value.pitch, value.program)] = velocity_bin + return [ + Event("program", value.program), + Event("velocity", velocity_bin), + Event("pitch", value.pitch), + ] + + +def note_encoding_state_to_events(state: NoteEncodingState) -> Sequence[Event]: + """Output program and pitch events for active notes plus a final tie event.""" + events = [] + for pitch, program in sorted(state.active_pitches.keys(), key=lambda k: k[::-1]): + if state.active_pitches[(pitch, program)]: + events += [Event("program", program), Event("pitch", pitch)] + events.append(Event("tie", 0)) + return events + + +def encode_and_index_events( + state, event_times, event_values, codec, frame_times, encode_event_fn, encoding_state_to_events_fn=None +): + """Encode a sequence of timed events and index to audio frame times. + + Encodes time shifts as repeated single step shifts for later run length encoding. + + Optionally, also encodes a sequence of "state events", keeping track of the current encoding state at each audio + frame. This can be used e.g. to prepend events representing the current state to a targets segment. + + Args: + state: Initial event encoding state. + event_times: Sequence of event times. + event_values: Sequence of event values. + encode_event_fn: Function that transforms event value into a sequence of one + or more Event objects. + codec: An Codec object that maps Event objects to indices. + frame_times: Time for every audio frame. + encoding_state_to_events_fn: Function that transforms encoding state into a + sequence of one or more Event objects. + + Returns: + events: Encoded events and shifts. event_start_indices: Corresponding start event index for every audio frame. + Note: one event can correspond to multiple audio indices due to sampling rate differences. This makes + splitting sequences tricky because the same event can appear at the end of one sequence and the beginning of + another. + event_end_indices: Corresponding end event index for every audio frame. Used + to ensure when slicing that one chunk ends where the next begins. Should always be true that + event_end_indices[i] = event_start_indices[i + 1]. + state_events: Encoded "state" events representing the encoding state before + each event. + state_event_indices: Corresponding state event index for every audio frame. + """ + indices = np.argsort(event_times, kind="stable") + event_steps = [round(event_times[i] * codec.steps_per_second) for i in indices] + event_values = [event_values[i] for i in indices] + + events = [] + state_events = [] + event_start_indices = [] + state_event_indices = [] + + cur_step = 0 + cur_event_idx = 0 + cur_state_event_idx = 0 + + def fill_event_start_indices_to_cur_step(): + while ( + len(event_start_indices) < len(frame_times) + and frame_times[len(event_start_indices)] < cur_step / codec.steps_per_second + ): + event_start_indices.append(cur_event_idx) + state_event_indices.append(cur_state_event_idx) + + for event_step, event_value in zip(event_steps, event_values): + while event_step > cur_step: + events.append(codec.encode_event(Event(type="shift", value=1))) + cur_step += 1 + fill_event_start_indices_to_cur_step() + cur_event_idx = len(events) + cur_state_event_idx = len(state_events) + if encoding_state_to_events_fn: + # Dump state to state events *before* processing the next event, because + # we want to capture the state prior to the occurrence of the event. + for e in encoding_state_to_events_fn(state): + state_events.append(codec.encode_event(e)) + + for e in encode_event_fn(state, event_value, codec): + events.append(codec.encode_event(e)) + + # After the last event, continue filling out the event_start_indices array. + # The inequality is not strict because if our current step lines up exactly + # with (the start of) an audio frame, we need to add an additional shift event + # to "cover" that frame. + while cur_step / codec.steps_per_second <= frame_times[-1]: + events.append(codec.encode_event(Event(type="shift", value=1))) + cur_step += 1 + fill_event_start_indices_to_cur_step() + cur_event_idx = len(events) + + # Now fill in event_end_indices. We need this extra array to make sure that + # when we slice events, each slice ends exactly where the subsequent slice + # begins. + event_end_indices = event_start_indices[1:] + [len(events)] + + events = np.array(events).astype(np.int32) + state_events = np.array(state_events).astype(np.int32) + event_start_indices = segment(np.array(event_start_indices).astype(np.int32), TARGET_FEATURE_LENGTH) + event_end_indices = segment(np.array(event_end_indices).astype(np.int32), TARGET_FEATURE_LENGTH) + state_event_indices = segment(np.array(state_event_indices).astype(np.int32), TARGET_FEATURE_LENGTH) + + outputs = [] + for start_indices, end_indices, event_indices in zip(event_start_indices, event_end_indices, state_event_indices): + outputs.append( + { + "inputs": events, + "event_start_indices": start_indices, + "event_end_indices": end_indices, + "state_events": state_events, + "state_event_indices": event_indices, + } + ) + + return outputs + + +def extract_sequence_with_indices(features, state_events_end_token=None, feature_key="inputs"): + """Extract target sequence corresponding to audio token segment.""" + features = features.copy() + start_idx = features["event_start_indices"][0] + end_idx = features["event_end_indices"][-1] + + features[feature_key] = features[feature_key][start_idx:end_idx] + + if state_events_end_token is not None: + # Extract the state events corresponding to the audio start token, and + # prepend them to the targets array. + state_event_start_idx = features["state_event_indices"][0] + state_event_end_idx = state_event_start_idx + 1 + while features["state_events"][state_event_end_idx - 1] != state_events_end_token: + state_event_end_idx += 1 + features[feature_key] = np.concatenate( + [ + features["state_events"][state_event_start_idx:state_event_end_idx], + features[feature_key], + ], + axis=0, + ) + + return features + + +def map_midi_programs( + feature, codec: Codec, granularity_type: str = "full", feature_key: str = "inputs" +) -> Mapping[str, Any]: + """Apply MIDI program map to token sequences.""" + granularity = PROGRAM_GRANULARITIES[granularity_type] + + feature[feature_key] = granularity.tokens_map_fn(feature[feature_key], codec) + return feature + + +def run_length_encode_shifts_fn( + features, + codec: Codec, + feature_key: str = "inputs", + state_change_event_types: Sequence[str] = (), +) -> Callable[[Mapping[str, Any]], Mapping[str, Any]]: + """Return a function that run-length encodes shifts for a given codec. + + Args: + codec: The Codec to use for shift events. + feature_key: The feature key for which to run-length encode shifts. + state_change_event_types: A list of event types that represent state + changes; tokens corresponding to these event types will be interpreted as state changes and redundant ones + will be removed. + + Returns: + A preprocessing function that run-length encodes single-step shifts. + """ + state_change_event_ranges = [codec.event_type_range(event_type) for event_type in state_change_event_types] + + def run_length_encode_shifts(features: MutableMapping[str, Any]) -> Mapping[str, Any]: + """Combine leading/interior shifts, trim trailing shifts. + + Args: + features: Dict of features to process. + + Returns: + A dict of features. + """ + events = features[feature_key] + + shift_steps = 0 + total_shift_steps = 0 + output = np.array([], dtype=np.int32) + + current_state = np.zeros(len(state_change_event_ranges), dtype=np.int32) + + for event in events: + if codec.is_shift_event_index(event): + shift_steps += 1 + total_shift_steps += 1 + + else: + # If this event is a state change and has the same value as the current + # state, we can skip it entirely. + is_redundant = False + for i, (min_index, max_index) in enumerate(state_change_event_ranges): + if (min_index <= event) and (event <= max_index): + if current_state[i] == event: + is_redundant = True + current_state[i] = event + if is_redundant: + continue + + # Once we've reached a non-shift event, RLE all previous shift events + # before outputting the non-shift event. + if shift_steps > 0: + shift_steps = total_shift_steps + while shift_steps > 0: + output_steps = np.minimum(codec.max_shift_steps, shift_steps) + output = np.concatenate([output, [output_steps]], axis=0) + shift_steps -= output_steps + output = np.concatenate([output, [event]], axis=0) + + features[feature_key] = output + return features + + return run_length_encode_shifts(features) + + +def note_representation_processor_chain(features, codec: Codec, note_representation_config: NoteRepresentationConfig): + tie_token = codec.encode_event(Event("tie", 0)) + state_events_end_token = tie_token if note_representation_config.include_ties else None + + features = extract_sequence_with_indices( + features, state_events_end_token=state_events_end_token, feature_key="inputs" + ) + + features = map_midi_programs(features, codec) + + features = run_length_encode_shifts_fn(features, codec, state_change_event_types=["velocity", "program"]) + + return features + + +class MidiProcessor: + def __init__(self): + self.codec = Codec( + max_shift_steps=DEFAULT_MAX_SHIFT_SECONDS * DEFAULT_STEPS_PER_SECOND, + steps_per_second=DEFAULT_STEPS_PER_SECOND, + event_ranges=[ + EventRange("pitch", note_seq.MIN_MIDI_PITCH, note_seq.MAX_MIDI_PITCH), + EventRange("velocity", 0, DEFAULT_NUM_VELOCITY_BINS), + EventRange("tie", 0, 0), + EventRange("program", note_seq.MIN_MIDI_PROGRAM, note_seq.MAX_MIDI_PROGRAM), + EventRange("drum", note_seq.MIN_MIDI_PITCH, note_seq.MAX_MIDI_PITCH), + ], + ) + self.tokenizer = Tokenizer(self.codec.num_classes) + self.note_representation_config = NoteRepresentationConfig(onsets_only=False, include_ties=True) + + def __call__(self, midi: bytes | os.PathLike | str): + if not isinstance(midi, bytes): + with open(midi, "rb") as f: + midi = f.read() + + ns = note_seq.midi_to_note_sequence(midi) + ns_sus = note_seq.apply_sustain_control_changes(ns) + + for note in ns_sus.notes: + if not note.is_drum: + note.program = program_to_slakh_program(note.program) + + samples = np.zeros(int(ns_sus.total_time * SAMPLE_RATE)) + + _, frame_times = audio_to_frames(samples, HOP_SIZE, FRAME_RATE) + times, values = note_sequence_to_onsets_and_offsets_and_programs(ns_sus) + + events = encode_and_index_events( + state=NoteEncodingState(), + event_times=times, + event_values=values, + frame_times=frame_times, + codec=self.codec, + encode_event_fn=note_event_data_to_events, + encoding_state_to_events_fn=note_encoding_state_to_events, + ) + + events = [ + note_representation_processor_chain(event, self.codec, self.note_representation_config) for event in events + ] + input_tokens = [self.tokenizer.encode(event["inputs"]) for event in events] + + return input_tokens diff --git a/diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py b/diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..25ad4a4ccfd28b2423ee9771130a24c941c06649 --- /dev/null +++ b/diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py @@ -0,0 +1,86 @@ +# Copyright 2022 The Music Spectrogram Diffusion Authors. +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +from transformers.modeling_utils import ModuleUtilsMixin +from transformers.models.t5.modeling_t5 import T5Block, T5Config, T5LayerNorm + +from ....configuration_utils import ConfigMixin, register_to_config +from ....models import ModelMixin + + +class SpectrogramNotesEncoder(ModelMixin, ConfigMixin, ModuleUtilsMixin): + @register_to_config + def __init__( + self, + max_length: int, + vocab_size: int, + d_model: int, + dropout_rate: float, + num_layers: int, + num_heads: int, + d_kv: int, + d_ff: int, + feed_forward_proj: str, + is_decoder: bool = False, + ): + super().__init__() + + self.token_embedder = nn.Embedding(vocab_size, d_model) + + self.position_encoding = nn.Embedding(max_length, d_model) + self.position_encoding.weight.requires_grad = False + + self.dropout_pre = nn.Dropout(p=dropout_rate) + + t5config = T5Config( + vocab_size=vocab_size, + d_model=d_model, + num_heads=num_heads, + d_kv=d_kv, + d_ff=d_ff, + dropout_rate=dropout_rate, + feed_forward_proj=feed_forward_proj, + is_decoder=is_decoder, + is_encoder_decoder=False, + ) + + self.encoders = nn.ModuleList() + for lyr_num in range(num_layers): + lyr = T5Block(t5config) + self.encoders.append(lyr) + + self.layer_norm = T5LayerNorm(d_model) + self.dropout_post = nn.Dropout(p=dropout_rate) + + def forward(self, encoder_input_tokens, encoder_inputs_mask): + x = self.token_embedder(encoder_input_tokens) + + seq_length = encoder_input_tokens.shape[1] + inputs_positions = torch.arange(seq_length, device=encoder_input_tokens.device) + x += self.position_encoding(inputs_positions) + + x = self.dropout_pre(x) + + # inverted the attention mask + input_shape = encoder_input_tokens.size() + extended_attention_mask = self.get_extended_attention_mask(encoder_inputs_mask, input_shape) + + for lyr in self.encoders: + x = lyr(x, extended_attention_mask)[0] + x = self.layer_norm(x) + + return self.dropout_post(x), encoder_inputs_mask diff --git a/diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py b/diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..269e7405d10d2be9f44d481b05725b2918727ea9 --- /dev/null +++ b/diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py @@ -0,0 +1,269 @@ +# Copyright 2022 The Music Spectrogram Diffusion Authors. +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Any, Callable + +import numpy as np +import torch + +from ....models import T5FilmDecoder +from ....schedulers import DDPMScheduler +from ....utils import is_onnx_available, logging +from ....utils.torch_utils import randn_tensor + + +if is_onnx_available(): + from ...onnx_utils import OnnxRuntimeModel + +from ...pipeline_utils import AudioPipelineOutput, DiffusionPipeline +from .continuous_encoder import SpectrogramContEncoder +from .notes_encoder import SpectrogramNotesEncoder + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +TARGET_FEATURE_LENGTH = 256 + + +class SpectrogramDiffusionPipeline(DiffusionPipeline): + r""" + Pipeline for unconditional audio generation. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + notes_encoder ([`SpectrogramNotesEncoder`]): + continuous_encoder ([`SpectrogramContEncoder`]): + decoder ([`T5FilmDecoder`]): + A [`T5FilmDecoder`] to denoise the encoded audio latents. + scheduler ([`DDPMScheduler`]): + A scheduler to be used in combination with `decoder` to denoise the encoded audio latents. + melgan ([`OnnxRuntimeModel`]): + """ + + _optional_components = ["melgan"] + + def __init__( + self, + notes_encoder: SpectrogramNotesEncoder, + continuous_encoder: SpectrogramContEncoder, + decoder: T5FilmDecoder, + scheduler: DDPMScheduler, + melgan: OnnxRuntimeModel if is_onnx_available() else Any, + ) -> None: + super().__init__() + + # From MELGAN + self.min_value = math.log(1e-5) # Matches MelGAN training. + self.max_value = 4.0 # Largest value for most examples + self.n_dims = 128 + + self.register_modules( + notes_encoder=notes_encoder, + continuous_encoder=continuous_encoder, + decoder=decoder, + scheduler=scheduler, + melgan=melgan, + ) + + def scale_features(self, features, output_range=(-1.0, 1.0), clip=False): + """Linearly scale features to network outputs range.""" + min_out, max_out = output_range + if clip: + features = torch.clip(features, self.min_value, self.max_value) + # Scale to [0, 1]. + zero_one = (features - self.min_value) / (self.max_value - self.min_value) + # Scale to [min_out, max_out]. + return zero_one * (max_out - min_out) + min_out + + def scale_to_features(self, outputs, input_range=(-1.0, 1.0), clip=False): + """Invert by linearly scaling network outputs to features range.""" + min_out, max_out = input_range + outputs = torch.clip(outputs, min_out, max_out) if clip else outputs + # Scale to [0, 1]. + zero_one = (outputs - min_out) / (max_out - min_out) + # Scale to [self.min_value, self.max_value]. + return zero_one * (self.max_value - self.min_value) + self.min_value + + def encode(self, input_tokens, continuous_inputs, continuous_mask): + tokens_mask = input_tokens > 0 + tokens_encoded, tokens_mask = self.notes_encoder( + encoder_input_tokens=input_tokens, encoder_inputs_mask=tokens_mask + ) + + continuous_encoded, continuous_mask = self.continuous_encoder( + encoder_inputs=continuous_inputs, encoder_inputs_mask=continuous_mask + ) + + return [(tokens_encoded, tokens_mask), (continuous_encoded, continuous_mask)] + + def decode(self, encodings_and_masks, input_tokens, noise_time): + timesteps = noise_time + if not torch.is_tensor(timesteps): + timesteps = torch.tensor([timesteps], dtype=torch.long, device=input_tokens.device) + elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: + timesteps = timesteps[None].to(input_tokens.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps * torch.ones(input_tokens.shape[0], dtype=timesteps.dtype, device=timesteps.device) + + logits = self.decoder( + encodings_and_masks=encodings_and_masks, decoder_input_tokens=input_tokens, decoder_noise_time=timesteps + ) + return logits + + @torch.no_grad() + def __call__( + self, + input_tokens: list[list[int]], + generator: torch.Generator | None = None, + num_inference_steps: int = 100, + return_dict: bool = True, + output_type: str = "np", + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + ) -> AudioPipelineOutput | tuple: + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + r""" + The call function to the pipeline for generation. + + Args: + input_tokens (`list[list[int]]`): + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality audio at the + expense of slower inference. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.AudioPipelineOutput`] instead of a plain tuple. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated audio. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + + Example: + + ```py + >>> from diffusers import SpectrogramDiffusionPipeline, MidiProcessor + + >>> pipe = SpectrogramDiffusionPipeline.from_pretrained("google/music-spectrogram-diffusion") + >>> pipe = pipe.to("cuda") + >>> processor = MidiProcessor() + + >>> # Download MIDI from: wget http://www.piano-midi.de/midis/beethoven/beethoven_hammerklavier_2.mid + >>> output = pipe(processor("beethoven_hammerklavier_2.mid")) + + >>> audio = output.audios[0] + ``` + + Returns: + [`pipelines.AudioPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`pipelines.AudioPipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated audio. + """ + + pred_mel = np.zeros([1, TARGET_FEATURE_LENGTH, self.n_dims], dtype=np.float32) + full_pred_mel = np.zeros([1, 0, self.n_dims], np.float32) + ones = torch.ones((1, TARGET_FEATURE_LENGTH), dtype=bool, device=self.device) + + for i, encoder_input_tokens in enumerate(input_tokens): + if i == 0: + encoder_continuous_inputs = torch.from_numpy(pred_mel[:1].copy()).to( + device=self.device, dtype=self.decoder.dtype + ) + # The first chunk has no previous context. + encoder_continuous_mask = torch.zeros((1, TARGET_FEATURE_LENGTH), dtype=bool, device=self.device) + else: + # The full song pipeline does not feed in a context feature, so the mask + # will be all 0s after the feature converter. Because we know we're + # feeding in a full context chunk from the previous prediction, set it + # to all 1s. + encoder_continuous_mask = ones + + encoder_continuous_inputs = self.scale_features( + encoder_continuous_inputs, output_range=[-1.0, 1.0], clip=True + ) + + encodings_and_masks = self.encode( + input_tokens=torch.IntTensor([encoder_input_tokens]).to(device=self.device), + continuous_inputs=encoder_continuous_inputs, + continuous_mask=encoder_continuous_mask, + ) + + # Sample encoder_continuous_inputs shaped gaussian noise to begin loop + x = randn_tensor( + shape=encoder_continuous_inputs.shape, + generator=generator, + device=self.device, + dtype=self.decoder.dtype, + ) + + # set step values + self.scheduler.set_timesteps(num_inference_steps) + + # Denoising diffusion loop + for j, t in enumerate(self.progress_bar(self.scheduler.timesteps)): + output = self.decode( + encodings_and_masks=encodings_and_masks, + input_tokens=x, + noise_time=t / self.scheduler.config.num_train_timesteps, # rescale to [0, 1) + ) + + # Compute previous output: x_t -> x_t-1 + x = self.scheduler.step(output, t, x, generator=generator).prev_sample + + mel = self.scale_to_features(x, input_range=[-1.0, 1.0]) + encoder_continuous_inputs = mel[:1] + pred_mel = mel.cpu().float().numpy() + + full_pred_mel = np.concatenate([full_pred_mel, pred_mel[:1]], axis=1) + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, full_pred_mel) + + logger.info("Generated segment", i) + + if output_type == "np" and not is_onnx_available(): + raise ValueError( + "Cannot return output in 'np' format if ONNX is not available. Make sure to have ONNX installed or set 'output_type' to 'mel'." + ) + elif output_type == "np" and self.melgan is None: + raise ValueError( + "Cannot return output in 'np' format if melgan component is not defined. Make sure to define `self.melgan` or set 'output_type' to 'mel'." + ) + + if output_type == "np": + output = self.melgan(input_features=full_pred_mel.astype(np.float32)) + else: + output = full_pred_mel + + if not return_dict: + return (output,) + + return AudioPipelineOutput(audios=output) diff --git a/diffusers/pipelines/deprecated/stable_diffusion_variants/__init__.py b/diffusers/pipelines/deprecated/stable_diffusion_variants/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..36cf1a33ce6ada8e718aabadb9a706737aee30bd --- /dev/null +++ b/diffusers/pipelines/deprecated/stable_diffusion_variants/__init__.py @@ -0,0 +1,55 @@ +from typing import TYPE_CHECKING + +from ....utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ....utils import dummy_torch_and_transformers_objects + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_cycle_diffusion"] = ["CycleDiffusionPipeline"] + _import_structure["pipeline_stable_diffusion_inpaint_legacy"] = ["StableDiffusionInpaintPipelineLegacy"] + _import_structure["pipeline_stable_diffusion_model_editing"] = ["StableDiffusionModelEditingPipeline"] + + _import_structure["pipeline_stable_diffusion_paradigms"] = ["StableDiffusionParadigmsPipeline"] + _import_structure["pipeline_stable_diffusion_pix2pix_zero"] = ["StableDiffusionPix2PixZeroPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ....utils.dummy_torch_and_transformers_objects import * + + else: + from .pipeline_cycle_diffusion import CycleDiffusionPipeline + from .pipeline_stable_diffusion_inpaint_legacy import StableDiffusionInpaintPipelineLegacy + from .pipeline_stable_diffusion_model_editing import StableDiffusionModelEditingPipeline + from .pipeline_stable_diffusion_paradigms import StableDiffusionParadigmsPipeline + from .pipeline_stable_diffusion_pix2pix_zero import StableDiffusionPix2PixZeroPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py b/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..dae5e600d773a388b5fa5dc79dfde78005429554 --- /dev/null +++ b/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py @@ -0,0 +1,953 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import PIL.Image +import torch +from packaging import version +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from ....configuration_utils import FrozenDict +from ....image_processor import PipelineImageInput, VaeImageProcessor +from ....loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ....models import AutoencoderKL, UNet2DConditionModel +from ....models.lora import adjust_lora_scale_text_encoder +from ....schedulers import DDIMScheduler +from ....utils import PIL_INTERPOLATION, USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers +from ....utils.torch_utils import randn_tensor +from ...pipeline_utils import DiffusionPipeline +from ...stable_diffusion.pipeline_output import StableDiffusionPipelineOutput +from ...stable_diffusion.safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess +def preprocess(image): + deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead" + deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False) + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 + + image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + return image + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +def posterior_sample(scheduler, latents, timestep, clean_latents, generator, eta): + # 1. get previous step value (=t-1) + prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps + + if prev_timestep <= 0: + return clean_latents + + # 2. compute alphas, betas + alpha_prod_t = scheduler.alphas_cumprod[timestep] + alpha_prod_t_prev = ( + scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod + ) + + variance = scheduler._get_variance(timestep, prev_timestep) + std_dev_t = eta * variance ** (0.5) + + # direction pointing to x_t + e_t = (latents - alpha_prod_t ** (0.5) * clean_latents) / (1 - alpha_prod_t) ** (0.5) + dir_xt = (1.0 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * e_t + noise = std_dev_t * randn_tensor( + clean_latents.shape, dtype=clean_latents.dtype, device=clean_latents.device, generator=generator + ) + prev_latents = alpha_prod_t_prev ** (0.5) * clean_latents + dir_xt + noise + + return prev_latents + + +def compute_noise(scheduler, prev_latents, latents, timestep, noise_pred, eta): + # 1. get previous step value (=t-1) + prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps + + # 2. compute alphas, betas + alpha_prod_t = scheduler.alphas_cumprod[timestep] + alpha_prod_t_prev = ( + scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod + ) + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://huggingface.co/papers/2010.02502 + pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5) + + # 4. Clip "predicted x_0" + if scheduler.config.clip_sample: + pred_original_sample = torch.clamp(pred_original_sample, -1, 1) + + # 5. compute variance: "sigma_t(η)" -> see formula (16) + # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) + variance = scheduler._get_variance(timestep, prev_timestep) + std_dev_t = eta * variance ** (0.5) + + # 6. compute "direction pointing to x_t" of formula (12) from https://huggingface.co/papers/2010.02502 + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * noise_pred + + noise = (prev_latents - (alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction)) / ( + variance ** (0.5) * eta + ) + return noise + + +class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin): + r""" + Pipeline for text-guided image to image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can only be an + instance of [`DDIMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for + more details about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: DDIMScheduler, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + is_unet_version_less_0_9_0 = ( + unet is not None + and hasattr(unet.config, "_diffusers_version") + and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0") + ) + is_unet_sample_size_less_64 = ( + unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5" + " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + def check_inputs( + self, + prompt, + strength, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): + image = image.to(device=device, dtype=dtype) + + batch_size = image.shape[0] + + if image.shape[1] == 4: + init_latents = image + + else: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if isinstance(generator, list): + init_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + init_latents = self.vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0) + + # add noise to latents using the timestep + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # get latents + clean_latents = init_latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + latents = init_latents + + return latents, clean_latents + + @torch.no_grad() + def __call__( + self, + prompt: str | list[str], + source_prompt: str | list[str], + image: PipelineImageInput = None, + strength: float = 0.8, + num_inference_steps: int | None = 50, + guidance_scale: float | None = 7.5, + source_guidance_scale: float | None = 1, + num_images_per_prompt: int | None = 1, + eta: float | None = 0.1, + generator: torch.Generator | list[torch.Generator] | None = None, + prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + cross_attention_kwargs: dict[str, Any] | None = None, + clip_skip: int | None = None, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`): + The prompt or prompts to guide the image generation. + image (`torch.Tensor` `np.ndarray`, `PIL.Image.Image`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image` or tensor representing an image batch to be used as the starting point. Can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + strength (`float`, *optional*, defaults to 0.8): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter is modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + source_guidance_scale (`float`, *optional*, defaults to 1): + Guidance scale for the source prompt. This is useful to control the amount of influence the source + prompt has for encoding. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + Example: + + ```py + import requests + import torch + from PIL import Image + from io import BytesIO + + from diffusers import CycleDiffusionPipeline, DDIMScheduler + + # load the pipeline + # make sure you're logged in with `hf auth login` + model_id_or_path = "CompVis/stable-diffusion-v1-4" + scheduler = DDIMScheduler.from_pretrained(model_id_or_path, subfolder="scheduler") + pipe = CycleDiffusionPipeline.from_pretrained(model_id_or_path, scheduler=scheduler).to("cuda") + + # let's download an initial image + url = "https://raw.githubusercontent.com/ChenWu98/cycle-diffusion/main/data/dalle2/An%20astronaut%20riding%20a%20horse.png" + response = requests.get(url) + init_image = Image.open(BytesIO(response.content)).convert("RGB") + init_image = init_image.resize((512, 512)) + init_image.save("horse.png") + + # let's specify a prompt + source_prompt = "An astronaut riding a horse" + prompt = "An astronaut riding an elephant" + + # call the pipeline + image = pipe( + prompt=prompt, + source_prompt=source_prompt, + image=init_image, + num_inference_steps=100, + eta=0.1, + strength=0.8, + guidance_scale=2, + source_guidance_scale=1, + ).images[0] + + image.save("horse_to_elephant.png") + + # let's try another example + # See more samples at the original repo: https://github.com/ChenWu98/cycle-diffusion + url = ( + "https://raw.githubusercontent.com/ChenWu98/cycle-diffusion/main/data/dalle2/A%20black%20colored%20car.png" + ) + response = requests.get(url) + init_image = Image.open(BytesIO(response.content)).convert("RGB") + init_image = init_image.resize((512, 512)) + init_image.save("black.png") + + source_prompt = "A black colored car" + prompt = "A blue colored car" + + # call the pipeline + torch.manual_seed(0) + image = pipe( + prompt=prompt, + source_prompt=source_prompt, + image=init_image, + num_inference_steps=100, + eta=0.1, + strength=0.85, + guidance_scale=3, + source_guidance_scale=1, + ).images[0] + + image.save("black_to_blue.png") + ``` + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + # 1. Check inputs + self.check_inputs(prompt, strength, callback_steps) + + # 2. Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds_tuple = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=clip_skip, + ) + source_prompt_embeds_tuple = self.encode_prompt( + source_prompt, device, num_images_per_prompt, do_classifier_free_guidance, None, clip_skip=clip_skip + ) + if prompt_embeds_tuple[1] is not None: + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + else: + prompt_embeds = prompt_embeds_tuple[0] + if source_prompt_embeds_tuple[1] is not None: + source_prompt_embeds = torch.cat([source_prompt_embeds_tuple[1], source_prompt_embeds_tuple[0]]) + else: + source_prompt_embeds = source_prompt_embeds_tuple[0] + + # 4. Preprocess image + image = self.image_processor.preprocess(image) + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 6. Prepare latent variables + latents, clean_latents = self.prepare_latents( + image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator + ) + source_latents = latents + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + generator = extra_step_kwargs.pop("generator", None) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + source_latent_model_input = ( + torch.cat([source_latents] * 2) if do_classifier_free_guidance else source_latents + ) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + source_latent_model_input = self.scheduler.scale_model_input(source_latent_model_input, t) + + # predict the noise residual + if do_classifier_free_guidance: + concat_latent_model_input = torch.stack( + [ + source_latent_model_input[0], + latent_model_input[0], + source_latent_model_input[1], + latent_model_input[1], + ], + dim=0, + ) + concat_prompt_embeds = torch.stack( + [ + source_prompt_embeds[0], + prompt_embeds[0], + source_prompt_embeds[1], + prompt_embeds[1], + ], + dim=0, + ) + else: + concat_latent_model_input = torch.cat( + [ + source_latent_model_input, + latent_model_input, + ], + dim=0, + ) + concat_prompt_embeds = torch.cat( + [ + source_prompt_embeds, + prompt_embeds, + ], + dim=0, + ) + + concat_noise_pred = self.unet( + concat_latent_model_input, + t, + cross_attention_kwargs=cross_attention_kwargs, + encoder_hidden_states=concat_prompt_embeds, + ).sample + + # perform guidance + if do_classifier_free_guidance: + ( + source_noise_pred_uncond, + noise_pred_uncond, + source_noise_pred_text, + noise_pred_text, + ) = concat_noise_pred.chunk(4, dim=0) + + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + source_noise_pred = source_noise_pred_uncond + source_guidance_scale * ( + source_noise_pred_text - source_noise_pred_uncond + ) + + else: + (source_noise_pred, noise_pred) = concat_noise_pred.chunk(2, dim=0) + + # Sample source_latents from the posterior distribution. + prev_source_latents = posterior_sample( + self.scheduler, source_latents, t, clean_latents, generator=generator, **extra_step_kwargs + ) + # Compute noise. + noise = compute_noise( + self.scheduler, prev_source_latents, source_latents, t, source_noise_pred, **extra_step_kwargs + ) + source_latents = prev_source_latents + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, t, latents, variance_noise=noise, **extra_step_kwargs + ).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + # 9. Post-processing + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py b/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py new file mode 100644 index 0000000000000000000000000000000000000000..f526dc419cea63d5e60fadae61a10c07f266dc1a --- /dev/null +++ b/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py @@ -0,0 +1,543 @@ +import inspect +from typing import Callable + +import numpy as np +import PIL.Image +import torch +from transformers import CLIPImageProcessor, CLIPTokenizer + +from ....configuration_utils import FrozenDict +from ....schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ....utils import deprecate, logging +from ...onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel +from ...pipeline_utils import DiffusionPipeline +from ...stable_diffusion.pipeline_output import StableDiffusionPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def preprocess(image): + w, h = image.size + w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL.Image.LANCZOS) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + return 2.0 * image - 1.0 + + +def preprocess_mask(mask, scale_factor=8): + mask = mask.convert("L") + w, h = mask.size + w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 + mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL.Image.NEAREST) + mask = np.array(mask).astype(np.float32) / 255.0 + mask = np.tile(mask, (4, 1, 1)) + mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? + mask = 1 - mask # repaint white, keep black + return mask + + +class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline): + r""" + Pipeline for text-guided image inpainting using Stable Diffusion. This is a *legacy feature* for Onnx pipelines to + provide compatibility with StableDiffusionInpaintPipelineLegacy and may be removed in the future. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for + details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + _optional_components = ["safety_checker", "feature_extractor"] + _is_onnx = True + + vae_encoder: OnnxRuntimeModel + vae_decoder: OnnxRuntimeModel + text_encoder: OnnxRuntimeModel + tokenizer: CLIPTokenizer + unet: OnnxRuntimeModel + scheduler: DDIMScheduler | PNDMScheduler | LMSDiscreteScheduler + safety_checker: OnnxRuntimeModel + feature_extractor: CLIPImageProcessor + + def __init__( + self, + vae_encoder: OnnxRuntimeModel, + vae_decoder: OnnxRuntimeModel, + text_encoder: OnnxRuntimeModel, + tokenizer: CLIPTokenizer, + unet: OnnxRuntimeModel, + scheduler: DDIMScheduler | PNDMScheduler | LMSDiscreteScheduler, + safety_checker: OnnxRuntimeModel, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae_encoder=vae_encoder, + vae_decoder=vae_decoder, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt: str | list[str], + num_images_per_prompt: int | None, + do_classifier_free_guidance: bool, + negative_prompt: str | None, + prompt_embeds: np.ndarray | None = None, + negative_prompt_embeds: np.ndarray | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`): + prompt to be encoded + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + prompt_embeds (`np.ndarray`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`np.ndarray`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids + + if not np.array_equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] + + prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] * batch_size + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="np", + ) + negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] + + if do_classifier_free_guidance: + negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + def check_inputs( + self, + prompt, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def __call__( + self, + prompt: str | list[str], + image: np.ndarray | PIL.Image.Image = None, + mask_image: np.ndarray | PIL.Image.Image = None, + strength: float = 0.8, + num_inference_steps: int | None = 50, + guidance_scale: float | None = 7.5, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float | None = 0.0, + generator: np.random.RandomState | None = None, + prompt_embeds: np.ndarray | None = None, + negative_prompt_embeds: np.ndarray | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, np.ndarray], None] | None = None, + callback_steps: int = 1, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`): + The prompt or prompts to guide the image generation. + image (`nd.ndarray` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. This is the image whose masked region will be inpainted. + mask_image (`nd.ndarray` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be + replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a + PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should + contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.uu + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter will be modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (?) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`np.random.RandomState`, *optional*): + A np.random.RandomState to make generation deterministic. + prompt_embeds (`np.ndarray`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`np.ndarray`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + + # check inputs. Raise error if not correct + self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) + + # define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if generator is None: + generator = np.random + + # set timesteps + self.scheduler.set_timesteps(num_inference_steps) + + if isinstance(image, PIL.Image.Image): + image = preprocess(image) + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + prompt_embeds = self._encode_prompt( + prompt, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + latents_dtype = prompt_embeds.dtype + image = image.astype(latents_dtype) + + # encode the init image into latents and scale the latents + init_latents = self.vae_encoder(sample=image)[0] + init_latents = 0.18215 * init_latents + + # Expand init_latents for batch_size and num_images_per_prompt + init_latents = np.concatenate([init_latents] * num_images_per_prompt, axis=0) + init_latents_orig = init_latents + + # preprocess mask + if not isinstance(mask_image, np.ndarray): + mask_image = preprocess_mask(mask_image, 8) + mask_image = mask_image.astype(latents_dtype) + mask = np.concatenate([mask_image] * num_images_per_prompt, axis=0) + + # check sizes + if not mask.shape == init_latents.shape: + raise ValueError("The mask and image should be the same size!") + + # get the original timestep using init_timestep + offset = self.scheduler.config.get("steps_offset", 0) + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + + timesteps = self.scheduler.timesteps.numpy()[-init_timestep] + timesteps = np.array([timesteps] * batch_size * num_images_per_prompt) + + # add noise to latents using the timesteps + noise = generator.randn(*init_latents.shape).astype(latents_dtype) + init_latents = self.scheduler.add_noise( + torch.from_numpy(init_latents), torch.from_numpy(noise), torch.from_numpy(timesteps) + ) + init_latents = init_latents.numpy() + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (?) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to ? in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + latents = init_latents + + t_start = max(num_inference_steps - init_timestep + offset, 0) + timesteps = self.scheduler.timesteps[t_start:].numpy() + timestep_dtype = next( + (input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)" + ) + timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] + + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + timestep = np.array([t], dtype=timestep_dtype) + noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)[ + 0 + ] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs + ).prev_sample + + latents = latents.numpy() + + init_latents_proper = self.scheduler.add_noise( + torch.from_numpy(init_latents_orig), torch.from_numpy(noise), torch.from_numpy(np.array([t])) + ) + + init_latents_proper = init_latents_proper.numpy() + + latents = (init_latents_proper * mask) + (latents * (1 - mask)) + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + latents = 1 / 0.18215 * latents + # image = self.vae_decoder(latent_sample=latents)[0] + # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 + image = np.concatenate( + [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])] + ) + + image = np.clip(image / 2 + 0.5, 0, 1) + image = image.transpose((0, 2, 3, 1)) + + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor( + self.numpy_to_pil(image), return_tensors="np" + ).pixel_values.astype(image.dtype) + # There will throw an error if use safety_checker batchsize>1 + images, has_nsfw_concept = [], [] + for i in range(image.shape[0]): + image_i, has_nsfw_concept_i = self.safety_checker( + clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1] + ) + images.append(image_i) + has_nsfw_concept.append(has_nsfw_concept_i[0]) + image = np.concatenate(images) + else: + has_nsfw_concept = None + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py b/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py new file mode 100644 index 0000000000000000000000000000000000000000..650695b604c1014f1b02d78ace239e82029014ee --- /dev/null +++ b/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py @@ -0,0 +1,789 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import PIL.Image +import torch +from packaging import version +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from ....configuration_utils import FrozenDict +from ....image_processor import VaeImageProcessor +from ....loaders import FromSingleFileMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ....models import AutoencoderKL, UNet2DConditionModel +from ....models.lora import adjust_lora_scale_text_encoder +from ....schedulers import KarrasDiffusionSchedulers +from ....utils import PIL_INTERPOLATION, USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers +from ....utils.torch_utils import randn_tensor +from ...pipeline_utils import DiffusionPipeline +from ...stable_diffusion import StableDiffusionPipelineOutput +from ...stable_diffusion.safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) + + +def preprocess_image(image, batch_size): + w, h = image.size + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 + image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) + image = np.array(image).astype(np.float32) / 255.0 + image = np.vstack([image[None].transpose(0, 3, 1, 2)] * batch_size) + image = torch.from_numpy(image) + return 2.0 * image - 1.0 + + +def preprocess_mask(mask, batch_size, scale_factor=8): + if not isinstance(mask, torch.Tensor): + mask = mask.convert("L") + w, h = mask.size + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 + mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"]) + mask = np.array(mask).astype(np.float32) / 255.0 + mask = np.tile(mask, (4, 1, 1)) + mask = np.vstack([mask[None]] * batch_size) + mask = 1 - mask # repaint white, keep black + mask = torch.from_numpy(mask) + return mask + + else: + valid_mask_channel_sizes = [1, 3] + # if mask channel is fourth tensor dimension, permute dimensions to pytorch standard (B, C, H, W) + if mask.shape[3] in valid_mask_channel_sizes: + mask = mask.permute(0, 3, 1, 2) + elif mask.shape[1] not in valid_mask_channel_sizes: + raise ValueError( + f"Mask channel dimension of size in {valid_mask_channel_sizes} should be second or fourth dimension," + f" but received mask of shape {tuple(mask.shape)}" + ) + # (potentially) reduce mask channel dimension from 3 to 1 for broadcasting to latent shape + mask = mask.mean(dim=1, keepdim=True) + h, w = mask.shape[-2:] + h, w = (x - x % 8 for x in (h, w)) # resize to integer multiple of 8 + mask = torch.nn.functional.interpolate(mask, (h // scale_factor, w // scale_factor)) + return mask + + +class StableDiffusionInpaintPipelineLegacy( + DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin, FromSingleFileMixin +): + r""" + Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + In addition the pipeline inherits the following loading methods: + - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] + - *LoRA*: [`loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] + - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] + + as well as the following saving methods: + - *LoRA*: [`loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for + details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + _optional_components = ["feature_extractor"] + _exclude_from_cpu_offload = ["safety_checker"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + deprecation_message = ( + f"The class {self.__class__} is deprecated and will be removed in v1.0.0. You can achieve exactly the same functionality" + "by loading your model into `StableDiffusionInpaintPipeline` instead. See https://github.com/huggingface/diffusers/pull/3533" + "for more information." + ) + deprecate("legacy is outdated", "1.0.0", deprecation_message, standard_warn=False) + + if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = ( + unet is not None + and hasattr(unet.config, "_diffusers_version") + and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0") + ) + is_unet_sample_size_less_64 = ( + unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5" + " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + strength, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + def prepare_latents(self, image, timestep, num_images_per_prompt, dtype, device, generator): + image = image.to(device=device, dtype=dtype) + init_latent_dist = self.vae.encode(image).latent_dist + init_latents = init_latent_dist.sample(generator=generator) + init_latents = self.vae.config.scaling_factor * init_latents + + # Expand init_latents for batch_size and num_images_per_prompt + init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0) + init_latents_orig = init_latents + + # add noise to latents using the timesteps + noise = randn_tensor(init_latents.shape, generator=generator, device=device, dtype=dtype) + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + latents = init_latents + return latents, init_latents_orig, noise + + @torch.no_grad() + def __call__( + self, + prompt: str | list[str] = None, + image: torch.Tensor | PIL.Image.Image = None, + mask_image: torch.Tensor | PIL.Image.Image = None, + strength: float = 0.8, + num_inference_steps: int | None = 50, + guidance_scale: float | None = 7.5, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + add_predicted_noise: bool | None = False, + eta: float | None = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + cross_attention_kwargs: dict[str, Any] | None = None, + clip_skip: int | None = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.Tensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. This is the image whose masked region will be inpainted. + mask_image (`torch.Tensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be + replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a + PIL image, it will be converted to a single channel (luminance) before use. If mask is a tensor, the + expected shape should be either `(B, H, W, C)` or `(B, C, H, W)`, where C is 1 or 3. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength` + is 1, the denoising process will be run on the masked area for the full number of iterations specified + in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more noise to + that region the larger the `strength`. If `strength` is 0, no inpainting will occur. + num_inference_steps (`int`, *optional*, defaults to 50): + The reference number of denoising steps. More denoising steps usually lead to a higher quality image at + the expense of slower inference. This parameter will be modulated by `strength`, as explained above. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` + is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + add_predicted_noise (`bool`, *optional*, defaults to True): + Use predicted noise instead of random noise when constructing noisy versions of the original image in + the reverse diffusion process + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 1. Check inputs + self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Preprocess image and mask + if not isinstance(image, torch.Tensor): + image = preprocess_image(image, batch_size) + + mask_image = preprocess_mask(mask_image, batch_size, self.vae_scale_factor) + + # 5. set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 6. Prepare latent variables + # encode the init image into latents and scale the latents + latents, init_latents_orig, noise = self.prepare_latents( + image, latent_timestep, num_images_per_prompt, prompt_embeds.dtype, device, generator + ) + + # 7. Prepare mask latent + mask = mask_image.to(device=device, dtype=latents.dtype) + mask = torch.cat([mask] * num_images_per_prompt) + + # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 9. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + # masking + if add_predicted_noise: + init_latents_proper = self.scheduler.add_noise( + init_latents_orig, noise_pred_uncond, torch.tensor([t]) + ) + else: + init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t])) + + latents = (init_latents_proper * mask) + (latents * (1 - mask)) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + # use original latents corresponding to unmasked portions of the image + latents = (init_latents_orig * mask) + (latents * (1 - mask)) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py b/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py new file mode 100644 index 0000000000000000000000000000000000000000..851820c00aedf6089aa11c1c59d4d663fa41435b --- /dev/null +++ b/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py @@ -0,0 +1,832 @@ +# Copyright 2025 TIME Authors and The HuggingFace Team. All rights reserved." +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import inspect +from typing import Any, Callable + +import torch +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from ....image_processor import VaeImageProcessor +from ....loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ....models import AutoencoderKL, UNet2DConditionModel +from ....models.lora import adjust_lora_scale_text_encoder +from ....schedulers import PNDMScheduler +from ....schedulers.scheduling_utils import SchedulerMixin +from ....utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers +from ....utils.torch_utils import randn_tensor +from ...pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ...stable_diffusion.pipeline_output import StableDiffusionPipelineOutput +from ...stable_diffusion.safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +AUGS_CONST = ["A photo of ", "An image of ", "A picture of "] + + +class StableDiffusionModelEditingPipeline( + DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin +): + r""" + Pipeline for text-to-image model editing. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for + more details about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + with_to_k ([`bool`]): + Whether to edit the key projection matrices along with the value projection matrices. + with_augs ([`list`]): + Textual augmentations to apply while editing the text-to-image model. Set to `[]` for no augmentations. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor"] + _exclude_from_cpu_offload = ["safety_checker"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: SchedulerMixin, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + with_to_k: bool = True, + with_augs: list = AUGS_CONST, + ): + super().__init__() + + if isinstance(scheduler, PNDMScheduler): + logger.error("PNDMScheduler for this pipeline is currently not supported.") + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + self.with_to_k = with_to_k + self.with_augs = with_augs + + # get cross-attention layers + ca_layers = [] + + def append_ca(net_): + if net_.__class__.__name__ == "CrossAttention": + ca_layers.append(net_) + elif hasattr(net_, "children"): + for net__ in net_.children(): + append_ca(net__) + + # recursively find all cross-attention layers in unet + for net in self.unet.named_children(): + if "down" in net[0]: + append_ca(net[1]) + elif "up" in net[0]: + append_ca(net[1]) + elif "mid" in net[0]: + append_ca(net[1]) + + # get projection matrices + self.ca_clip_layers = [l for l in ca_layers if l.to_v.in_features == 768] + self.projection_matrices = [l.to_v for l in self.ca_clip_layers] + self.og_matrices = [copy.deepcopy(l.to_v) for l in self.ca_clip_layers] + if self.with_to_k: + self.projection_matrices = self.projection_matrices + [l.to_k for l in self.ca_clip_layers] + self.og_matrices = self.og_matrices + [copy.deepcopy(l.to_k) for l in self.ca_clip_layers] + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def edit_model( + self, + source_prompt: str, + destination_prompt: str, + lamb: float = 0.1, + restart_params: bool = True, + ): + r""" + Apply model editing via closed-form solution (see Eq. 5 in the TIME + [paper](https://huggingface.co/papers/2303.08084)). + + Args: + source_prompt (`str`): + The source prompt containing the concept to be edited. + destination_prompt (`str`): + The destination prompt. Must contain all words from `source_prompt` with additional ones to specify the + target edit. + lamb (`float`, *optional*, defaults to 0.1): + The lambda parameter specifying the regularization intensity. Smaller values increase the editing + power. + restart_params (`bool`, *optional*, defaults to True): + Restart the model parameters to their pre-trained version before editing. This is done to avoid edit + compounding. When it is `False`, edits accumulate. + """ + + # restart LDM parameters + if restart_params: + num_ca_clip_layers = len(self.ca_clip_layers) + for idx_, l in enumerate(self.ca_clip_layers): + l.to_v = copy.deepcopy(self.og_matrices[idx_]) + self.projection_matrices[idx_] = l.to_v + if self.with_to_k: + l.to_k = copy.deepcopy(self.og_matrices[num_ca_clip_layers + idx_]) + self.projection_matrices[num_ca_clip_layers + idx_] = l.to_k + + # set up sentences + old_texts = [source_prompt] + new_texts = [destination_prompt] + # add augmentations + base = old_texts[0] if old_texts[0][0:1] != "A" else "a" + old_texts[0][1:] + for aug in self.with_augs: + old_texts.append(aug + base) + base = new_texts[0] if new_texts[0][0:1] != "A" else "a" + new_texts[0][1:] + for aug in self.with_augs: + new_texts.append(aug + base) + + # prepare input k* and v* + old_embs, new_embs = [], [] + for old_text, new_text in zip(old_texts, new_texts): + text_input = self.tokenizer( + [old_text, new_text], + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] + old_emb, new_emb = text_embeddings + old_embs.append(old_emb) + new_embs.append(new_emb) + + # identify corresponding destinations for each token in old_emb + idxs_replaces = [] + for old_text, new_text in zip(old_texts, new_texts): + tokens_a = self.tokenizer(old_text).input_ids + tokens_b = self.tokenizer(new_text).input_ids + tokens_a = [self.tokenizer.encode("a ")[1] if self.tokenizer.decode(t) == "an" else t for t in tokens_a] + tokens_b = [self.tokenizer.encode("a ")[1] if self.tokenizer.decode(t) == "an" else t for t in tokens_b] + num_orig_tokens = len(tokens_a) + idxs_replace = [] + j = 0 + for i in range(num_orig_tokens): + curr_token = tokens_a[i] + while tokens_b[j] != curr_token: + j += 1 + idxs_replace.append(j) + j += 1 + while j < 77: + idxs_replace.append(j) + j += 1 + while len(idxs_replace) < 77: + idxs_replace.append(76) + idxs_replaces.append(idxs_replace) + + # prepare batch: for each pair of sentences, old context and new values + contexts, valuess = [], [] + for old_emb, new_emb, idxs_replace in zip(old_embs, new_embs, idxs_replaces): + context = old_emb.detach() + values = [] + with torch.no_grad(): + for layer in self.projection_matrices: + values.append(layer(new_emb[idxs_replace]).detach()) + contexts.append(context) + valuess.append(values) + + # edit the model + for layer_num in range(len(self.projection_matrices)): + # mat1 = \lambda W + \sum{v k^T} + mat1 = lamb * self.projection_matrices[layer_num].weight + + # mat2 = \lambda I + \sum{k k^T} + mat2 = lamb * torch.eye( + self.projection_matrices[layer_num].weight.shape[1], + device=self.projection_matrices[layer_num].weight.device, + ) + + # aggregate sums for mat1, mat2 + for context, values in zip(contexts, valuess): + context_vector = context.reshape(context.shape[0], context.shape[1], 1) + context_vector_T = context.reshape(context.shape[0], 1, context.shape[1]) + value_vector = values[layer_num].reshape(values[layer_num].shape[0], values[layer_num].shape[1], 1) + for_mat1 = (value_vector @ context_vector_T).sum(dim=0) + for_mat2 = (context_vector @ context_vector_T).sum(dim=0) + mat1 += for_mat1 + mat2 += for_mat2 + + # update projection matrix + self.projection_matrices[layer_num].weight = torch.nn.Parameter(mat1 @ torch.inverse(mat2)) + + @torch.no_grad() + def __call__( + self, + prompt: str | list[str] = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + cross_attention_kwargs: dict[str, Any] | None = None, + clip_skip: int | None = None, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + + Examples: + + ```py + >>> import torch + >>> from diffusers import StableDiffusionModelEditingPipeline + + >>> model_ckpt = "CompVis/stable-diffusion-v1-4" + >>> pipe = StableDiffusionModelEditingPipeline.from_pretrained(model_ckpt) + + >>> pipe = pipe.to("cuda") + + >>> source_prompt = "A pack of roses" + >>> destination_prompt = "A pack of blue roses" + >>> pipe.edit_model(source_prompt, destination_prompt) + + >>> prompt = "A field of roses" + >>> image = pipe(prompt).images[0] + ``` + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py b/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py new file mode 100644 index 0000000000000000000000000000000000000000..ea81be87a0f4b246a0b328e050da5a07357141e9 --- /dev/null +++ b/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py @@ -0,0 +1,798 @@ +# Copyright 2025 ParaDiGMS authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import torch +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from ....image_processor import VaeImageProcessor +from ....loaders import FromSingleFileMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ....models import AutoencoderKL, UNet2DConditionModel +from ....models.lora import adjust_lora_scale_text_encoder +from ....schedulers import KarrasDiffusionSchedulers +from ....utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ....utils.torch_utils import randn_tensor +from ...pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ...stable_diffusion.pipeline_output import StableDiffusionPipelineOutput +from ...stable_diffusion.safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import DDPMParallelScheduler + >>> from diffusers import StableDiffusionParadigmsPipeline + + >>> scheduler = DDPMParallelScheduler.from_pretrained( + ... "stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="scheduler" + ... ) + + >>> pipe = StableDiffusionParadigmsPipeline.from_pretrained( + ... "stable-diffusion-v1-5/stable-diffusion-v1-5", scheduler=scheduler, torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> ngpu, batch_per_device = torch.cuda.device_count(), 5 + >>> pipe.wrapped_unet = torch.nn.DataParallel(pipe.unet, device_ids=[d for d in range(ngpu)]) + + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt, parallel=ngpu * batch_per_device, num_inference_steps=1000).images[0] + ``` +""" + + +class StableDiffusionParadigmsPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionLoraLoaderMixin, + FromSingleFileMixin, +): + r""" + Pipeline for text-to-image generation using a parallelized version of Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for + more details about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor"] + _exclude_from_cpu_offload = ["safety_checker"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # attribute to wrap the unet with torch.nn.DataParallel when running multiple denoising steps on multiple GPUs + self.wrapped_unet = self.unet + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def _cumsum(self, input, dim, debug=False): + if debug: + # cumsum_cuda_kernel does not have a deterministic implementation + # so perform cumsum on cpu for debugging purposes + return torch.cumsum(input.cpu().float(), dim=dim).to(input.device) + else: + return torch.cumsum(input, dim=dim) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + parallel: int = 10, + tolerance: float = 0.1, + guidance_scale: float = 7.5, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + cross_attention_kwargs: dict[str, Any] | None = None, + debug: bool = False, + clip_skip: int = None, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + parallel (`int`, *optional*, defaults to 10): + The batch size to use when doing parallel sampling. More parallelism may lead to faster inference but + requires higher memory usage and can also require more total FLOPs. + tolerance (`float`, *optional*, defaults to 0.1): + The error tolerance for determining when to slide the batch window forward for parallel sampling. Lower + tolerance usually leads to less or no degradation. Higher tolerance is faster but can risk degradation + of sample quality. The tolerance is specified as a ratio of the scheduler's noise magnitude. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + debug (`bool`, *optional*, defaults to `False`): + Whether or not to run in debug mode. In debug mode, `torch.cumsum` is evaluated using the CPU. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + clip_skip=clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + extra_step_kwargs.pop("generator", None) + + # # 7. Denoising loop + scheduler = self.scheduler + parallel = min(parallel, len(scheduler.timesteps)) + + begin_idx = 0 + end_idx = parallel + latents_time_evolution_buffer = torch.stack([latents] * (len(scheduler.timesteps) + 1)) + + # We must make sure the noise of stochastic schedulers such as DDPM is sampled only once per timestep. + # Sampling inside the parallel denoising loop will mess this up, so we pre-sample the noise vectors outside the denoising loop. + noise_array = torch.zeros_like(latents_time_evolution_buffer) + for j in range(len(scheduler.timesteps)): + base_noise = randn_tensor( + shape=latents.shape, generator=generator, device=latents.device, dtype=prompt_embeds.dtype + ) + noise = (self.scheduler._get_variance(scheduler.timesteps[j]) ** 0.5) * base_noise + noise_array[j] = noise.clone() + + # We specify the error tolerance as a ratio of the scheduler's noise magnitude. We similarly compute the error tolerance + # outside of the denoising loop to avoid recomputing it at every step. + # We will be dividing the norm of the noise, so we store its inverse here to avoid a division at every step. + inverse_variance_norm = 1.0 / torch.tensor( + [scheduler._get_variance(scheduler.timesteps[j]) for j in range(len(scheduler.timesteps))] + [0] + ).to(noise_array.device) + latent_dim = noise_array[0, 0].numel() + inverse_variance_norm = inverse_variance_norm[:, None] / latent_dim + + scaled_tolerance = tolerance**2 + + with self.progress_bar(total=num_inference_steps) as progress_bar: + steps = 0 + while begin_idx < len(scheduler.timesteps): + # these have shape (parallel_dim, 2*batch_size, ...) + # parallel_len is at most parallel, but could be less if we are at the end of the timesteps + # we are processing batch window of timesteps spanning [begin_idx, end_idx) + parallel_len = end_idx - begin_idx + + block_prompt_embeds = torch.stack([prompt_embeds] * parallel_len) + block_latents = latents_time_evolution_buffer[begin_idx:end_idx] + block_t = scheduler.timesteps[begin_idx:end_idx, None].repeat(1, batch_size * num_images_per_prompt) + t_vec = block_t + if do_classifier_free_guidance: + t_vec = t_vec.repeat(1, 2) + + # expand the latents if we are doing classifier free guidance + latent_model_input = ( + torch.cat([block_latents] * 2, dim=1) if do_classifier_free_guidance else block_latents + ) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t_vec) + + # if parallel_len is small, no need to use multiple GPUs + net = self.wrapped_unet if parallel_len > 3 else self.unet + # predict the noise residual, shape is now [parallel_len * 2 * batch_size * num_images_per_prompt, ...] + model_output = net( + latent_model_input.flatten(0, 1), + t_vec.flatten(0, 1), + encoder_hidden_states=block_prompt_embeds.flatten(0, 1), + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + per_latent_shape = model_output.shape[1:] + if do_classifier_free_guidance: + model_output = model_output.reshape( + parallel_len, 2, batch_size * num_images_per_prompt, *per_latent_shape + ) + noise_pred_uncond, noise_pred_text = model_output[:, 0], model_output[:, 1] + model_output = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + model_output = model_output.reshape( + parallel_len * batch_size * num_images_per_prompt, *per_latent_shape + ) + + block_latents_denoise = scheduler.batch_step_no_noise( + model_output=model_output, + timesteps=block_t.flatten(0, 1), + sample=block_latents.flatten(0, 1), + **extra_step_kwargs, + ).reshape(block_latents.shape) + + # back to shape (parallel_dim, batch_size, ...) + # now we want to add the pre-sampled noise + # parallel sampling algorithm requires computing the cumulative drift from the beginning + # of the window, so we need to compute cumulative sum of the deltas and the pre-sampled noises. + delta = block_latents_denoise - block_latents + cumulative_delta = self._cumsum(delta, dim=0, debug=debug) + cumulative_noise = self._cumsum(noise_array[begin_idx:end_idx], dim=0, debug=debug) + + # if we are using an ODE-like scheduler (like DDIM), we don't want to add noise + if scheduler._is_ode_scheduler: + cumulative_noise = 0 + + block_latents_new = ( + latents_time_evolution_buffer[begin_idx][None,] + cumulative_delta + cumulative_noise + ) + cur_error = torch.linalg.norm( + (block_latents_new - latents_time_evolution_buffer[begin_idx + 1 : end_idx + 1]).reshape( + parallel_len, batch_size * num_images_per_prompt, -1 + ), + dim=-1, + ).pow(2) + error_ratio = cur_error * inverse_variance_norm[begin_idx + 1 : end_idx + 1] + + # find the first index of the vector error_ratio that is greater than error tolerance + # we can shift the window for the next iteration up to this index + error_ratio = torch.nn.functional.pad( + error_ratio, (0, 0, 0, 1), value=1e9 + ) # handle the case when everything is below ratio, by padding the end of parallel_len dimension + any_error_at_time = torch.max(error_ratio > scaled_tolerance, dim=1).values.int() + ind = torch.argmax(any_error_at_time).item() + + # compute the new begin and end idxs for the window + new_begin_idx = begin_idx + min(1 + ind, parallel) + new_end_idx = min(new_begin_idx + parallel, len(scheduler.timesteps)) + + # store the computed latents for the current window in the global buffer + latents_time_evolution_buffer[begin_idx + 1 : end_idx + 1] = block_latents_new + # initialize the new sliding window latents with the end of the current window, + # should be better than random initialization + latents_time_evolution_buffer[end_idx : new_end_idx + 1] = latents_time_evolution_buffer[end_idx][ + None, + ] + + steps += 1 + + progress_bar.update(new_begin_idx - begin_idx) + if callback is not None and steps % callback_steps == 0: + callback(begin_idx, block_t[begin_idx], latents_time_evolution_buffer[begin_idx]) + + begin_idx = new_begin_idx + end_idx = new_end_idx + + latents = latents_time_evolution_buffer[-1] + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py b/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py new file mode 100644 index 0000000000000000000000000000000000000000..0955b6fe48a149af5f76ce717729fbb21a1238a1 --- /dev/null +++ b/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py @@ -0,0 +1,1311 @@ +# Copyright 2025 Pix2Pix Zero Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass +from typing import Any, Callable + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import ( + BlipForConditionalGeneration, + BlipProcessor, + CLIPImageProcessor, + CLIPTextModel, + CLIPTokenizer, +) + +from ....image_processor import PipelineImageInput, VaeImageProcessor +from ....loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ....models import AutoencoderKL, UNet2DConditionModel +from ....models.attention_processor import Attention +from ....models.lora import adjust_lora_scale_text_encoder +from ....schedulers import DDIMScheduler, DDPMScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler +from ....schedulers.scheduling_ddim_inverse import DDIMInverseScheduler +from ....utils import ( + PIL_INTERPOLATION, + USE_PEFT_BACKEND, + BaseOutput, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ....utils.torch_utils import randn_tensor +from ...pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ...stable_diffusion.pipeline_output import StableDiffusionPipelineOutput +from ...stable_diffusion.safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class Pix2PixInversionPipelineOutput(BaseOutput, TextualInversionLoaderMixin): + """ + Output class for Stable Diffusion pipelines. + + Args: + latents (`torch.Tensor`) + inverted latents tensor + images (`list[PIL.Image.Image]` or `np.ndarray`) + list of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + latents: torch.Tensor + images: list[PIL.Image.Image] | np.ndarray + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import requests + >>> import torch + + >>> from diffusers import DDIMScheduler, StableDiffusionPix2PixZeroPipeline + + + >>> def download(embedding_url, local_filepath): + ... r = requests.get(embedding_url) + ... with open(local_filepath, "wb") as f: + ... f.write(r.content) + + + >>> model_ckpt = "CompVis/stable-diffusion-v1-4" + >>> pipeline = StableDiffusionPix2PixZeroPipeline.from_pretrained(model_ckpt, torch_dtype=torch.float16) + >>> pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) + >>> pipeline.to("cuda") + + >>> prompt = "a high resolution painting of a cat in the style of van gough" + >>> source_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/cat.pt" + >>> target_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/dog.pt" + + >>> for url in [source_emb_url, target_emb_url]: + ... download(url, url.split("/")[-1]) + + >>> src_embeds = torch.load(source_emb_url.split("/")[-1]) + >>> target_embeds = torch.load(target_emb_url.split("/")[-1]) + >>> images = pipeline( + ... prompt, + ... source_embeds=src_embeds, + ... target_embeds=target_embeds, + ... num_inference_steps=50, + ... cross_attention_guidance_amount=0.15, + ... ).images + + >>> images[0].save("edited_image_dog.png") + ``` +""" + +EXAMPLE_INVERT_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from transformers import BlipForConditionalGeneration, BlipProcessor + >>> from diffusers import DDIMScheduler, DDIMInverseScheduler, StableDiffusionPix2PixZeroPipeline + + >>> import requests + >>> from PIL import Image + + >>> captioner_id = "Salesforce/blip-image-captioning-base" + >>> processor = BlipProcessor.from_pretrained(captioner_id) + >>> model = BlipForConditionalGeneration.from_pretrained( + ... captioner_id, torch_dtype=torch.float16, low_cpu_mem_usage=True + ... ) + + >>> sd_model_ckpt = "CompVis/stable-diffusion-v1-4" + >>> pipeline = StableDiffusionPix2PixZeroPipeline.from_pretrained( + ... sd_model_ckpt, + ... caption_generator=model, + ... caption_processor=processor, + ... torch_dtype=torch.float16, + ... safety_checker=None, + ... ) + + >>> pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) + >>> pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config) + >>> pipeline.enable_model_cpu_offload() + + >>> img_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/test_images/cats/cat_6.png" + + >>> raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB").resize((512, 512)) + >>> # generate caption + >>> caption = pipeline.generate_caption(raw_image) + + >>> # "a photography of a cat with flowers and dai dai daie - daie - daie kasaii" + >>> inv_latents = pipeline.invert(caption, image=raw_image).latents + >>> # we need to generate source and target embeds + + >>> source_prompts = ["a cat sitting on the street", "a cat playing in the field", "a face of a cat"] + + >>> target_prompts = ["a dog sitting on the street", "a dog playing in the field", "a face of a dog"] + + >>> source_embeds = pipeline.get_embeds(source_prompts) + >>> target_embeds = pipeline.get_embeds(target_prompts) + >>> # the latents can then be used to edit a real image + >>> # when using Stable Diffusion 2 or other models that use v-prediction + >>> # set `cross_attention_guidance_amount` to 0.01 or less to avoid input latent gradient explosion + + >>> image = pipeline( + ... caption, + ... source_embeds=source_embeds, + ... target_embeds=target_embeds, + ... num_inference_steps=50, + ... cross_attention_guidance_amount=0.15, + ... generator=generator, + ... latents=inv_latents, + ... negative_prompt=caption, + ... ).images[0] + >>> image.save("edited_image.png") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess +def preprocess(image): + deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead" + deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False) + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 + + image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + return image + + +def prepare_unet(unet: UNet2DConditionModel): + """Modifies the UNet (`unet`) to perform Pix2Pix Zero optimizations.""" + pix2pix_zero_attn_procs = {} + for name in unet.attn_processors.keys(): + module_name = name.replace(".processor", "") + module = unet.get_submodule(module_name) + if "attn2" in name: + pix2pix_zero_attn_procs[name] = Pix2PixZeroAttnProcessor(is_pix2pix_zero=True) + module.requires_grad_(True) + else: + pix2pix_zero_attn_procs[name] = Pix2PixZeroAttnProcessor(is_pix2pix_zero=False) + module.requires_grad_(False) + + unet.set_attn_processor(pix2pix_zero_attn_procs) + return unet + + +class Pix2PixZeroL2Loss: + def __init__(self): + self.loss = 0.0 + + def compute_loss(self, predictions, targets): + self.loss += ((predictions - targets) ** 2).sum((1, 2)).mean(0) + + +class Pix2PixZeroAttnProcessor: + """An attention processor class to store the attention weights. + In Pix2Pix Zero, it happens during computations in the cross-attention blocks.""" + + def __init__(self, is_pix2pix_zero=False): + self.is_pix2pix_zero = is_pix2pix_zero + if self.is_pix2pix_zero: + self.reference_cross_attn_map = {} + + def __call__( + self, + attn: Attention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + timestep=None, + loss=None, + ): + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + if self.is_pix2pix_zero and timestep is not None: + # new bookkeeping to save the attention weights. + if loss is None: + self.reference_cross_attn_map[timestep.item()] = attention_probs.detach().cpu() + # compute loss + elif loss is not None: + prev_attn_probs = self.reference_cross_attn_map.pop(timestep.item()) + loss.compute_loss(attention_probs, prev_attn_probs.to(attention_probs.device)) + + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline, StableDiffusionMixin): + r""" + Pipeline for pixel-level image editing using Pix2Pix Zero. Based on Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`], or [`DDPMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for + details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + requires_safety_checker (bool): + Whether the pipeline requires a safety checker. We recommend setting it to True if you're using the + pipeline publicly. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + _optional_components = [ + "safety_checker", + "feature_extractor", + "caption_generator", + "caption_processor", + "inverse_scheduler", + ] + _exclude_from_cpu_offload = ["safety_checker"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: DDPMScheduler | DDIMScheduler | EulerAncestralDiscreteScheduler | LMSDiscreteScheduler, + feature_extractor: CLIPImageProcessor, + safety_checker: StableDiffusionSafetyChecker, + inverse_scheduler: DDIMInverseScheduler, + caption_generator: BlipForConditionalGeneration, + caption_processor: BlipProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + caption_processor=caption_processor, + caption_generator=caption_generator, + inverse_scheduler=inverse_scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + source_embeds, + target_embeds, + callback_steps, + prompt_embeds=None, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + if source_embeds is None and target_embeds is None: + raise ValueError("`source_embeds` and `target_embeds` cannot be undefined.") + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def generate_caption(self, images): + """Generates caption for a given image.""" + text = "a photography of" + + prev_device = self.caption_generator.device + + device = self._execution_device + inputs = self.caption_processor(images, text, return_tensors="pt").to( + device=device, dtype=self.caption_generator.dtype + ) + self.caption_generator.to(device) + outputs = self.caption_generator.generate(**inputs, max_new_tokens=128) + + # offload caption generator + self.caption_generator.to(prev_device) + + caption = self.caption_processor.batch_decode(outputs, skip_special_tokens=True)[0] + return caption + + def construct_direction(self, embs_source: torch.Tensor, embs_target: torch.Tensor): + """Constructs the edit direction to steer the image generation process semantically.""" + return (embs_target.mean(0) - embs_source.mean(0)).unsqueeze(0) + + @torch.no_grad() + def get_embeds(self, prompt: list[str], batch_size: int = 16) -> torch.Tensor: + num_prompts = len(prompt) + embeds = [] + for i in range(0, num_prompts, batch_size): + prompt_slice = prompt[i : i + batch_size] + + input_ids = self.tokenizer( + prompt_slice, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ).input_ids + + input_ids = input_ids.to(self.text_encoder.device) + embeds.append(self.text_encoder(input_ids)[0]) + + return torch.cat(embeds, dim=0).mean(0)[None] + + def prepare_image_latents(self, image, batch_size, dtype, device, generator=None): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + if image.shape[1] == 4: + latents = image + + else: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if isinstance(generator, list): + latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + ] + latents = torch.cat(latents, dim=0) + else: + latents = self.vae.encode(image).latent_dist.sample(generator) + + latents = self.vae.config.scaling_factor * latents + + if batch_size != latents.shape[0]: + if batch_size % latents.shape[0] == 0: + # expand image_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_latents_per_image = batch_size // latents.shape[0] + latents = torch.cat([latents] * additional_latents_per_image, dim=0) + else: + raise ValueError( + f"Cannot duplicate `image` of batch size {latents.shape[0]} to {batch_size} text prompts." + ) + else: + latents = torch.cat([latents], dim=0) + + return latents + + def get_epsilon(self, model_output: torch.Tensor, sample: torch.Tensor, timestep: int): + pred_type = self.inverse_scheduler.config.prediction_type + alpha_prod_t = self.inverse_scheduler.alphas_cumprod[timestep] + + beta_prod_t = 1 - alpha_prod_t + + if pred_type == "epsilon": + return model_output + elif pred_type == "sample": + return (sample - alpha_prod_t ** (0.5) * model_output) / beta_prod_t ** (0.5) + elif pred_type == "v_prediction": + return (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"prediction_type given as {pred_type} must be one of `epsilon`, `sample`, or `v_prediction`" + ) + + def auto_corr_loss(self, hidden_states, generator=None): + reg_loss = 0.0 + for i in range(hidden_states.shape[0]): + for j in range(hidden_states.shape[1]): + noise = hidden_states[i : i + 1, j : j + 1, :, :] + while True: + roll_amount = torch.randint(noise.shape[2] // 2, (1,), generator=generator).item() + reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=2)).mean() ** 2 + reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=3)).mean() ** 2 + + if noise.shape[2] <= 8: + break + noise = F.avg_pool2d(noise, kernel_size=2) + return reg_loss + + def kl_divergence(self, hidden_states): + mean = hidden_states.mean() + var = hidden_states.var() + return var + mean**2 - 1 - torch.log(var + 1e-7) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] | None = None, + source_embeds: torch.Tensor = None, + target_embeds: torch.Tensor = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + cross_attention_guidance_amount: float = 0.1, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int | None = 1, + cross_attention_kwargs: dict[str, Any] | None = None, + clip_skip: int | None = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + source_embeds (`torch.Tensor`): + Source concept embeddings. Generation of the embeddings as per the [original + paper](https://huggingface.co/papers/2302.03027). Used in discovering the edit direction. + target_embeds (`torch.Tensor`): + Target concept embeddings. Generation of the embeddings as per the [original + paper](https://huggingface.co/papers/2302.03027). Used in discovering the edit direction. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + cross_attention_guidance_amount (`float`, defaults to 0.1): + Amount of guidance needed from the reference cross-attention maps. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Define the spatial resolutions. + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + source_embeds, + target_embeds, + callback_steps, + prompt_embeds, + ) + + # 3. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + if cross_attention_kwargs is None: + cross_attention_kwargs = {} + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + clip_skip=clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Generate the inverted noise from the input image or any other image + # generated from the input prompt. + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + latents_init = latents.clone() + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 8. Rejig the UNet so that we can obtain the cross-attenion maps and + # use them for guiding the subsequent image generation. + self.unet = prepare_unet(self.unet) + + # 7. Denoising loop where we obtain the cross-attention maps. + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs={"timestep": t}, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + # 8. Compute the edit directions. + edit_direction = self.construct_direction(source_embeds, target_embeds).to(prompt_embeds.device) + + # 9. Edit the prompt embeddings as per the edit directions discovered. + prompt_embeds_edit = prompt_embeds.clone() + prompt_embeds_edit[1:2] += edit_direction + + # 10. Second denoising loop to generate the edited image. + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + latents = latents_init + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # we want to learn the latent such that it steers the generation + # process towards the edited direction, so make the make initial + # noise learnable + x_in = latent_model_input.detach().clone() + x_in.requires_grad = True + + # optimizer + opt = torch.optim.SGD([x_in], lr=cross_attention_guidance_amount) + + with torch.enable_grad(): + # initialize loss + loss = Pix2PixZeroL2Loss() + + # predict the noise residual + noise_pred = self.unet( + x_in, + t, + encoder_hidden_states=prompt_embeds_edit.detach(), + cross_attention_kwargs={"timestep": t, "loss": loss}, + ).sample + + loss.loss.backward(retain_graph=False) + opt.step() + + # recompute the noise + noise_pred = self.unet( + x_in.detach(), + t, + encoder_hidden_states=prompt_embeds_edit, + cross_attention_kwargs={"timestep": None}, + ).sample + + latents = x_in.detach().chunk(2)[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_INVERT_DOC_STRING) + def invert( + self, + prompt: str | None = None, + image: PipelineImageInput = None, + num_inference_steps: int = 50, + guidance_scale: float = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + cross_attention_guidance_amount: float = 0.1, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int | None = 1, + cross_attention_kwargs: dict[str, Any] | None = None, + lambda_auto_corr: float = 20.0, + lambda_kl: float = 20.0, + num_reg_steps: int = 5, + num_auto_corr_rolls: int = 5, + ): + r""" + Function used to generate inverted latents given a prompt and image. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.Tensor` `np.ndarray`, `PIL.Image.Image`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image`, or tensor representing an image batch which will be used for conditioning. Can also accept + image latents as `image`, if passing latents directly, it will not be encoded again. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 1): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + cross_attention_guidance_amount (`float`, defaults to 0.1): + Amount of guidance needed from the reference cross-attention maps. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + lambda_auto_corr (`float`, *optional*, defaults to 20.0): + Lambda parameter to control auto correction + lambda_kl (`float`, *optional*, defaults to 20.0): + Lambda parameter to control Kullback–Leibler divergence output + num_reg_steps (`int`, *optional*, defaults to 5): + Number of regularization loss steps + num_auto_corr_rolls (`int`, *optional*, defaults to 5): + Number of auto correction roll steps + + Examples: + + Returns: + [`~pipelines.stable_diffusion.pipeline_stable_diffusion_pix2pix_zero.Pix2PixInversionPipelineOutput`] or + `tuple`: + [`~pipelines.stable_diffusion.pipeline_stable_diffusion_pix2pix_zero.Pix2PixInversionPipelineOutput`] if + `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is the inverted + latents tensor and then second is the corresponding decoded image. + """ + # 1. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + if cross_attention_kwargs is None: + cross_attention_kwargs = {} + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Preprocess image + image = self.image_processor.preprocess(image) + + # 4. Prepare latent variables + latents = self.prepare_image_latents(image, batch_size, self.vae.dtype, device, generator) + + # 5. Encode input prompt + num_images_per_prompt = 1 + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Prepare timesteps + self.inverse_scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.inverse_scheduler.timesteps + + # 6. Rejig the UNet so that we can obtain the cross-attenion maps and + # use them for guiding the subsequent image generation. + self.unet = prepare_unet(self.unet) + + # 7. Denoising loop where we obtain the cross-attention maps. + num_warmup_steps = len(timesteps) - num_inference_steps * self.inverse_scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.inverse_scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs={"timestep": t}, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # regularization of the noise prediction + with torch.enable_grad(): + for _ in range(num_reg_steps): + if lambda_auto_corr > 0: + for _ in range(num_auto_corr_rolls): + var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True) + + # Derive epsilon from model output before regularizing to IID standard normal + var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t) + + l_ac = self.auto_corr_loss(var_epsilon, generator=generator) + l_ac.backward() + + grad = var.grad.detach() / num_auto_corr_rolls + noise_pred = noise_pred - lambda_auto_corr * grad + + if lambda_kl > 0: + var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True) + + # Derive epsilon from model output before regularizing to IID standard normal + var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t) + + l_kld = self.kl_divergence(var_epsilon) + l_kld.backward() + + grad = var.grad.detach() + noise_pred = noise_pred - lambda_kl * grad + + noise_pred = noise_pred.detach() + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.inverse_scheduler.step(noise_pred, t, latents).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.inverse_scheduler.order == 0 + ): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + inverted_latents = latents.detach().clone() + + # 8. Post-processing + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (inverted_latents, image) + + return Pix2PixInversionPipelineOutput(latents=inverted_latents, images=image) diff --git a/diffusers/pipelines/deprecated/stochastic_karras_ve/__init__.py b/diffusers/pipelines/deprecated/stochastic_karras_ve/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..15c9a8c27f98dd7e1913bd57dfd5e8dae71172b4 --- /dev/null +++ b/diffusers/pipelines/deprecated/stochastic_karras_ve/__init__.py @@ -0,0 +1,19 @@ +from typing import TYPE_CHECKING + +from ....utils import DIFFUSERS_SLOW_IMPORT, _LazyModule + + +_import_structure = {"pipeline_stochastic_karras_ve": ["KarrasVePipeline"]} + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + from .pipeline_stochastic_karras_ve import KarrasVePipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) diff --git a/diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py b/diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py new file mode 100644 index 0000000000000000000000000000000000000000..ce2c785c8d98fc6bbdfcc25d2a401ea1b725a5eb --- /dev/null +++ b/diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py @@ -0,0 +1,126 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from ....models import UNet2DModel +from ....schedulers import KarrasVeScheduler +from ....utils.torch_utils import randn_tensor +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +class KarrasVePipeline(DiffusionPipeline): + r""" + Pipeline for unconditional image generation. + + Parameters: + unet ([`UNet2DModel`]): + A `UNet2DModel` to denoise the encoded image. + scheduler ([`KarrasVeScheduler`]): + A scheduler to be used in combination with `unet` to denoise the encoded image. + """ + + # add type hints for linting + unet: UNet2DModel + scheduler: KarrasVeScheduler + + def __init__(self, unet: UNet2DModel, scheduler: KarrasVeScheduler): + super().__init__() + self.register_modules(unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + batch_size: int = 1, + num_inference_steps: int = 50, + generator: torch.Generator | list[torch.Generator] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + **kwargs, + ) -> tuple | ImagePipelineOutput: + r""" + The call function to the pipeline for generation. + + Args: + batch_size (`int`, *optional*, defaults to 1): + The number of images to generate. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`ImagePipelineOutput`] instead of a plain tuple. + + Example: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images. + """ + + img_size = self.unet.config.sample_size + shape = (batch_size, 3, img_size, img_size) + + model = self.unet + + # sample x_0 ~ N(0, sigma_0^2 * I) + sample = randn_tensor(shape, generator=generator, device=self.device) * self.scheduler.init_noise_sigma + + self.scheduler.set_timesteps(num_inference_steps) + + for t in self.progress_bar(self.scheduler.timesteps): + # here sigma_t == t_i from the paper + sigma = self.scheduler.schedule[t] + sigma_prev = self.scheduler.schedule[t - 1] if t > 0 else 0 + + # 1. Select temporarily increased noise level sigma_hat + # 2. Add new noise to move from sample_i to sample_hat + sample_hat, sigma_hat = self.scheduler.add_noise_to_input(sample, sigma, generator=generator) + + # 3. Predict the noise residual given the noise magnitude `sigma_hat` + # The model inputs and output are adjusted by following eq. (213) in [1]. + model_output = (sigma_hat / 2) * model((sample_hat + 1) / 2, sigma_hat / 2).sample + + # 4. Evaluate dx/dt at sigma_hat + # 5. Take Euler step from sigma to sigma_prev + step_output = self.scheduler.step(model_output, sigma_hat, sigma_prev, sample_hat) + + if sigma_prev != 0: + # 6. Apply 2nd order correction + # The model inputs and output are adjusted by following eq. (213) in [1]. + model_output = (sigma_prev / 2) * model((step_output.prev_sample + 1) / 2, sigma_prev / 2).sample + step_output = self.scheduler.step_correct( + model_output, + sigma_hat, + sigma_prev, + sample_hat, + step_output.prev_sample, + step_output["derivative"], + ) + sample = step_output.prev_sample + + sample = (sample / 2 + 0.5).clamp(0, 1) + image = sample.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/pipelines/deprecated/versatile_diffusion/__init__.py b/diffusers/pipelines/deprecated/versatile_diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8ea6ef6e2f65b96aebebdf72cb80135003e4f08d --- /dev/null +++ b/diffusers/pipelines/deprecated/versatile_diffusion/__init__.py @@ -0,0 +1,71 @@ +from typing import TYPE_CHECKING + +from ....utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + is_torch_available, + is_transformers_available, + is_transformers_version, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ....utils.dummy_torch_and_transformers_objects import ( + VersatileDiffusionDualGuidedPipeline, + VersatileDiffusionImageVariationPipeline, + VersatileDiffusionPipeline, + VersatileDiffusionTextToImagePipeline, + ) + + _dummy_objects.update( + { + "VersatileDiffusionDualGuidedPipeline": VersatileDiffusionDualGuidedPipeline, + "VersatileDiffusionImageVariationPipeline": VersatileDiffusionImageVariationPipeline, + "VersatileDiffusionPipeline": VersatileDiffusionPipeline, + "VersatileDiffusionTextToImagePipeline": VersatileDiffusionTextToImagePipeline, + } + ) +else: + _import_structure["modeling_text_unet"] = ["UNetFlatConditionModel"] + _import_structure["pipeline_versatile_diffusion"] = ["VersatileDiffusionPipeline"] + _import_structure["pipeline_versatile_diffusion_dual_guided"] = ["VersatileDiffusionDualGuidedPipeline"] + _import_structure["pipeline_versatile_diffusion_image_variation"] = ["VersatileDiffusionImageVariationPipeline"] + _import_structure["pipeline_versatile_diffusion_text_to_image"] = ["VersatileDiffusionTextToImagePipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ....utils.dummy_torch_and_transformers_objects import ( + VersatileDiffusionDualGuidedPipeline, + VersatileDiffusionImageVariationPipeline, + VersatileDiffusionPipeline, + VersatileDiffusionTextToImagePipeline, + ) + else: + from .pipeline_versatile_diffusion import VersatileDiffusionPipeline + from .pipeline_versatile_diffusion_dual_guided import VersatileDiffusionDualGuidedPipeline + from .pipeline_versatile_diffusion_image_variation import VersatileDiffusionImageVariationPipeline + from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py b/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py new file mode 100644 index 0000000000000000000000000000000000000000..7be159d77af5ddf3392fbae4c1b808d953bdb801 --- /dev/null +++ b/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py @@ -0,0 +1,2438 @@ +from typing import Any + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from diffusers.utils import deprecate + +from ....configuration_utils import ConfigMixin, register_to_config +from ....models import ModelMixin +from ....models.activations import get_activation +from ....models.attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + Attention, + AttentionProcessor, + AttnAddedKVProcessor, + AttnAddedKVProcessor2_0, + AttnProcessor, +) +from ....models.embeddings import ( + GaussianFourierProjection, + ImageHintTimeEmbedding, + ImageProjection, + ImageTimeEmbedding, + TextImageProjection, + TextImageTimeEmbedding, + TextTimeEmbedding, + TimestepEmbedding, + Timesteps, +) +from ....models.resnet import ResnetBlockCondNorm2D +from ....models.transformers.dual_transformer_2d import DualTransformer2DModel +from ....models.transformers.transformer_2d import Transformer2DModel +from ....models.unets.unet_2d_condition import UNet2DConditionOutput +from ....utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ....utils.torch_utils import apply_freeu + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def get_down_block( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + num_attention_heads, + transformer_layers_per_block, + attention_type, + attention_head_dim, + resnet_groups=None, + cross_attention_dim=None, + downsample_padding=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + resnet_skip_time_act=False, + resnet_out_scale_factor=1.0, + cross_attention_norm=None, + dropout=0.0, +): + down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type + if down_block_type == "DownBlockFlat": + return DownBlockFlat( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "CrossAttnDownBlockFlat": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockFlat") + return CrossAttnDownBlockFlat( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + raise ValueError(f"{down_block_type} is not supported.") + + +def get_up_block( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + num_attention_heads, + transformer_layers_per_block, + resolution_idx, + attention_type, + attention_head_dim, + resnet_groups=None, + cross_attention_dim=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + resnet_skip_time_act=False, + resnet_out_scale_factor=1.0, + cross_attention_norm=None, + dropout=0.0, +): + up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type + if up_block_type == "UpBlockFlat": + return UpBlockFlat( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "CrossAttnUpBlockFlat": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockFlat") + return CrossAttnUpBlockFlat( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + raise ValueError(f"{up_block_type} is not supported.") + + +class FourierEmbedder(nn.Module): + def __init__(self, num_freqs=64, temperature=100): + super().__init__() + + self.num_freqs = num_freqs + self.temperature = temperature + + freq_bands = temperature ** (torch.arange(num_freqs) / num_freqs) + freq_bands = freq_bands[None, None, None] + self.register_buffer("freq_bands", freq_bands, persistent=False) + + def __call__(self, x): + x = self.freq_bands * x.unsqueeze(-1) + return torch.stack((x.sin(), x.cos()), dim=-1).permute(0, 1, 3, 4, 2).reshape(*x.shape[:2], -1) + + +class GLIGENTextBoundingboxProjection(nn.Module): + def __init__(self, positive_len, out_dim, feature_type, fourier_freqs=8): + super().__init__() + self.positive_len = positive_len + self.out_dim = out_dim + + self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs) + self.position_dim = fourier_freqs * 2 * 4 # 2: sin/cos, 4: xyxy + + if isinstance(out_dim, tuple): + out_dim = out_dim[0] + + if feature_type == "text-only": + self.linears = nn.Sequential( + nn.Linear(self.positive_len + self.position_dim, 512), + nn.SiLU(), + nn.Linear(512, 512), + nn.SiLU(), + nn.Linear(512, out_dim), + ) + self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) + + elif feature_type == "text-image": + self.linears_text = nn.Sequential( + nn.Linear(self.positive_len + self.position_dim, 512), + nn.SiLU(), + nn.Linear(512, 512), + nn.SiLU(), + nn.Linear(512, out_dim), + ) + self.linears_image = nn.Sequential( + nn.Linear(self.positive_len + self.position_dim, 512), + nn.SiLU(), + nn.Linear(512, 512), + nn.SiLU(), + nn.Linear(512, out_dim), + ) + self.null_text_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) + self.null_image_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) + + self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim])) + + def forward( + self, + boxes, + masks, + positive_embeddings=None, + phrases_masks=None, + image_masks=None, + phrases_embeddings=None, + image_embeddings=None, + ): + masks = masks.unsqueeze(-1) + + xyxy_embedding = self.fourier_embedder(boxes) + xyxy_null = self.null_position_feature.view(1, 1, -1) + xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null + + if positive_embeddings: + positive_null = self.null_positive_feature.view(1, 1, -1) + positive_embeddings = positive_embeddings * masks + (1 - masks) * positive_null + + objs = self.linears(torch.cat([positive_embeddings, xyxy_embedding], dim=-1)) + else: + phrases_masks = phrases_masks.unsqueeze(-1) + image_masks = image_masks.unsqueeze(-1) + + text_null = self.null_text_feature.view(1, 1, -1) + image_null = self.null_image_feature.view(1, 1, -1) + + phrases_embeddings = phrases_embeddings * phrases_masks + (1 - phrases_masks) * text_null + image_embeddings = image_embeddings * image_masks + (1 - image_masks) * image_null + + objs_text = self.linears_text(torch.cat([phrases_embeddings, xyxy_embedding], dim=-1)) + objs_image = self.linears_image(torch.cat([image_embeddings, xyxy_embedding], dim=-1)) + objs = torch.cat([objs_text, objs_image], dim=1) + + return objs + + +class UNetFlatConditionModel(ModelMixin, ConfigMixin): + r""" + A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample + shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + sample_size (`int` or `tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. + center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. + flip_sin_to_cos (`bool`, *optional*, defaults to `False`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + down_block_types (`tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "DownBlockFlat")`): + The tuple of downsample blocks to use. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlockFlatCrossAttn"`): + Block type for middle of UNet, it can be one of `UNetMidBlockFlatCrossAttn`, `UNetMidBlockFlat`, or + `UNetMidBlockFlatSimpleCrossAttn`. If `None`, the mid block layer is skipped. + up_block_types (`tuple[str]`, *optional*, defaults to `("UpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat")`): + The tuple of upsample blocks to use. + only_cross_attention(`bool` or `tuple[bool]`, *optional*, default to `False`): + Whether to include self-attention in the basic transformer blocks, see + [`~models.attention.BasicTransformerBlock`]. + block_out_channels (`tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + If `None`, normalization and activation layers is skipped in post-processing. + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int` or `tuple[int]`, *optional*, defaults to 1280): + The dimension of the cross attention features. + transformer_layers_per_block (`int`, `tuple[int]`, or `tuple[tuple]` , *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unet_2d_blocks.CrossAttnDownBlockFlat`], [`~models.unet_2d_blocks.CrossAttnUpBlockFlat`], + [`~models.unet_2d_blocks.UNetMidBlockFlatCrossAttn`]. + reverse_transformer_layers_per_block : (`tuple[tuple]`, *optional*, defaults to None): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling + blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `tuple[tuple]` and for + [`~models.unet_2d_blocks.CrossAttnDownBlockFlat`], [`~models.unet_2d_blocks.CrossAttnUpBlockFlat`], + [`~models.unet_2d_blocks.UNetMidBlockFlatCrossAttn`]. + encoder_hid_dim (`int`, *optional*, defaults to None): + If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` + dimension to `cross_attention_dim`. + encoder_hid_dim_type (`str`, *optional*, defaults to `None`): + If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text + embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + num_attention_heads (`int`, *optional*): + The number of attention heads. If not defined, defaults to `attention_head_dim` + resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config + for ResNet blocks (see [`~models.resnet.ResnetBlockFlat`]). Choose from `default` or `scale_shift`. + class_embed_type (`str`, *optional*, defaults to `None`): + The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, + `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. + addition_embed_type (`str`, *optional*, defaults to `None`): + Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or + "text". "text" will use the `TextTimeEmbedding` layer. + addition_time_embed_dim: (`int`, *optional*, defaults to `None`): + Dimension for the timestep embeddings. + num_class_embeds (`int`, *optional*, defaults to `None`): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. + time_embedding_type (`str`, *optional*, defaults to `positional`): + The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. + time_embedding_dim (`int`, *optional*, defaults to `None`): + An optional override for the dimension of the projected time embedding. + time_embedding_act_fn (`str`, *optional*, defaults to `None`): + Optional activation function to use only once on the time embeddings before they are passed to the rest of + the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. + timestep_post_act (`str`, *optional*, defaults to `None`): + The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. + time_cond_proj_dim (`int`, *optional*, defaults to `None`): + The dimension of `cond_proj` layer in the timestep embedding. + conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`, + *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`, + *optional*): The dimension of the `class_labels` input when + `class_embed_type="projection"`. Required when `class_embed_type="projection"`. + class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time + embeddings with the class embeddings. + mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): + Whether to use cross attention with the mid block when using the `UNetMidBlockFlatSimpleCrossAttn`. If + `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the + `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` + otherwise. + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["BasicTransformerBlock", "ResnetBlockFlat", "CrossAttnUpBlockFlat"] + + @register_to_config + def __init__( + self, + sample_size: int | None = None, + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: tuple[str] = ( + "CrossAttnDownBlockFlat", + "CrossAttnDownBlockFlat", + "CrossAttnDownBlockFlat", + "DownBlockFlat", + ), + mid_block_type: str = "UNetMidBlockFlatCrossAttn", + up_block_types: tuple[str] = ( + "UpBlockFlat", + "CrossAttnUpBlockFlat", + "CrossAttnUpBlockFlat", + "CrossAttnUpBlockFlat", + ), + only_cross_attention: bool | tuple[bool] = False, + block_out_channels: tuple[int] = (320, 640, 1280, 1280), + layers_per_block: int | tuple[int] = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + dropout: float = 0.0, + act_fn: str = "silu", + norm_num_groups: int | None = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int | tuple[int] = 1280, + transformer_layers_per_block: int | tuple[int] | tuple[tuple] = 1, + reverse_transformer_layers_per_block: tuple[tuple[int]] | None = None, + encoder_hid_dim: int | None = None, + encoder_hid_dim_type: str | None = None, + attention_head_dim: int | tuple[int] = 8, + num_attention_heads: int | tuple[int] | None = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + class_embed_type: str | None = None, + addition_embed_type: str | None = None, + addition_time_embed_dim: int | None = None, + num_class_embeds: int | None = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: int = 1.0, + time_embedding_type: str = "positional", + time_embedding_dim: int | None = None, + time_embedding_act_fn: str | None = None, + timestep_post_act: str | None = None, + time_cond_proj_dim: int | None = None, + conv_in_kernel: int = 3, + conv_out_kernel: int = 3, + projection_class_embeddings_input_dim: int | None = None, + attention_type: str = "default", + class_embeddings_concat: bool = False, + mid_block_only_cross_attention: bool | None = None, + cross_attention_norm: str | None = None, + addition_embed_type_num_heads=64, + ): + super().__init__() + + self.sample_size = sample_size + + if num_attention_heads is not None: + raise ValueError( + "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." + ) + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = num_attention_heads or attention_head_dim + + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." + ) + + if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." + ) + if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None: + for layer_number_per_block in transformer_layers_per_block: + if isinstance(layer_number_per_block, list): + raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.") + + # input + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = LinearMultiDim( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + # time + if time_embedding_type == "fourier": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 + if time_embed_dim % 2 != 0: + raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") + self.time_proj = GaussianFourierProjection( + time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos + ) + timestep_input_dim = time_embed_dim + elif time_embedding_type == "positional": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 + + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + else: + raise ValueError( + f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." + ) + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + post_act_fn=timestep_post_act, + cond_proj_dim=time_cond_proj_dim, + ) + + if encoder_hid_dim_type is None and encoder_hid_dim is not None: + encoder_hid_dim_type = "text_proj" + self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) + logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") + + if encoder_hid_dim is None and encoder_hid_dim_type is not None: + raise ValueError( + f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." + ) + + if encoder_hid_dim_type == "text_proj": + self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) + elif encoder_hid_dim_type == "text_image_proj": + # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)` + self.encoder_hid_proj = TextImageProjection( + text_embed_dim=encoder_hid_dim, + image_embed_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 + self.encoder_hid_proj = ImageProjection( + image_embed_dim=encoder_hid_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type is not None: + raise ValueError( + f"`encoder_hid_dim_type`: {encoder_hid_dim_type} must be None, 'text_proj', 'text_image_proj' or 'image_proj'." + ) + else: + self.encoder_hid_proj = None + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif class_embed_type == "simple_projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" + ) + self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) + else: + self.class_embedding = None + + if addition_embed_type == "text": + if encoder_hid_dim is not None: + text_time_embedding_from_dim = encoder_hid_dim + else: + text_time_embedding_from_dim = cross_attention_dim + + self.add_embedding = TextTimeEmbedding( + text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads + ) + elif addition_embed_type == "text_image": + # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image"` (Kandinsky 2.1)` + self.add_embedding = TextImageTimeEmbedding( + text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim + ) + elif addition_embed_type == "text_time": + self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif addition_embed_type == "image": + # Kandinsky 2.2 + self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type == "image_hint": + # Kandinsky 2.2 ControlNet + self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type is not None: + raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") + + if time_embedding_act_fn is None: + self.time_embed_act = None + else: + self.time_embed_act = get_activation(time_embedding_act_fn) + + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = only_cross_attention + + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = False + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) * len(down_block_types) + + if isinstance(layers_per_block, int): + layers_per_block = [layers_per_block] * len(down_block_types) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + + if class_embeddings_concat: + # The time embeddings are concatenated with the class embeddings. The dimension of the + # time embeddings passed to the down, middle, and up blocks is twice the dimension of the + # regular time embeddings + blocks_time_embed_dim = time_embed_dim * 2 + else: + blocks_time_embed_dim = time_embed_dim + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block[i], + transformer_layers_per_block=transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + temb_channels=blocks_time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim[i], + num_attention_heads=num_attention_heads[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + dropout=dropout, + ) + self.down_blocks.append(down_block) + + # mid + if mid_block_type == "UNetMidBlockFlatCrossAttn": + self.mid_block = UNetMidBlockFlatCrossAttn( + transformer_layers_per_block=transformer_layers_per_block[-1], + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + dropout=dropout, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim[-1], + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + elif mid_block_type == "UNetMidBlockFlatSimpleCrossAttn": + self.mid_block = UNetMidBlockFlatSimpleCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + dropout=dropout, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + cross_attention_dim=cross_attention_dim[-1], + attention_head_dim=attention_head_dim[-1], + resnet_groups=norm_num_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + only_cross_attention=mid_block_only_cross_attention, + cross_attention_norm=cross_attention_norm, + ) + elif mid_block_type == "UNetMidBlockFlat": + self.mid_block = UNetMidBlockFlat( + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + dropout=dropout, + num_layers=0, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_groups=norm_num_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + add_attention=False, + ) + elif mid_block_type is None: + self.mid_block = None + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + reversed_layers_per_block = list(reversed(layers_per_block)) + reversed_cross_attention_dim = list(reversed(cross_attention_dim)) + reversed_transformer_layers_per_block = ( + list(reversed(transformer_layers_per_block)) + if reverse_transformer_layers_per_block is None + else reverse_transformer_layers_per_block + ) + only_cross_attention = list(reversed(only_cross_attention)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=reversed_layers_per_block[i] + 1, + transformer_layers_per_block=reversed_transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=blocks_time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resolution_idx=i, + resnet_groups=norm_num_groups, + cross_attention_dim=reversed_cross_attention_dim[i], + num_attention_heads=reversed_num_attention_heads[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + dropout=dropout, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_num_groups is not None: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps + ) + + self.conv_act = get_activation(act_fn) + + else: + self.conv_norm_out = None + self.conv_act = None + + conv_out_padding = (conv_out_kernel - 1) // 2 + self.conv_out = LinearMultiDim( + block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding + ) + + if attention_type in ["gated", "gated-text-image"]: + positive_len = 768 + if isinstance(cross_attention_dim, int): + positive_len = cross_attention_dim + elif isinstance(cross_attention_dim, (list, tuple)): + positive_len = cross_attention_dim[0] + + feature_type = "text-only" if attention_type == "gated" else "text-image" + self.position_net = GLIGENTextBoundingboxProjection( + positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type + ) + + @property + def attn_processors(self) -> dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor: AttentionProcessor | dict[str, AttentionProcessor]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnAddedKVProcessor() + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module splits the input tensor in slices to compute attention in + several steps. This is useful for saving some memory in exchange for a small decrease in speed. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If + `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_sliceable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_sliceable_dims(module) + + num_sliceable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_sliceable_layers * [1] + + slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: list[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def enable_freeu(self, s1, s2, b1, b2): + r"""Enables the FreeU mechanism from https://huggingface.co/papers/2309.11497. + + The suffixes after the scaling factors represent the stage blocks where they are being applied. + + Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that + are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. + + Args: + s1 (`float`): + Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to + mitigate the "oversmoothing effect" in the enhanced denoising process. + s2 (`float`): + Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to + mitigate the "oversmoothing effect" in the enhanced denoising process. + b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. + b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. + """ + for i, upsample_block in enumerate(self.up_blocks): + setattr(upsample_block, "s1", s1) + setattr(upsample_block, "s2", s2) + setattr(upsample_block, "b1", b1) + setattr(upsample_block, "b2", b2) + + def disable_freeu(self): + """Disables the FreeU mechanism.""" + freeu_keys = {"s1", "s2", "b1", "b2"} + for i, upsample_block in enumerate(self.up_blocks): + for k in freeu_keys: + if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None: + setattr(upsample_block, k, None) + + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + > [!WARNING] > This API is 🧪 experimental. + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + > [!WARNING] > This API is 🧪 experimental. + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) + + def unload_lora(self): + """Unloads LoRA weights.""" + deprecate( + "unload_lora", + "0.28.0", + "Calling `unload_lora()` is deprecated and will be removed in a future version. Please install `peft` and then call `disable_adapters().", + ) + for module in self.modules(): + if hasattr(module, "set_lora_layer"): + module.set_lora_layer(None) + + def forward( + self, + sample: torch.Tensor, + timestep: torch.Tensor | float | int, + encoder_hidden_states: torch.Tensor, + class_labels: torch.Tensor | None = None, + timestep_cond: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + cross_attention_kwargs: dict[str, Any] | None = None, + added_cond_kwargs: dict[str, torch.Tensor] | None = None, + down_block_additional_residuals: tuple[torch.Tensor] | None = None, + mid_block_additional_residual: torch.Tensor | None = None, + down_intrablock_additional_residuals: tuple[torch.Tensor] | None = None, + encoder_attention_mask: torch.Tensor | None = None, + return_dict: bool = True, + ) -> UNet2DConditionOutput | tuple: + r""" + The [`UNetFlatConditionModel`] forward method. + + Args: + sample (`torch.Tensor`): + The noisy input tensor with the following shape `(batch, channel, height, width)`. + timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input. + encoder_hidden_states (`torch.Tensor`): + The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. + class_labels (`torch.Tensor`, *optional*, defaults to `None`): + Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. + timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`): + Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed + through the `self.time_embedding` layer to obtain the timestep embeddings. + attention_mask (`torch.Tensor`, *optional*, defaults to `None`): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + added_cond_kwargs: (`dict`, *optional*): + A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that + are passed along to the UNet blocks. + down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*): + A tuple of tensors that if specified are added to the residuals of down unet blocks. + mid_block_additional_residual: (`torch.Tensor`, *optional*): + A tensor that if specified is added to the residual of the middle unet block. + encoder_attention_mask (`torch.Tensor`): + A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If + `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, + which adds large negative values to the attention scores corresponding to "discard" tokens. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. + added_cond_kwargs: (`dict`, *optional*): + A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that + are passed along to the UNet blocks. + down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*): + additional residuals to be added to UNet long skip connections from down blocks to up blocks for + example from ControlNet side model(s) + mid_block_additional_residual (`torch.Tensor`, *optional*): + additional residual to be added to UNet mid block output, for example from ControlNet side model + down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*): + additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s) + + Returns: + [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, + otherwise a `tuple` is returned where the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + for dim in sample.shape[-2:]: + if dim % default_overall_up_factor != 0: + # Forward upsample size to force interpolation output size. + forward_upsample_size = True + break + + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None: + encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + is_npu = sample.device.type == "npu" + if isinstance(timestep, float): + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + else: + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + aug_emb = None + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # there might be better ways to encapsulate this. + class_labels = class_labels.to(dtype=sample.dtype) + + class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) + + if self.config.class_embeddings_concat: + emb = torch.cat([emb, class_emb], dim=-1) + else: + emb = emb + class_emb + + if self.config.addition_embed_type == "text": + aug_emb = self.add_embedding(encoder_hidden_states) + elif self.config.addition_embed_type == "text_image": + # Kandinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + + image_embs = added_cond_kwargs.get("image_embeds") + text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) + aug_emb = self.add_embedding(text_embs, image_embs) + elif self.config.addition_embed_type == "text_time": + # SDXL - style + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + elif self.config.addition_embed_type == "image": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + aug_emb = self.add_embedding(image_embs) + elif self.config.addition_embed_type == "image_hint": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + hint = added_cond_kwargs.get("hint") + aug_emb, hint = self.add_embedding(image_embs, hint) + sample = torch.cat([sample, hint], dim=1) + + emb = emb + aug_emb if aug_emb is not None else emb + + if self.time_embed_act is not None: + emb = self.time_embed_act(emb) + + if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": + # Kandinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(image_embeds) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + image_embeds = self.encoder_hid_proj(image_embeds) + encoder_hidden_states = (encoder_hidden_states, image_embeds) + + # 2. pre-process + sample = self.conv_in(sample) + + # 2.5 GLIGEN position net + if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None: + cross_attention_kwargs = cross_attention_kwargs.copy() + gligen_args = cross_attention_kwargs.pop("gligen") + cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} + + # 3. down + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + + is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None + # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets + is_adapter = down_intrablock_additional_residuals is not None + # maintain backward compatibility for legacy usage, where + # T2I-Adapter and ControlNet both use down_block_additional_residuals arg + # but can only use one or the other + if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None: + deprecate( + "T2I should not use down_block_additional_residuals", + "1.3.0", + "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \ + and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \ + for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ", + standard_warn=False, + ) + down_intrablock_additional_residuals = down_block_additional_residuals + is_adapter = True + + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + # For t2i-adapter CrossAttnDownBlockFlat + additional_residuals = {} + if is_adapter and len(down_intrablock_additional_residuals) > 0: + additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) + + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + **additional_residuals, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + if is_adapter and len(down_intrablock_additional_residuals) > 0: + sample += down_intrablock_additional_residuals.pop(0) + + down_block_res_samples += res_samples + + if is_controlnet: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + if self.mid_block is not None: + if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = self.mid_block(sample, emb) + + # To support T2I-Adapter-XL + if ( + is_adapter + and len(down_intrablock_additional_residuals) > 0 + and sample.shape == down_intrablock_additional_residuals[0].shape + ): + sample += down_intrablock_additional_residuals.pop(0) + + if is_controlnet: + sample = sample + mid_block_additional_residual + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + scale=lora_scale, + ) + + # 6. post-process + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (sample,) + + return UNet2DConditionOutput(sample=sample) + + +class LinearMultiDim(nn.Linear): + def __init__(self, in_features, out_features=None, second_dim=4, *args, **kwargs): + in_features = [in_features, second_dim, 1] if isinstance(in_features, int) else list(in_features) + if out_features is None: + out_features = in_features + out_features = [out_features, second_dim, 1] if isinstance(out_features, int) else list(out_features) + self.in_features_multidim = in_features + self.out_features_multidim = out_features + super().__init__(np.array(in_features).prod(), np.array(out_features).prod()) + + def forward(self, input_tensor, *args, **kwargs): + shape = input_tensor.shape + n_dim = len(self.in_features_multidim) + input_tensor = input_tensor.reshape(*shape[0:-n_dim], self.in_features) + output_tensor = super().forward(input_tensor) + output_tensor = output_tensor.view(*shape[0:-n_dim], *self.out_features_multidim) + return output_tensor + + +class ResnetBlockFlat(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + dropout=0.0, + temb_channels=512, + groups=32, + groups_out=None, + pre_norm=True, + eps=1e-6, + time_embedding_norm="default", + use_in_shortcut=None, + second_dim=4, + **kwargs, + ): + super().__init__() + self.pre_norm = pre_norm + self.pre_norm = True + + in_channels = [in_channels, second_dim, 1] if isinstance(in_channels, int) else list(in_channels) + self.in_channels_prod = np.array(in_channels).prod() + self.channels_multidim = in_channels + + if out_channels is not None: + out_channels = [out_channels, second_dim, 1] if isinstance(out_channels, int) else list(out_channels) + out_channels_prod = np.array(out_channels).prod() + self.out_channels_multidim = out_channels + else: + out_channels_prod = self.in_channels_prod + self.out_channels_multidim = self.channels_multidim + self.time_embedding_norm = time_embedding_norm + + if groups_out is None: + groups_out = groups + + self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=self.in_channels_prod, eps=eps, affine=True) + self.conv1 = torch.nn.Conv2d(self.in_channels_prod, out_channels_prod, kernel_size=1, padding=0) + + if temb_channels is not None: + self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels_prod) + else: + self.time_emb_proj = None + + self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels_prod, eps=eps, affine=True) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels_prod, out_channels_prod, kernel_size=1, padding=0) + + self.nonlinearity = nn.SiLU() + + self.use_in_shortcut = ( + self.in_channels_prod != out_channels_prod if use_in_shortcut is None else use_in_shortcut + ) + + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = torch.nn.Conv2d( + self.in_channels_prod, out_channels_prod, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, input_tensor, temb): + shape = input_tensor.shape + n_dim = len(self.channels_multidim) + input_tensor = input_tensor.reshape(*shape[0:-n_dim], self.in_channels_prod, 1, 1) + input_tensor = input_tensor.view(-1, self.in_channels_prod, 1, 1) + + hidden_states = input_tensor + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.conv1(hidden_states) + + if temb is not None: + temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None] + hidden_states = hidden_states + temb + + hidden_states = self.norm2(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = input_tensor + hidden_states + + output_tensor = output_tensor.view(*shape[0:-n_dim], -1) + output_tensor = output_tensor.view(*shape[0:-n_dim], *self.out_channels_multidim) + + return output_tensor + + +class DownBlockFlat(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + downsample_padding: int = 1, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlockFlat( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + LinearMultiDim( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, hidden_states: torch.Tensor, temb: torch.Tensor | None = None + ) -> tuple[torch.Tensor, tuple[torch.Tensor, ...]]: + output_states = () + + for resnet in self.resnets: + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnDownBlockFlat(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int | tuple[int] = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + downsample_padding: int = 1, + add_downsample: bool = True, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlockFlat( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + LinearMultiDim( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + cross_attention_kwargs: dict[str, Any] | None = None, + encoder_attention_mask: torch.Tensor | None = None, + additional_residuals: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, tuple[torch.Tensor, ...]]: + output_states = () + + blocks = list(zip(self.resnets, self.attentions)) + + for i, (resnet, attn) in enumerate(blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + # apply additional residuals to the output of the last pair of resnet and attention blocks + if i == len(blocks) - 1 and additional_residuals is not None: + hidden_states = hidden_states + additional_residuals + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +# Copied from diffusers.models.unets.unet_2d_blocks.UpBlock2D with UpBlock2D->UpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim +class UpBlockFlat(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + resolution_idx: int | None = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlockFlat( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([LinearMultiDim(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_tuple: tuple[torch.Tensor, ...], + temb: torch.Tensor | None = None, + upsample_size: int | None = None, + *args, + **kwargs, + ) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + hidden_states, res_hidden_states = apply_freeu( + self.resolution_idx, + hidden_states, + res_hidden_states, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +# Copied from diffusers.models.unets.unet_2d_blocks.CrossAttnUpBlock2D with CrossAttnUpBlock2D->CrossAttnUpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim +class CrossAttnUpBlockFlat(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + resolution_idx: int | None = None, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int | tuple[int] = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlockFlat( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([LinearMultiDim(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_tuple: tuple[torch.Tensor, ...], + temb: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + cross_attention_kwargs: dict[str, Any] | None = None, + upsample_size: int | None = None, + attention_mask: torch.Tensor | None = None, + encoder_attention_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + hidden_states, res_hidden_states = apply_freeu( + self.resolution_idx, + hidden_states, + res_hidden_states, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +# Copied from diffusers.models.unets.unet_2d_blocks.UNetMidBlock2D with UNetMidBlock2D->UNetMidBlockFlat, ResnetBlock2D->ResnetBlockFlat +class UNetMidBlockFlat(nn.Module): + """ + A 2D UNet mid-block [`UNetMidBlockFlat`] with multiple residual blocks and optional attention blocks. + + Args: + in_channels (`int`): The number of input channels. + temb_channels (`int`): The number of temporal embedding channels. + dropout (`float`, *optional*, defaults to 0.0): The dropout rate. + num_layers (`int`, *optional*, defaults to 1): The number of residual blocks. + resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks. + resnet_time_scale_shift (`str`, *optional*, defaults to `default`): + The type of normalization to apply to the time embeddings. This can help to improve the performance of the + model on tasks with long-range temporal dependencies. + resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks. + resnet_groups (`int`, *optional*, defaults to 32): + The number of groups to use in the group normalization layers of the resnet blocks. + attn_groups (`int | None`, *optional*, defaults to None): The number of groups for the attention blocks. + resnet_pre_norm (`bool`, *optional*, defaults to `True`): + Whether to use pre-normalization for the resnet blocks. + add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks. + attention_head_dim (`int`, *optional*, defaults to 1): + Dimension of a single attention head. The number of attention heads is determined based on this value and + the number of input channels. + output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor. + + Returns: + `torch.Tensor`: The output of the last residual block, which is a tensor of shape `(batch_size, in_channels, + height, width)`. + + """ + + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", # default, spatial + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + attn_groups: int | None = None, + resnet_pre_norm: bool = True, + add_attention: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + ): + super().__init__() + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + self.add_attention = add_attention + + if attn_groups is None: + attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None + + # there is always at least one resnet + if resnet_time_scale_shift == "spatial": + resnets = [ + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ] + else: + resnets = [ + ResnetBlockFlat( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + if attention_head_dim is None: + logger.warning( + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}." + ) + attention_head_dim = in_channels + + for _ in range(num_layers): + if self.add_attention: + attentions.append( + Attention( + in_channels, + heads=in_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=attn_groups, + spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + else: + attentions.append(None) + + if resnet_time_scale_shift == "spatial": + resnets.append( + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) + else: + resnets.append( + ResnetBlockFlat( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor | None = None) -> torch.Tensor: + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if torch.is_grad_enabled() and self.gradient_checkpointing: + if attn is not None: + hidden_states = attn(hidden_states, temb=temb) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) + else: + if attn is not None: + hidden_states = attn(hidden_states, temb=temb) + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +# Copied from diffusers.models.unets.unet_2d_blocks.UNetMidBlock2DCrossAttn with UNetMidBlock2DCrossAttn->UNetMidBlockFlatCrossAttn, ResnetBlock2D->ResnetBlockFlat +class UNetMidBlockFlatCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + out_channels: int | None = None, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int | tuple[int] = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_groups_out: int | None = None, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + output_scale_factor: float = 1.0, + cross_attention_dim: int = 1280, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + ): + super().__init__() + + out_channels = out_channels or in_channels + self.in_channels = in_channels + self.out_channels = out_channels + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # support for variable transformer layers per block + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + resnet_groups_out = resnet_groups_out or resnet_groups + + # there is always at least one resnet + resnets = [ + ResnetBlockFlat( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + groups_out=resnet_groups_out, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + for i in range(num_layers): + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups_out, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + resnets.append( + ResnetBlockFlat( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups_out, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + cross_attention_kwargs: dict[str, Any] | None = None, + encoder_attention_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) + else: + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +# Copied from diffusers.models.unets.unet_2d_blocks.UNetMidBlock2DSimpleCrossAttn with UNetMidBlock2DSimpleCrossAttn->UNetMidBlockFlatSimpleCrossAttn, ResnetBlock2D->ResnetBlockFlat +class UNetMidBlockFlatSimpleCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + cross_attention_dim: int = 1280, + skip_time_act: bool = False, + only_cross_attention: bool = False, + cross_attention_norm: str | None = None, + ): + super().__init__() + + self.has_cross_attention = True + + self.attention_head_dim = attention_head_dim + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + self.num_heads = in_channels // self.attention_head_dim + + # there is always at least one resnet + resnets = [ + ResnetBlockFlat( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + ) + ] + attentions = [] + + for _ in range(num_layers): + processor = ( + AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor() + ) + + attentions.append( + Attention( + query_dim=in_channels, + cross_attention_dim=in_channels, + heads=self.num_heads, + dim_head=self.attention_head_dim, + added_kv_proj_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + bias=True, + upcast_softmax=True, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + processor=processor, + ) + ) + resnets.append( + ResnetBlockFlat( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + cross_attention_kwargs: dict[str, Any] | None = None, + encoder_attention_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + if attention_mask is None: + # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. + mask = None if encoder_hidden_states is None else encoder_attention_mask + else: + # when attention_mask is defined: we don't even check for encoder_attention_mask. + # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks. + # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask. + # then we can simplify this whole if/else block to: + # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask + mask = attention_mask + + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + # attn + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=mask, + **cross_attention_kwargs, + ) + + # resnet + hidden_states = resnet(hidden_states, temb) + + return hidden_states diff --git a/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py b/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..101a1b72e7f95e4d28b68964f89d3e7566ccfe57 --- /dev/null +++ b/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py @@ -0,0 +1,421 @@ +import inspect +from typing import Callable + +import PIL.Image +import torch +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModel + +from ....models import AutoencoderKL, UNet2DConditionModel +from ....schedulers import KarrasDiffusionSchedulers +from ....utils import logging +from ...pipeline_utils import DiffusionPipeline +from .pipeline_versatile_diffusion_dual_guided import VersatileDiffusionDualGuidedPipeline +from .pipeline_versatile_diffusion_image_variation import VersatileDiffusionImageVariationPipeline +from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class VersatileDiffusionPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for + more details about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + tokenizer: CLIPTokenizer + image_feature_extractor: CLIPImageProcessor + text_encoder: CLIPTextModel + image_encoder: CLIPVisionModel + image_unet: UNet2DConditionModel + text_unet: UNet2DConditionModel + vae: AutoencoderKL + scheduler: KarrasDiffusionSchedulers + + def __init__( + self, + tokenizer: CLIPTokenizer, + image_feature_extractor: CLIPImageProcessor, + text_encoder: CLIPTextModel, + image_encoder: CLIPVisionModel, + image_unet: UNet2DConditionModel, + text_unet: UNet2DConditionModel, + vae: AutoencoderKL, + scheduler: KarrasDiffusionSchedulers, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, + image_feature_extractor=image_feature_extractor, + text_encoder=text_encoder, + image_encoder=image_encoder, + image_unet=image_unet, + text_unet=text_unet, + vae=vae, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + + @torch.no_grad() + def image_variation( + self, + image: torch.Tensor | PIL.Image.Image, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + ): + r""" + The call function to the pipeline for generation. + + Args: + image (`PIL.Image.Image`, `list[PIL.Image.Image]` or `torch.Tensor`): + The image prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to `self.image_unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.image_unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + + Examples: + + ```py + >>> from diffusers import VersatileDiffusionPipeline + >>> import torch + >>> import requests + >>> from io import BytesIO + >>> from PIL import Image + + >>> # let's download an initial image + >>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg" + + >>> response = requests.get(url) + >>> image = Image.open(BytesIO(response.content)).convert("RGB") + + >>> pipe = VersatileDiffusionPipeline.from_pretrained( + ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> generator = torch.Generator(device="cuda").manual_seed(0) + >>> image = pipe.image_variation(image, generator=generator).images[0] + >>> image.save("./car_variation.png") + ``` + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + expected_components = inspect.signature(VersatileDiffusionImageVariationPipeline.__init__).parameters.keys() + components = {name: component for name, component in self.components.items() if name in expected_components} + return VersatileDiffusionImageVariationPipeline(**components)( + image=image, + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + latents=latents, + output_type=output_type, + return_dict=return_dict, + callback=callback, + callback_steps=callback_steps, + ) + + @torch.no_grad() + def text_to_image( + self, + prompt: str | list[str], + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`): + The prompt or prompts to guide image generation. + height (`int`, *optional*, defaults to `self.image_unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.image_unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + + Examples: + + ```py + >>> from diffusers import VersatileDiffusionPipeline + >>> import torch + + >>> pipe = VersatileDiffusionPipeline.from_pretrained( + ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> generator = torch.Generator(device="cuda").manual_seed(0) + >>> image = pipe.text_to_image("an astronaut riding on a horse on mars", generator=generator).images[0] + >>> image.save("./astronaut.png") + ``` + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + expected_components = inspect.signature(VersatileDiffusionTextToImagePipeline.__init__).parameters.keys() + components = {name: component for name, component in self.components.items() if name in expected_components} + temp_pipeline = VersatileDiffusionTextToImagePipeline(**components) + output = temp_pipeline( + prompt=prompt, + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + latents=latents, + output_type=output_type, + return_dict=return_dict, + callback=callback, + callback_steps=callback_steps, + ) + # swap the attention blocks back to the original state + temp_pipeline._swap_unet_attention_blocks() + + return output + + @torch.no_grad() + def dual_guided( + self, + prompt: PIL.Image.Image | list[PIL.Image.Image], + image: str | list[str], + text_to_image_strength: float = 0.5, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`): + The prompt or prompts to guide image generation. + height (`int`, *optional*, defaults to `self.image_unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.image_unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + + Examples: + + ```py + >>> from diffusers import VersatileDiffusionPipeline + >>> import torch + >>> import requests + >>> from io import BytesIO + >>> from PIL import Image + + >>> # let's download an initial image + >>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg" + + >>> response = requests.get(url) + >>> image = Image.open(BytesIO(response.content)).convert("RGB") + >>> text = "a red car in the sun" + + >>> pipe = VersatileDiffusionPipeline.from_pretrained( + ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> generator = torch.Generator(device="cuda").manual_seed(0) + >>> text_to_image_strength = 0.75 + + >>> image = pipe.dual_guided( + ... prompt=text, image=image, text_to_image_strength=text_to_image_strength, generator=generator + ... ).images[0] + >>> image.save("./car_variation.png") + ``` + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images. + """ + + expected_components = inspect.signature(VersatileDiffusionDualGuidedPipeline.__init__).parameters.keys() + components = {name: component for name, component in self.components.items() if name in expected_components} + temp_pipeline = VersatileDiffusionDualGuidedPipeline(**components) + output = temp_pipeline( + prompt=prompt, + image=image, + text_to_image_strength=text_to_image_strength, + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + latents=latents, + output_type=output_type, + return_dict=return_dict, + callback=callback, + callback_steps=callback_steps, + ) + temp_pipeline._revert_dual_attention() + + return output diff --git a/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py b/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py new file mode 100644 index 0000000000000000000000000000000000000000..8f8fb712e023b34b84e380f73fd1b98c5573e8ce --- /dev/null +++ b/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py @@ -0,0 +1,561 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable + +import numpy as np +import PIL.Image +import torch +import torch.utils.checkpoint +from transformers import ( + CLIPImageProcessor, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from ....image_processor import VaeImageProcessor +from ....models import AutoencoderKL, DualTransformer2DModel, Transformer2DModel, UNet2DConditionModel +from ....schedulers import KarrasDiffusionSchedulers +from ....utils import deprecate, logging +from ....utils.torch_utils import randn_tensor +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from .modeling_text_unet import UNetFlatConditionModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline): + r""" + Pipeline for image-text dual-guided generation using Versatile Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Parameters: + vqvae ([`VQModel`]): + Vector-quantized (VQ) model to encode and decode images to and from latent representations. + bert ([`LDMBertModel`]): + Text-encoder model based on [`~transformers.BERT`]. + tokenizer ([`~transformers.BertTokenizer`]): + A `BertTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + model_cpu_offload_seq = "bert->unet->vqvae" + + tokenizer: CLIPTokenizer + image_feature_extractor: CLIPImageProcessor + text_encoder: CLIPTextModelWithProjection + image_encoder: CLIPVisionModelWithProjection + image_unet: UNet2DConditionModel + text_unet: UNetFlatConditionModel + vae: AutoencoderKL + scheduler: KarrasDiffusionSchedulers + + _optional_components = ["text_unet"] + + def __init__( + self, + tokenizer: CLIPTokenizer, + image_feature_extractor: CLIPImageProcessor, + text_encoder: CLIPTextModelWithProjection, + image_encoder: CLIPVisionModelWithProjection, + image_unet: UNet2DConditionModel, + text_unet: UNetFlatConditionModel, + vae: AutoencoderKL, + scheduler: KarrasDiffusionSchedulers, + ): + super().__init__() + self.register_modules( + tokenizer=tokenizer, + image_feature_extractor=image_feature_extractor, + text_encoder=text_encoder, + image_encoder=image_encoder, + image_unet=image_unet, + text_unet=text_unet, + vae=vae, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + if self.text_unet is not None and ( + "dual_cross_attention" not in self.image_unet.config or not self.image_unet.config.dual_cross_attention + ): + # if loading from a universal checkpoint rather than a saved dual-guided pipeline + self._convert_to_dual_attention() + + def remove_unused_weights(self): + self.register_modules(text_unet=None) + + def _convert_to_dual_attention(self): + """ + Replace image_unet's `Transformer2DModel` blocks with `DualTransformer2DModel` that contains transformer blocks + from both `image_unet` and `text_unet` + """ + for name, module in self.image_unet.named_modules(): + if isinstance(module, Transformer2DModel): + parent_name, index = name.rsplit(".", 1) + index = int(index) + + image_transformer = self.image_unet.get_submodule(parent_name)[index] + text_transformer = self.text_unet.get_submodule(parent_name)[index] + + config = image_transformer.config + dual_transformer = DualTransformer2DModel( + num_attention_heads=config.num_attention_heads, + attention_head_dim=config.attention_head_dim, + in_channels=config.in_channels, + num_layers=config.num_layers, + dropout=config.dropout, + norm_num_groups=config.norm_num_groups, + cross_attention_dim=config.cross_attention_dim, + attention_bias=config.attention_bias, + sample_size=config.sample_size, + num_vector_embeds=config.num_vector_embeds, + activation_fn=config.activation_fn, + num_embeds_ada_norm=config.num_embeds_ada_norm, + ) + dual_transformer.transformers[0] = image_transformer + dual_transformer.transformers[1] = text_transformer + + self.image_unet.get_submodule(parent_name)[index] = dual_transformer + self.image_unet.register_to_config(dual_cross_attention=True) + + def _revert_dual_attention(self): + """ + Revert the image_unet `DualTransformer2DModel` blocks back to `Transformer2DModel` with image_unet weights Call + this function if you reuse `image_unet` in another pipeline, e.g. `VersatileDiffusionPipeline` + """ + for name, module in self.image_unet.named_modules(): + if isinstance(module, DualTransformer2DModel): + parent_name, index = name.rsplit(".", 1) + index = int(index) + self.image_unet.get_submodule(parent_name)[index] = module.transformers[0] + + self.image_unet.register_to_config(dual_cross_attention=False) + + def _encode_text_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + """ + + def normalize_embeddings(encoder_output): + embeds = self.text_encoder.text_projection(encoder_output.last_hidden_state) + embeds_pooled = encoder_output.text_embeds + embeds = embeds / torch.norm(embeds_pooled.unsqueeze(1), dim=-1, keepdim=True) + return embeds + + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids + + if not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = normalize_embeddings(prompt_embeds) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens = [""] * batch_size + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = normalize_embeddings(negative_prompt_embeds) + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + def _encode_image_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + """ + + def normalize_embeddings(encoder_output): + embeds = self.image_encoder.vision_model.post_layernorm(encoder_output.last_hidden_state) + embeds = self.image_encoder.visual_projection(embeds) + embeds_pooled = embeds[:, 0:1] + embeds = embeds / torch.norm(embeds_pooled, dim=-1, keepdim=True) + return embeds + + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + # get prompt text embeddings + image_input = self.image_feature_extractor(images=prompt, return_tensors="pt") + pixel_values = image_input.pixel_values.to(device).to(self.image_encoder.dtype) + image_embeddings = self.image_encoder(pixel_values) + image_embeddings = normalize_embeddings(image_embeddings) + + # duplicate image embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = image_embeddings.shape + image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1) + image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_images = [np.zeros((512, 512, 3)) + 0.5] * batch_size + uncond_images = self.image_feature_extractor(images=uncond_images, return_tensors="pt") + pixel_values = uncond_images.pixel_values.to(device).to(self.image_encoder.dtype) + negative_prompt_embeds = self.image_encoder(pixel_values) + negative_prompt_embeds = normalize_embeddings(negative_prompt_embeds) + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and conditional embeddings into a single batch + # to avoid doing two forward passes + image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings]) + + return image_embeddings + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs(self, prompt, image, height, width, callback_steps): + if not isinstance(prompt, str) and not isinstance(prompt, PIL.Image.Image) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` `PIL.Image` or `list` but is {type(prompt)}") + if not isinstance(image, str) and not isinstance(image, PIL.Image.Image) and not isinstance(image, list): + raise ValueError(f"`image` has to be of type `str` `PIL.Image` or `list` but is {type(image)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def set_transformer_params(self, mix_ratio: float = 0.5, condition_types: tuple = ("text", "image")): + for name, module in self.image_unet.named_modules(): + if isinstance(module, DualTransformer2DModel): + module.mix_ratio = mix_ratio + + for i, type in enumerate(condition_types): + if type == "text": + module.condition_lengths[i] = self.text_encoder.config.max_position_embeddings + module.transformer_index_for_condition[i] = 1 # use the second (text) transformer + else: + module.condition_lengths[i] = 257 + module.transformer_index_for_condition[i] = 0 # use the first (image) transformer + + @torch.no_grad() + def __call__( + self, + prompt: PIL.Image.Image | list[PIL.Image.Image], + image: str | list[str], + text_to_image_strength: float = 0.5, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`): + The prompt or prompts to guide image generation. + height (`int`, *optional*, defaults to `self.image_unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.image_unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + + Examples: + + ```py + >>> from diffusers import VersatileDiffusionDualGuidedPipeline + >>> import torch + >>> import requests + >>> from io import BytesIO + >>> from PIL import Image + + >>> # let's download an initial image + >>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg" + + >>> response = requests.get(url) + >>> image = Image.open(BytesIO(response.content)).convert("RGB") + >>> text = "a red car in the sun" + + >>> pipe = VersatileDiffusionDualGuidedPipeline.from_pretrained( + ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 + ... ) + >>> pipe.remove_unused_weights() + >>> pipe = pipe.to("cuda") + + >>> generator = torch.Generator(device="cuda").manual_seed(0) + >>> text_to_image_strength = 0.75 + + >>> image = pipe( + ... prompt=text, image=image, text_to_image_strength=text_to_image_strength, generator=generator + ... ).images[0] + >>> image.save("./car_variation.png") + ``` + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images. + """ + # 0. Default height and width to unet + height = height or self.image_unet.config.sample_size * self.vae_scale_factor + width = width or self.image_unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, image, height, width, callback_steps) + + # 2. Define call parameters + prompt = [prompt] if not isinstance(prompt, list) else prompt + image = [image] if not isinstance(image, list) else image + batch_size = len(prompt) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompts + prompt_embeds = self._encode_text_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance) + image_embeddings = self._encode_image_prompt(image, device, num_images_per_prompt, do_classifier_free_guidance) + dual_prompt_embeddings = torch.cat([prompt_embeds, image_embeddings], dim=1) + prompt_types = ("text", "image") + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.image_unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + dual_prompt_embeddings.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Combine the attention blocks of the image and text UNets + self.set_transformer_params(text_to_image_strength, prompt_types) + + # 8. Denoising loop + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.image_unet(latent_model_input, t, encoder_hidden_states=dual_prompt_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = latents + + image = self.image_processor.postprocess(image, output_type=output_type) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py b/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py new file mode 100644 index 0000000000000000000000000000000000000000..348417ad11dfe703592c7452d2e10a4575d625a4 --- /dev/null +++ b/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py @@ -0,0 +1,401 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable + +import numpy as np +import PIL.Image +import torch +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection + +from ....image_processor import VaeImageProcessor +from ....models import AutoencoderKL, UNet2DConditionModel +from ....schedulers import KarrasDiffusionSchedulers +from ....utils import deprecate, logging +from ....utils.torch_utils import randn_tensor +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class VersatileDiffusionImageVariationPipeline(DiffusionPipeline): + r""" + Pipeline for image variation using Versatile Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Parameters: + vqvae ([`VQModel`]): + Vector-quantized (VQ) model to encode and decode images to and from latent representations. + bert ([`LDMBertModel`]): + Text-encoder model based on [`~transformers.BERT`]. + tokenizer ([`~transformers.BertTokenizer`]): + A `BertTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + model_cpu_offload_seq = "bert->unet->vqvae" + + image_feature_extractor: CLIPImageProcessor + image_encoder: CLIPVisionModelWithProjection + image_unet: UNet2DConditionModel + vae: AutoencoderKL + scheduler: KarrasDiffusionSchedulers + + def __init__( + self, + image_feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection, + image_unet: UNet2DConditionModel, + vae: AutoencoderKL, + scheduler: KarrasDiffusionSchedulers, + ): + super().__init__() + self.register_modules( + image_feature_extractor=image_feature_extractor, + image_encoder=image_encoder, + image_unet=image_unet, + vae=vae, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + """ + + def normalize_embeddings(encoder_output): + embeds = self.image_encoder.vision_model.post_layernorm(encoder_output.last_hidden_state) + embeds = self.image_encoder.visual_projection(embeds) + embeds_pooled = embeds[:, 0:1] + embeds = embeds / torch.norm(embeds_pooled, dim=-1, keepdim=True) + return embeds + + if isinstance(prompt, torch.Tensor) and len(prompt.shape) == 4: + prompt = list(prompt) + + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + # get prompt text embeddings + image_input = self.image_feature_extractor(images=prompt, return_tensors="pt") + pixel_values = image_input.pixel_values.to(device).to(self.image_encoder.dtype) + image_embeddings = self.image_encoder(pixel_values) + image_embeddings = normalize_embeddings(image_embeddings) + + # duplicate image embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = image_embeddings.shape + image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1) + image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_images: list[str] + if negative_prompt is None: + uncond_images = [np.zeros((512, 512, 3)) + 0.5] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, PIL.Image.Image): + uncond_images = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_images = negative_prompt + + uncond_images = self.image_feature_extractor(images=uncond_images, return_tensors="pt") + pixel_values = uncond_images.pixel_values.to(device).to(self.image_encoder.dtype) + negative_prompt_embeds = self.image_encoder(pixel_values) + negative_prompt_embeds = normalize_embeddings(negative_prompt_embeds) + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and conditional embeddings into a single batch + # to avoid doing two forward passes + image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings]) + + return image_embeddings + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_image_variation.StableDiffusionImageVariationPipeline.check_inputs + def check_inputs(self, image, height, width, callback_steps): + if ( + not isinstance(image, torch.Tensor) + and not isinstance(image, PIL.Image.Image) + and not isinstance(image, list) + ): + raise ValueError( + "`image` has to be of type `torch.Tensor` or `PIL.Image.Image` or `list[PIL.Image.Image]` but is" + f" {type(image)}" + ) + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def __call__( + self, + image: PIL.Image.Image | list[PIL.Image.Image] | torch.Tensor, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + image (`PIL.Image.Image`, `list[PIL.Image.Image]` or `torch.Tensor`): + The image prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to `self.image_unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.image_unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + + Examples: + + ```py + >>> from diffusers import VersatileDiffusionImageVariationPipeline + >>> import torch + >>> import requests + >>> from io import BytesIO + >>> from PIL import Image + + >>> # let's download an initial image + >>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg" + + >>> response = requests.get(url) + >>> image = Image.open(BytesIO(response.content)).convert("RGB") + + >>> pipe = VersatileDiffusionImageVariationPipeline.from_pretrained( + ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> generator = torch.Generator(device="cuda").manual_seed(0) + >>> image = pipe(image, generator=generator).images[0] + >>> image.save("./car_variation.png") + ``` + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images. + """ + # 0. Default height and width to unet + height = height or self.image_unet.config.sample_size * self.vae_scale_factor + width = width or self.image_unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs(image, height, width, callback_steps) + + # 2. Define call parameters + batch_size = 1 if isinstance(image, PIL.Image.Image) else len(image) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + image_embeddings = self._encode_prompt( + image, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.image_unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + image_embeddings.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.image_unet(latent_model_input, t, encoder_hidden_states=image_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = latents + + image = self.image_processor.postprocess(image, output_type=output_type) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py b/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py new file mode 100644 index 0000000000000000000000000000000000000000..8e2af1063421dfac5e911835fa8b0c0862fbe5ff --- /dev/null +++ b/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py @@ -0,0 +1,479 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable + +import torch +from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer + +from ....image_processor import VaeImageProcessor +from ....models import AutoencoderKL, Transformer2DModel, UNet2DConditionModel +from ....schedulers import KarrasDiffusionSchedulers +from ....utils import deprecate, logging +from ....utils.torch_utils import randn_tensor +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from .modeling_text_unet import UNetFlatConditionModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class VersatileDiffusionTextToImagePipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using Versatile Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Parameters: + vqvae ([`VQModel`]): + Vector-quantized (VQ) model to encode and decode images to and from latent representations. + bert ([`LDMBertModel`]): + Text-encoder model based on [`~transformers.BERT`]. + tokenizer ([`~transformers.BertTokenizer`]): + A `BertTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + model_cpu_offload_seq = "bert->unet->vqvae" + + tokenizer: CLIPTokenizer + image_feature_extractor: CLIPImageProcessor + text_encoder: CLIPTextModelWithProjection + image_unet: UNet2DConditionModel + text_unet: UNetFlatConditionModel + vae: AutoencoderKL + scheduler: KarrasDiffusionSchedulers + + _optional_components = ["text_unet"] + + def __init__( + self, + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModelWithProjection, + image_unet: UNet2DConditionModel, + text_unet: UNetFlatConditionModel, + vae: AutoencoderKL, + scheduler: KarrasDiffusionSchedulers, + ): + super().__init__() + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + image_unet=image_unet, + text_unet=text_unet, + vae=vae, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + if self.text_unet is not None: + self._swap_unet_attention_blocks() + + def _swap_unet_attention_blocks(self): + """ + Swap the `Transformer2DModel` blocks between the image and text UNets + """ + for name, module in self.image_unet.named_modules(): + if isinstance(module, Transformer2DModel): + parent_name, index = name.rsplit(".", 1) + index = int(index) + self.image_unet.get_submodule(parent_name)[index], self.text_unet.get_submodule(parent_name)[index] = ( + self.text_unet.get_submodule(parent_name)[index], + self.image_unet.get_submodule(parent_name)[index], + ) + + def remove_unused_weights(self): + self.register_modules(text_unet=None) + + def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + """ + + def normalize_embeddings(encoder_output): + embeds = self.text_encoder.text_projection(encoder_output.last_hidden_state) + embeds_pooled = encoder_output.text_embeds + embeds = embeds / torch.norm(embeds_pooled.unsqueeze(1), dim=-1, keepdim=True) + return embeds + + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids + + if not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = normalize_embeddings(prompt_embeds) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = normalize_embeddings(negative_prompt_embeds) + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def __call__( + self, + prompt: str | list[str], + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`): + The prompt or prompts to guide image generation. + height (`int`, *optional*, defaults to `self.image_unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.image_unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + + Examples: + + ```py + >>> from diffusers import VersatileDiffusionTextToImagePipeline + >>> import torch + + >>> pipe = VersatileDiffusionTextToImagePipeline.from_pretrained( + ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 + ... ) + >>> pipe.remove_unused_weights() + >>> pipe = pipe.to("cuda") + + >>> generator = torch.Generator(device="cuda").manual_seed(0) + >>> image = pipe("an astronaut riding on a horse on mars", generator=generator).images[0] + >>> image.save("./astronaut.png") + ``` + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images. + """ + # 0. Default height and width to unet + height = height or self.image_unet.config.sample_size * self.vae_scale_factor + width = width or self.image_unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, callback_steps) + + # 2. Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.image_unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.image_unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = latents + + image = self.image_processor.postprocess(image, output_type=output_type) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/pipelines/deprecated/vq_diffusion/__init__.py b/diffusers/pipelines/deprecated/vq_diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..070903377c7188415af0417d4839d74a8a34dc01 --- /dev/null +++ b/diffusers/pipelines/deprecated/vq_diffusion/__init__.py @@ -0,0 +1,57 @@ +from typing import TYPE_CHECKING + +from ....utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ....utils.dummy_torch_and_transformers_objects import ( + LearnedClassifierFreeSamplingEmbeddings, + VQDiffusionPipeline, + ) + + _dummy_objects.update( + { + "LearnedClassifierFreeSamplingEmbeddings": LearnedClassifierFreeSamplingEmbeddings, + "VQDiffusionPipeline": VQDiffusionPipeline, + } + ) +else: + _import_structure["pipeline_vq_diffusion"] = ["LearnedClassifierFreeSamplingEmbeddings", "VQDiffusionPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ....utils.dummy_torch_and_transformers_objects import ( + LearnedClassifierFreeSamplingEmbeddings, + VQDiffusionPipeline, + ) + else: + from .pipeline_vq_diffusion import LearnedClassifierFreeSamplingEmbeddings, VQDiffusionPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py b/diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..8f3b9512e888fce4ba9b639987f74da7a63bb755 --- /dev/null +++ b/diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py @@ -0,0 +1,325 @@ +# Copyright 2025 Microsoft and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable + +import torch +from transformers import CLIPTextModel, CLIPTokenizer + +from ....configuration_utils import ConfigMixin, register_to_config +from ....models import ModelMixin, Transformer2DModel, VQModel +from ....schedulers import VQDiffusionScheduler +from ....utils import logging +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class LearnedClassifierFreeSamplingEmbeddings(ModelMixin, ConfigMixin): + """ + Utility class for storing learned text embeddings for classifier free sampling + """ + + @register_to_config + def __init__(self, learnable: bool, hidden_size: int | None = None, length: int | None = None): + super().__init__() + + self.learnable = learnable + + if self.learnable: + assert hidden_size is not None, "learnable=True requires `hidden_size` to be set" + assert length is not None, "learnable=True requires `length` to be set" + + embeddings = torch.zeros(length, hidden_size) + else: + embeddings = None + + self.embeddings = torch.nn.Parameter(embeddings) + + +class VQDiffusionPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using VQ Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + vqvae ([`VQModel`]): + Vector Quantized Variational Auto-Encoder (VAE) model to encode and decode images to and from latent + representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + transformer ([`Transformer2DModel`]): + A conditional `Transformer2DModel` to denoise the encoded image latents. + scheduler ([`VQDiffusionScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + """ + + vqvae: VQModel + text_encoder: CLIPTextModel + tokenizer: CLIPTokenizer + transformer: Transformer2DModel + learned_classifier_free_sampling_embeddings: LearnedClassifierFreeSamplingEmbeddings + scheduler: VQDiffusionScheduler + + def __init__( + self, + vqvae: VQModel, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + transformer: Transformer2DModel, + scheduler: VQDiffusionScheduler, + learned_classifier_free_sampling_embeddings: LearnedClassifierFreeSamplingEmbeddings, + ): + super().__init__() + + self.register_modules( + vqvae=vqvae, + transformer=transformer, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + learned_classifier_free_sampling_embeddings=learned_classifier_free_sampling_embeddings, + ) + + def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance): + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + if text_input_ids.shape[-1] > self.tokenizer.model_max_length: + removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + prompt_embeds = self.text_encoder(text_input_ids.to(self.device))[0] + + # NOTE: This additional step of normalizing the text embeddings is from VQ-Diffusion. + # While CLIP does normalize the pooled output of the text transformer when combining + # the image and text embeddings, CLIP does not directly normalize the last hidden state. + # + # CLIP normalizing the pooled output. + # https://github.com/huggingface/transformers/blob/d92e22d1f28324f513f3080e5c47c071a3916721/src/transformers/models/clip/modeling_clip.py#L1052-L1053 + prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True) + + # duplicate text embeddings for each generation per prompt + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + if self.learned_classifier_free_sampling_embeddings.learnable: + negative_prompt_embeds = self.learned_classifier_free_sampling_embeddings.embeddings + negative_prompt_embeds = negative_prompt_embeds.unsqueeze(0).repeat(batch_size, 1, 1) + else: + uncond_tokens = [""] * batch_size + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + negative_prompt_embeds = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + # See comment for normalizing text embeddings + negative_prompt_embeds = negative_prompt_embeds / negative_prompt_embeds.norm(dim=-1, keepdim=True) + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + @torch.no_grad() + def __call__( + self, + prompt: str | list[str], + num_inference_steps: int = 100, + guidance_scale: float = 5.0, + truncation_rate: float = 1.0, + num_images_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + ) -> ImagePipelineOutput | tuple: + """ + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`): + The prompt or prompts to guide image generation. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + truncation_rate (`float`, *optional*, defaults to 1.0 (equivalent to no truncation)): + Used to "truncate" the predicted classes for x_0 such that the cumulative probability for a pixel is at + most `truncation_rate`. The lowest probabilities that would increase the cumulative probability above + `truncation_rate` are set to zero. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor` of shape (batch), *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Must be valid embedding indices.If not provided, a latents tensor will be generated of + completely masked latent pixels. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images. + """ + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + batch_size = batch_size * num_images_per_prompt + + do_classifier_free_guidance = guidance_scale > 1.0 + + prompt_embeds = self._encode_prompt(prompt, num_images_per_prompt, do_classifier_free_guidance) + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + # get the initial completely masked latents unless the user supplied it + + latents_shape = (batch_size, self.transformer.num_latent_pixels) + if latents is None: + mask_class = self.transformer.num_vector_embeds - 1 + latents = torch.full(latents_shape, mask_class).to(self.device) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + if (latents < 0).any() or (latents >= self.transformer.num_vector_embeds).any(): + raise ValueError( + "Unexpected latents value(s). All latents be valid embedding indices i.e. in the range 0," + f" {self.transformer.num_vector_embeds - 1} (inclusive)." + ) + latents = latents.to(self.device) + + # set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=self.device) + + timesteps_tensor = self.scheduler.timesteps.to(self.device) + + sample = latents + + for i, t in enumerate(self.progress_bar(timesteps_tensor)): + # expand the sample if we are doing classifier free guidance + latent_model_input = torch.cat([sample] * 2) if do_classifier_free_guidance else sample + + # predict the un-noised image + # model_output == `log_p_x_0` + model_output = self.transformer(latent_model_input, encoder_hidden_states=prompt_embeds, timestep=t).sample + + if do_classifier_free_guidance: + model_output_uncond, model_output_text = model_output.chunk(2) + model_output = model_output_uncond + guidance_scale * (model_output_text - model_output_uncond) + model_output -= torch.logsumexp(model_output, dim=1, keepdim=True) + + model_output = self.truncate(model_output, truncation_rate) + + # remove `log(0)`'s (`-inf`s) + model_output = model_output.clamp(-70) + + # compute the previous noisy sample x_t -> x_t-1 + sample = self.scheduler.step(model_output, timestep=t, sample=sample, generator=generator).prev_sample + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, t, sample) + + embedding_channels = self.vqvae.config.vq_embed_dim + embeddings_shape = (batch_size, self.transformer.height, self.transformer.width, embedding_channels) + embeddings = self.vqvae.quantize.get_codebook_entry(sample, shape=embeddings_shape) + image = self.vqvae.decode(embeddings, force_not_quantize=True).sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) + + def truncate(self, log_p_x_0: torch.Tensor, truncation_rate: float) -> torch.Tensor: + """ + Truncates `log_p_x_0` such that for each column vector, the total cumulative probability is `truncation_rate` + The lowest probabilities that would increase the cumulative probability above `truncation_rate` are set to + zero. + """ + sorted_log_p_x_0, indices = torch.sort(log_p_x_0, 1, descending=True) + sorted_p_x_0 = torch.exp(sorted_log_p_x_0) + keep_mask = sorted_p_x_0.cumsum(dim=1) < truncation_rate + + # Ensure that at least the largest probability is not zeroed out + all_true = torch.full_like(keep_mask[:, 0:1, :], True) + keep_mask = torch.cat((all_true, keep_mask), dim=1) + keep_mask = keep_mask[:, :-1, :] + + keep_mask = keep_mask.gather(1, indices.argsort(1)) + + rv = log_p_x_0.clone() + + rv[~keep_mask] = -torch.inf # -inf = log(0) + + return rv diff --git a/diffusers/pipelines/easyanimate/__init__.py b/diffusers/pipelines/easyanimate/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..49923423f951c1e6124ddb9b0506f740f28a7a5a --- /dev/null +++ b/diffusers/pipelines/easyanimate/__init__.py @@ -0,0 +1,52 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_easyanimate"] = ["EasyAnimatePipeline"] + _import_structure["pipeline_easyanimate_control"] = ["EasyAnimateControlPipeline"] + _import_structure["pipeline_easyanimate_inpaint"] = ["EasyAnimateInpaintPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_easyanimate import EasyAnimatePipeline + from .pipeline_easyanimate_control import EasyAnimateControlPipeline + from .pipeline_easyanimate_inpaint import EasyAnimateInpaintPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/pipelines/easyanimate/pipeline_easyanimate.py b/diffusers/pipelines/easyanimate/pipeline_easyanimate.py new file mode 100644 index 0000000000000000000000000000000000000000..6ec8f44e6d1af6f0f541f751ec06152c7ae8505a --- /dev/null +++ b/diffusers/pipelines/easyanimate/pipeline_easyanimate.py @@ -0,0 +1,774 @@ +# Copyright 2025 The EasyAnimate team and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable + +import torch +from transformers import ( + BertModel, + BertTokenizer, + Qwen2Tokenizer, + Qwen2VLForConditionalGeneration, +) + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...models import AutoencoderKLMagvit, EasyAnimateTransformer3DModel +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from .pipeline_output import EasyAnimatePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import EasyAnimatePipeline + >>> from diffusers.utils import export_to_video + + >>> # Models: "alibaba-pai/EasyAnimateV5.1-12b-zh" + >>> pipe = EasyAnimatePipeline.from_pretrained( + ... "alibaba-pai/EasyAnimateV5.1-7b-zh-diffusers", torch_dtype=torch.float16 + ... ).to("cuda") + >>> prompt = ( + ... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. " + ... "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other " + ... "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, " + ... "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. " + ... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical " + ... "atmosphere of this unique musical performance." + ... ) + >>> sample_size = (512, 512) + >>> video = pipe( + ... prompt=prompt, + ... guidance_scale=6, + ... negative_prompt="bad detailed", + ... height=sample_size[0], + ... width=sample_size[1], + ... num_inference_steps=50, + ... ).frames[0] + >>> export_to_video(video, "output.mp4", fps=8) + ``` +""" + + +# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid +def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): + tw = tgt_width + th = tgt_height + h, w = src + r = h / w + if r > (th / tw): + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class EasyAnimatePipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using EasyAnimate. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + EasyAnimate uses one text encoder [qwen2 vl](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) in V5.1. + + Args: + vae ([`AutoencoderKLMagvit`]): + Variational Auto-Encoder (VAE) Model to encode and decode video to and from latent representations. + text_encoder (`~transformers.Qwen2VLForConditionalGeneration`, `~transformers.BertModel` | None): + EasyAnimate uses [qwen2 vl](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) in V5.1. + tokenizer (`~transformers.Qwen2Tokenizer`, `~transformers.BertTokenizer` | None): + A `Qwen2Tokenizer` or `BertTokenizer` to tokenize text. + transformer ([`EasyAnimateTransformer3DModel`]): + The EasyAnimate model designed by EasyAnimate Team. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with EasyAnimate to denoise the encoded image latents. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKLMagvit, + text_encoder: Qwen2VLForConditionalGeneration | BertModel, + tokenizer: Qwen2Tokenizer | BertTokenizer, + transformer: EasyAnimateTransformer3DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + self.enable_text_attention_mask = ( + self.transformer.config.enable_text_attention_mask + if getattr(self, "transformer", None) is not None + else True + ) + self.vae_spatial_compression_ratio = ( + self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 8 + ) + self.vae_temporal_compression_ratio = ( + self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 4 + ) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio) + + def encode_prompt( + self, + prompt: str | list[str], + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: str | list[str] | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + max_sequence_length: int = 256, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + dtype (`torch.dtype`): + torch dtype + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the prompt. Required when `prompt_embeds` is passed directly. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly. + max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt. + """ + dtype = dtype or self.text_encoder.dtype + device = device or self.text_encoder.device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + if isinstance(prompt, str): + messages = [ + { + "role": "user", + "content": [{"type": "text", "text": prompt}], + } + ] + else: + messages = [ + { + "role": "user", + "content": [{"type": "text", "text": _prompt}], + } + for _prompt in prompt + ] + text = [ + self.tokenizer.apply_chat_template([m], tokenize=False, add_generation_prompt=True) for m in messages + ] + + text_inputs = self.tokenizer( + text=text, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_attention_mask=True, + padding_side="right", + return_tensors="pt", + ) + text_inputs = text_inputs.to(self.text_encoder.device) + + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + if self.enable_text_attention_mask: + # Inference: Generation of the output + prompt_embeds = self.text_encoder( + input_ids=text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True + ).hidden_states[-2] + else: + raise ValueError("LLM needs attention_mask") + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.to(device=device) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + if negative_prompt is not None and isinstance(negative_prompt, str): + messages = [ + { + "role": "user", + "content": [{"type": "text", "text": negative_prompt}], + } + ] + else: + messages = [ + { + "role": "user", + "content": [{"type": "text", "text": _negative_prompt}], + } + for _negative_prompt in negative_prompt + ] + text = [ + self.tokenizer.apply_chat_template([m], tokenize=False, add_generation_prompt=True) for m in messages + ] + + text_inputs = self.tokenizer( + text=text, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_attention_mask=True, + padding_side="right", + return_tensors="pt", + ) + text_inputs = text_inputs.to(self.text_encoder.device) + + text_input_ids = text_inputs.input_ids + negative_prompt_attention_mask = text_inputs.attention_mask + if self.enable_text_attention_mask: + # Inference: Generation of the output + negative_prompt_embeds = self.text_encoder( + input_ids=text_input_ids, + attention_mask=negative_prompt_attention_mask, + output_hidden_states=True, + ).hidden_states[-2] + else: + raise ValueError("LLM needs attention_mask") + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + negative_prompt_attention_mask = negative_prompt_attention_mask.to(device=device) + + return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + if latents is not None: + return latents.to(device=device, dtype=dtype) + + shape = ( + batch_size, + num_channels_latents, + (num_frames - 1) // self.vae_temporal_compression_ratio + 1, + height // self.vae_spatial_compression_ratio, + width // self.vae_spatial_compression_ratio, + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # scale the initial noise by the standard deviation required by the scheduler + if hasattr(self.scheduler, "init_noise_sigma"): + latents = latents * self.scheduler.init_noise_sigma + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + num_frames: int | None = 49, + height: int | None = 512, + width: int | None = 512, + num_inference_steps: int | None = 50, + guidance_scale: float | None = 5.0, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float | None = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + timesteps: list[int] | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + guidance_rescale: float = 0.0, + ): + r""" + Generates images or video using the EasyAnimate pipeline based on the provided prompts. + + Examples: + prompt (`str` or `list[str]`, *optional*): + Text prompts to guide the image or video generation. If not provided, use `prompt_embeds` instead. + num_frames (`int`, *optional*): + Length of the generated video (in frames). + height (`int`, *optional*): + Height of the generated image in pixels. + width (`int`, *optional*): + Width of the generated image in pixels. + num_inference_steps (`int`, *optional*, defaults to 50): + Number of denoising steps during generation. More steps generally yield higher quality images but slow + down inference. + guidance_scale (`float`, *optional*, defaults to 5.0): + Encourages the model to align outputs with prompts. A higher value may decrease image quality. + negative_prompt (`str` or `list[str]`, *optional*): + Prompts indicating what to exclude in generation. If not specified, use `negative_prompt_embeds`. + num_images_per_prompt (`int`, *optional*, defaults to 1): + Number of images to generate for each prompt. + eta (`float`, *optional*, defaults to 0.0): + Applies to DDIM scheduling. Controlled by the eta parameter from the related literature. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A generator to ensure reproducibility in image generation. + latents (`torch.Tensor`, *optional*): + Predefined latent tensors to condition generation. + prompt_embeds (`torch.Tensor`, *optional*): + Text embeddings for the prompts. Overrides prompt string inputs for more flexibility. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Embeddings for negative prompts. Overrides string inputs if defined. + prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the primary prompt embeddings. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for negative prompt embeddings. + output_type (`str`, *optional*, defaults to "latent"): + Format of the generated output, either as a PIL image or as a NumPy array. + return_dict (`bool`, *optional*, defaults to `True`): + If `True`, returns a structured output. Otherwise returns a simple tuple. + callback_on_step_end (`Callable`, *optional*): + Functions called at the end of each denoising step. + callback_on_step_end_tensor_inputs (`list[str]`, *optional*): + Tensor names to be included in callback function calls. + guidance_rescale (`float`, *optional*, defaults to 0.0): + Adjusts noise levels based on guidance scale. + original_size (`tuple[int, int]`, *optional*, defaults to `(1024, 1024)`): + Original dimensions of the output. + target_size (`tuple[int, int]`, *optional*): + Desired output dimensions for calculations. + crops_coords_top_left (`tuple[int, int]`, *optional*, defaults to `(0, 0)`): + Coordinates for cropping. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 0. default height and width + height = int((height // 16) * 16) + width = int((width // 16) * 16) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + callback_on_step_end_tensor_inputs, + ) + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + else: + dtype = self.transformer.dtype + + # 3. Encode input prompt + ( + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + device=device, + dtype=dtype, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + ) + + # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps, mu=1 + ) + else: + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps + ) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + num_frames, + height, + width, + dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask]) + + prompt_embeds = prompt_embeds.to(device=device) + prompt_attention_mask = prompt_attention_mask.to(device=device) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + if hasattr(self.scheduler, "scale_model_input"): + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input + t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to( + dtype=latent_model_input.dtype + ) + + # predict the noise residual + noise_pred = self.transformer( + latent_model_input, + t_expand, + encoder_hidden_states=prompt_embeds, + return_dict=False, + )[0] + + if noise_pred.size()[1] != self.vae.config.latent_channels: + noise_pred, _ = noise_pred.chunk(2, dim=1) + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://huggingface.co/papers/2305.08891 + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + latents = 1 / self.vae.config.scaling_factor * latents + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return EasyAnimatePipelineOutput(frames=video) diff --git a/diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py b/diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py new file mode 100644 index 0000000000000000000000000000000000000000..5e07996a661c18313c6e4eee554c5017fbf97ca2 --- /dev/null +++ b/diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py @@ -0,0 +1,998 @@ +# Copyright 2025 The EasyAnimate team and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable + +import numpy as np +import torch +import torch.nn.functional as F +from PIL import Image +from transformers import ( + BertModel, + BertTokenizer, + Qwen2Tokenizer, + Qwen2VLForConditionalGeneration, +) + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKLMagvit, EasyAnimateTransformer3DModel +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from .pipeline_output import EasyAnimatePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import EasyAnimateControlPipeline + >>> from diffusers.pipelines.easyanimate.pipeline_easyanimate_control import get_video_to_video_latent + >>> from diffusers.utils import export_to_video, load_video + + >>> pipe = EasyAnimateControlPipeline.from_pretrained( + ... "alibaba-pai/EasyAnimateV5.1-12b-zh-Control-diffusers", torch_dtype=torch.bfloat16 + ... ) + >>> pipe.to("cuda") + + >>> control_video = load_video( + ... "https://huggingface.co/alibaba-pai/EasyAnimateV5.1-12b-zh-Control/blob/main/asset/pose.mp4" + ... ) + >>> prompt = ( + ... "In this sunlit outdoor garden, a beautiful woman is dressed in a knee-length, sleeveless white dress. " + ... "The hem of her dress gently sways with her graceful dance, much like a butterfly fluttering in the breeze. " + ... "Sunlight filters through the leaves, casting dappled shadows that highlight her soft features and clear eyes, " + ... "making her appear exceptionally elegant. It seems as if every movement she makes speaks of youth and vitality. " + ... "As she twirls on the grass, her dress flutters, as if the entire garden is rejoicing in her dance. " + ... "The colorful flowers around her sway in the gentle breeze, with roses, chrysanthemums, and lilies each " + ... "releasing their fragrances, creating a relaxed and joyful atmosphere." + ... ) + >>> sample_size = (672, 384) + >>> num_frames = 49 + + >>> input_video, _, _ = get_video_to_video_latent(control_video, num_frames, sample_size) + >>> video = pipe( + ... prompt, + ... num_frames=num_frames, + ... negative_prompt="Twisted body, limb deformities, text subtitles, comics, stillness, ugliness, errors, garbled text.", + ... height=sample_size[0], + ... width=sample_size[1], + ... control_video=input_video, + ... ).frames[0] + >>> export_to_video(video, "output.mp4", fps=8) + ``` +""" + + +def preprocess_image(image, sample_size): + """ + Preprocess a single image (PIL.Image, numpy.ndarray, or torch.Tensor) to a resized tensor. + """ + if isinstance(image, torch.Tensor): + # If input is a tensor, assume it's in CHW format and resize using interpolation + image = torch.nn.functional.interpolate( + image.unsqueeze(0), size=sample_size, mode="bilinear", align_corners=False + ).squeeze(0) + elif isinstance(image, Image.Image): + # If input is a PIL image, resize and convert to numpy array + image = image.resize((sample_size[1], sample_size[0])) + image = np.array(image) + elif isinstance(image, np.ndarray): + # If input is a numpy array, resize using PIL + image = Image.fromarray(image).resize((sample_size[1], sample_size[0])) + image = np.array(image) + else: + raise ValueError("Unsupported input type. Expected PIL.Image, numpy.ndarray, or torch.Tensor.") + + # Convert to tensor if not already + if not isinstance(image, torch.Tensor): + image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0 # HWC -> CHW, normalize to [0, 1] + + return image + + +def get_video_to_video_latent(input_video, num_frames, sample_size, validation_video_mask=None, ref_image=None): + if input_video is not None: + # Convert each frame in the list to tensor + input_video = [preprocess_image(frame, sample_size=sample_size) for frame in input_video] + + # Stack all frames into a single tensor (F, C, H, W) + input_video = torch.stack(input_video)[:num_frames] + + # Add batch dimension (B, F, C, H, W) + input_video = input_video.permute(1, 0, 2, 3).unsqueeze(0) + + if validation_video_mask is not None: + # Handle mask input + validation_video_mask = preprocess_image(validation_video_mask, size=sample_size) + input_video_mask = torch.where(validation_video_mask < 240 / 255.0, 0.0, 255) + + # Adjust mask dimensions to match video + input_video_mask = input_video_mask.unsqueeze(0).unsqueeze(-1).permute([3, 0, 1, 2]).unsqueeze(0) + input_video_mask = torch.tile(input_video_mask, [1, 1, input_video.size()[2], 1, 1]) + input_video_mask = input_video_mask.to(input_video.device, input_video.dtype) + else: + input_video_mask = torch.zeros_like(input_video[:, :1]) + input_video_mask[:, :, :] = 255 + else: + input_video, input_video_mask = None, None + + if ref_image is not None: + # Convert reference image to tensor + ref_image = preprocess_image(ref_image, size=sample_size) + ref_image = ref_image.permute(1, 0, 2, 3).unsqueeze(0) # Add batch dimension (B, C, H, W) + else: + ref_image = None + + return input_video, input_video_mask, ref_image + + +# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid +def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): + tw = tgt_width + th = tgt_height + h, w = src + r = h / w + if r > (th / tw): + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +# Resize mask information in magvit +def resize_mask(mask, latent, process_first_frame_only=True): + latent_size = latent.size() + + if process_first_frame_only: + target_size = list(latent_size[2:]) + target_size[0] = 1 + first_frame_resized = F.interpolate( + mask[:, :, 0:1, :, :], size=target_size, mode="trilinear", align_corners=False + ) + + target_size = list(latent_size[2:]) + target_size[0] = target_size[0] - 1 + if target_size[0] != 0: + remaining_frames_resized = F.interpolate( + mask[:, :, 1:, :, :], size=target_size, mode="trilinear", align_corners=False + ) + resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2) + else: + resized_mask = first_frame_resized + else: + target_size = list(latent_size[2:]) + resized_mask = F.interpolate(mask, size=target_size, mode="trilinear", align_corners=False) + return resized_mask + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class EasyAnimateControlPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using EasyAnimate. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + EasyAnimate uses one text encoder [qwen2 vl](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) in V5.1. + + Args: + vae ([`AutoencoderKLMagvit`]): + Variational Auto-Encoder (VAE) Model to encode and decode video to and from latent representations. + text_encoder (`~transformers.Qwen2VLForConditionalGeneration`, `~transformers.BertModel` | None): + EasyAnimate uses [qwen2 vl](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) in V5.1. + tokenizer (`~transformers.Qwen2Tokenizer`, `~transformers.BertTokenizer` | None): + A `Qwen2Tokenizer` or `BertTokenizer` to tokenize text. + transformer ([`EasyAnimateTransformer3DModel`]): + The EasyAnimate model designed by EasyAnimate Team. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with EasyAnimate to denoise the encoded image latents. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKLMagvit, + text_encoder: Qwen2VLForConditionalGeneration | BertModel, + tokenizer: Qwen2Tokenizer | BertTokenizer, + transformer: EasyAnimateTransformer3DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + + self.enable_text_attention_mask = ( + self.transformer.config.enable_text_attention_mask + if getattr(self, "transformer", None) is not None + else True + ) + self.vae_spatial_compression_ratio = ( + self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 8 + ) + self.vae_temporal_compression_ratio = ( + self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 4 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_spatial_compression_ratio) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_spatial_compression_ratio, + do_normalize=False, + do_binarize=True, + do_convert_grayscale=True, + ) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio) + + # Copied from diffusers.pipelines.easyanimate.pipeline_easyanimate.EasyAnimatePipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: str | list[str] | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + max_sequence_length: int = 256, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + dtype (`torch.dtype`): + torch dtype + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the prompt. Required when `prompt_embeds` is passed directly. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly. + max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt. + """ + dtype = dtype or self.text_encoder.dtype + device = device or self.text_encoder.device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + if isinstance(prompt, str): + messages = [ + { + "role": "user", + "content": [{"type": "text", "text": prompt}], + } + ] + else: + messages = [ + { + "role": "user", + "content": [{"type": "text", "text": _prompt}], + } + for _prompt in prompt + ] + text = [ + self.tokenizer.apply_chat_template([m], tokenize=False, add_generation_prompt=True) for m in messages + ] + + text_inputs = self.tokenizer( + text=text, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_attention_mask=True, + padding_side="right", + return_tensors="pt", + ) + text_inputs = text_inputs.to(self.text_encoder.device) + + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + if self.enable_text_attention_mask: + # Inference: Generation of the output + prompt_embeds = self.text_encoder( + input_ids=text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True + ).hidden_states[-2] + else: + raise ValueError("LLM needs attention_mask") + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.to(device=device) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + if negative_prompt is not None and isinstance(negative_prompt, str): + messages = [ + { + "role": "user", + "content": [{"type": "text", "text": negative_prompt}], + } + ] + else: + messages = [ + { + "role": "user", + "content": [{"type": "text", "text": _negative_prompt}], + } + for _negative_prompt in negative_prompt + ] + text = [ + self.tokenizer.apply_chat_template([m], tokenize=False, add_generation_prompt=True) for m in messages + ] + + text_inputs = self.tokenizer( + text=text, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_attention_mask=True, + padding_side="right", + return_tensors="pt", + ) + text_inputs = text_inputs.to(self.text_encoder.device) + + text_input_ids = text_inputs.input_ids + negative_prompt_attention_mask = text_inputs.attention_mask + if self.enable_text_attention_mask: + # Inference: Generation of the output + negative_prompt_embeds = self.text_encoder( + input_ids=text_input_ids, + attention_mask=negative_prompt_attention_mask, + output_hidden_states=True, + ).hidden_states[-2] + else: + raise ValueError("LLM needs attention_mask") + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + negative_prompt_attention_mask = negative_prompt_attention_mask.to(device=device) + + return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + if latents is not None: + return latents.to(device=device, dtype=dtype) + + shape = ( + batch_size, + num_channels_latents, + (num_frames - 1) // self.vae_temporal_compression_ratio + 1, + height // self.vae_spatial_compression_ratio, + width // self.vae_spatial_compression_ratio, + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # scale the initial noise by the standard deviation required by the scheduler + if hasattr(self.scheduler, "init_noise_sigma"): + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_control_latents( + self, control, control_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the control to latents shape as we concatenate the control to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + + if control is not None: + control = control.to(device=device, dtype=dtype) + bs = 1 + new_control = [] + for i in range(0, control.shape[0], bs): + control_bs = control[i : i + bs] + control_bs = self.vae.encode(control_bs)[0] + control_bs = control_bs.mode() + new_control.append(control_bs) + control = torch.cat(new_control, dim=0) + control = control * self.vae.config.scaling_factor + + if control_image is not None: + control_image = control_image.to(device=device, dtype=dtype) + bs = 1 + new_control_pixel_values = [] + for i in range(0, control_image.shape[0], bs): + control_pixel_values_bs = control_image[i : i + bs] + control_pixel_values_bs = self.vae.encode(control_pixel_values_bs)[0] + control_pixel_values_bs = control_pixel_values_bs.mode() + new_control_pixel_values.append(control_pixel_values_bs) + control_image_latents = torch.cat(new_control_pixel_values, dim=0) + control_image_latents = control_image_latents * self.vae.config.scaling_factor + else: + control_image_latents = None + + return control, control_image_latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + num_frames: int | None = 49, + height: int | None = 512, + width: int | None = 512, + control_video: torch.FloatTensor = None, + control_camera_video: torch.FloatTensor = None, + ref_image: torch.FloatTensor = None, + num_inference_steps: int | None = 50, + guidance_scale: float | None = 5.0, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float | None = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + guidance_rescale: float = 0.0, + timesteps: list[int] | None = None, + ): + r""" + Generates images or video using the EasyAnimate pipeline based on the provided prompts. + + Examples: + prompt (`str` or `list[str]`, *optional*): + Text prompts to guide the image or video generation. If not provided, use `prompt_embeds` instead. + num_frames (`int`, *optional*): + Length of the generated video (in frames). + height (`int`, *optional*): + Height of the generated image in pixels. + width (`int`, *optional*): + Width of the generated image in pixels. + num_inference_steps (`int`, *optional*, defaults to 50): + Number of denoising steps during generation. More steps generally yield higher quality images but slow + down inference. + guidance_scale (`float`, *optional*, defaults to 5.0): + Encourages the model to align outputs with prompts. A higher value may decrease image quality. + negative_prompt (`str` or `list[str]`, *optional*): + Prompts indicating what to exclude in generation. If not specified, use `negative_prompt_embeds`. + num_images_per_prompt (`int`, *optional*, defaults to 1): + Number of images to generate for each prompt. + eta (`float`, *optional*, defaults to 0.0): + Applies to DDIM scheduling. Controlled by the eta parameter from the related literature. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A generator to ensure reproducibility in image generation. + latents (`torch.Tensor`, *optional*): + Predefined latent tensors to condition generation. + prompt_embeds (`torch.Tensor`, *optional*): + Text embeddings for the prompts. Overrides prompt string inputs for more flexibility. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Embeddings for negative prompts. Overrides string inputs if defined. + prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the primary prompt embeddings. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for negative prompt embeddings. + output_type (`str`, *optional*, defaults to "latent"): + Format of the generated output, either as a PIL image or as a NumPy array. + return_dict (`bool`, *optional*, defaults to `True`): + If `True`, returns a structured output. Otherwise returns a simple tuple. + callback_on_step_end (`Callable`, *optional*): + Functions called at the end of each denoising step. + callback_on_step_end_tensor_inputs (`list[str]`, *optional*): + Tensor names to be included in callback function calls. + guidance_rescale (`float`, *optional*, defaults to 0.0): + Adjusts noise levels based on guidance scale. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 0. default height and width + height = int((height // 16) * 16) + width = int((width // 16) * 16) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + callback_on_step_end_tensor_inputs, + ) + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + else: + dtype = self.transformer.dtype + + # 3. Encode input prompt + ( + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + device=device, + dtype=dtype, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + text_encoder_index=0, + ) + + # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps, mu=1 + ) + else: + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps + ) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + num_frames, + height, + width, + dtype, + device, + generator, + latents, + ) + + if control_camera_video is not None: + control_video_latents = resize_mask(control_camera_video, latents, process_first_frame_only=True) + control_video_latents = control_video_latents * 6 + control_latents = ( + torch.cat([control_video_latents] * 2) if self.do_classifier_free_guidance else control_video_latents + ).to(device, dtype) + elif control_video is not None: + batch_size, channels, num_frames, height_video, width_video = control_video.shape + control_video = self.image_processor.preprocess( + control_video.permute(0, 2, 1, 3, 4).reshape( + batch_size * num_frames, channels, height_video, width_video + ), + height=height, + width=width, + ) + control_video = control_video.to(dtype=torch.float32) + control_video = control_video.reshape(batch_size, num_frames, channels, height, width).permute( + 0, 2, 1, 3, 4 + ) + control_video_latents = self.prepare_control_latents( + None, + control_video, + batch_size, + height, + width, + dtype, + device, + generator, + self.do_classifier_free_guidance, + )[1] + control_latents = ( + torch.cat([control_video_latents] * 2) if self.do_classifier_free_guidance else control_video_latents + ).to(device, dtype) + else: + control_video_latents = torch.zeros_like(latents).to(device, dtype) + control_latents = ( + torch.cat([control_video_latents] * 2) if self.do_classifier_free_guidance else control_video_latents + ).to(device, dtype) + + if ref_image is not None: + batch_size, channels, num_frames, height_video, width_video = ref_image.shape + ref_image = self.image_processor.preprocess( + ref_image.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height_video, width_video), + height=height, + width=width, + ) + ref_image = ref_image.to(dtype=torch.float32) + ref_image = ref_image.reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4) + + ref_image_latents = self.prepare_control_latents( + None, + ref_image, + batch_size, + height, + width, + prompt_embeds.dtype, + device, + generator, + self.do_classifier_free_guidance, + )[1] + + ref_image_latents_conv_in = torch.zeros_like(latents) + if latents.size()[2] != 1: + ref_image_latents_conv_in[:, :, :1] = ref_image_latents + ref_image_latents_conv_in = ( + torch.cat([ref_image_latents_conv_in] * 2) + if self.do_classifier_free_guidance + else ref_image_latents_conv_in + ).to(device, dtype) + control_latents = torch.cat([control_latents, ref_image_latents_conv_in], dim=1) + else: + ref_image_latents_conv_in = torch.zeros_like(latents) + ref_image_latents_conv_in = ( + torch.cat([ref_image_latents_conv_in] * 2) + if self.do_classifier_free_guidance + else ref_image_latents_conv_in + ).to(device, dtype) + control_latents = torch.cat([control_latents, ref_image_latents_conv_in], dim=1) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask]) + + # To latents.device + prompt_embeds = prompt_embeds.to(device=device) + prompt_attention_mask = prompt_attention_mask.to(device=device) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + if hasattr(self.scheduler, "scale_model_input"): + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input + t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to( + dtype=latent_model_input.dtype + ) + # predict the noise residual + noise_pred = self.transformer( + latent_model_input, + t_expand, + encoder_hidden_states=prompt_embeds, + control_latents=control_latents, + return_dict=False, + )[0] + if noise_pred.size()[1] != self.vae.config.latent_channels: + noise_pred, _ = noise_pred.chunk(2, dim=1) + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://huggingface.co/papers/2305.08891 + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + # Convert to tensor + if not output_type == "latent": + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return EasyAnimatePipelineOutput(frames=video) diff --git a/diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py b/diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py new file mode 100644 index 0000000000000000000000000000000000000000..8723138980081a8d04cb50bf174317f50acf1837 --- /dev/null +++ b/diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py @@ -0,0 +1,1239 @@ +# Copyright 2025 The EasyAnimate team and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable + +import numpy as np +import torch +import torch.nn.functional as F +from PIL import Image +from transformers import ( + BertModel, + BertTokenizer, + Qwen2Tokenizer, + Qwen2VLForConditionalGeneration, +) + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKLMagvit, EasyAnimateTransformer3DModel +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from .pipeline_output import EasyAnimatePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import EasyAnimateInpaintPipeline + >>> from diffusers.pipelines.easyanimate.pipeline_easyanimate_inpaint import get_image_to_video_latent + >>> from diffusers.utils import export_to_video, load_image + + >>> pipe = EasyAnimateInpaintPipeline.from_pretrained( + ... "alibaba-pai/EasyAnimateV5.1-12b-zh-InP-diffusers", torch_dtype=torch.bfloat16 + ... ) + >>> pipe.to("cuda") + + >>> prompt = "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." + >>> validation_image_start = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg" + ... ) + + >>> validation_image_end = None + >>> sample_size = (448, 576) + >>> num_frames = 49 + >>> input_video, input_video_mask = get_image_to_video_latent( + ... [validation_image_start], validation_image_end, num_frames, sample_size + ... ) + + >>> video = pipe( + ... prompt, + ... num_frames=num_frames, + ... negative_prompt="Twisted body, limb deformities, text subtitles, comics, stillness, ugliness, errors, garbled text.", + ... height=sample_size[0], + ... width=sample_size[1], + ... video=input_video, + ... mask_video=input_video_mask, + ... ) + >>> export_to_video(video.frames[0], "output.mp4", fps=8) + ``` +""" + + +def preprocess_image(image, sample_size): + """ + Preprocess a single image (PIL.Image, numpy.ndarray, or torch.Tensor) to a resized tensor. + """ + if isinstance(image, torch.Tensor): + # If input is a tensor, assume it's in CHW format and resize using interpolation + image = torch.nn.functional.interpolate( + image.unsqueeze(0), size=sample_size, mode="bilinear", align_corners=False + ).squeeze(0) + elif isinstance(image, Image.Image): + # If input is a PIL image, resize and convert to numpy array + image = image.resize((sample_size[1], sample_size[0])) + image = np.array(image) + elif isinstance(image, np.ndarray): + # If input is a numpy array, resize using PIL + image = Image.fromarray(image).resize((sample_size[1], sample_size[0])) + image = np.array(image) + else: + raise ValueError("Unsupported input type. Expected PIL.Image, numpy.ndarray, or torch.Tensor.") + + # Convert to tensor if not already + if not isinstance(image, torch.Tensor): + image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0 # HWC -> CHW, normalize to [0, 1] + + return image + + +def get_image_to_video_latent(validation_image_start, validation_image_end, num_frames, sample_size): + """ + Generate latent representations for video from start and end images. Inputs can be PIL.Image, numpy.ndarray, or + torch.Tensor. + """ + input_video = None + input_video_mask = None + + if validation_image_start is not None: + # Preprocess the starting image(s) + if isinstance(validation_image_start, list): + image_start = [preprocess_image(img, sample_size) for img in validation_image_start] + else: + image_start = preprocess_image(validation_image_start, sample_size) + + # Create video tensor from the starting image(s) + if isinstance(image_start, list): + start_video = torch.cat( + [img.unsqueeze(1).unsqueeze(0) for img in image_start], + dim=2, + ) + input_video = torch.tile(start_video[:, :, :1], [1, 1, num_frames, 1, 1]) + input_video[:, :, : len(image_start)] = start_video + else: + input_video = torch.tile( + image_start.unsqueeze(1).unsqueeze(0), + [1, 1, num_frames, 1, 1], + ) + + # Normalize input video (already normalized in preprocess_image) + + # Create mask for the input video + input_video_mask = torch.zeros_like(input_video[:, :1]) + if isinstance(image_start, list): + input_video_mask[:, :, len(image_start) :] = 255 + else: + input_video_mask[:, :, 1:] = 255 + + # Handle ending image(s) if provided + if validation_image_end is not None: + if isinstance(validation_image_end, list): + image_end = [preprocess_image(img, sample_size) for img in validation_image_end] + end_video = torch.cat( + [img.unsqueeze(1).unsqueeze(0) for img in image_end], + dim=2, + ) + input_video[:, :, -len(end_video) :] = end_video + input_video_mask[:, :, -len(image_end) :] = 0 + else: + image_end = preprocess_image(validation_image_end, sample_size) + input_video[:, :, -1:] = image_end.unsqueeze(1).unsqueeze(0) + input_video_mask[:, :, -1:] = 0 + + elif validation_image_start is None: + # If no starting image is provided, initialize empty tensors + input_video = torch.zeros([1, 3, num_frames, sample_size[0], sample_size[1]]) + input_video_mask = torch.ones([1, 1, num_frames, sample_size[0], sample_size[1]]) * 255 + + return input_video, input_video_mask + + +# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid +def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): + tw = tgt_width + th = tgt_height + h, w = src + r = h / w + if r > (th / tw): + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +# Resize mask information in magvit +def resize_mask(mask, latent, process_first_frame_only=True): + latent_size = latent.size() + + if process_first_frame_only: + target_size = list(latent_size[2:]) + target_size[0] = 1 + first_frame_resized = F.interpolate( + mask[:, :, 0:1, :, :], size=target_size, mode="trilinear", align_corners=False + ) + + target_size = list(latent_size[2:]) + target_size[0] = target_size[0] - 1 + if target_size[0] != 0: + remaining_frames_resized = F.interpolate( + mask[:, :, 1:, :, :], size=target_size, mode="trilinear", align_corners=False + ) + resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2) + else: + resized_mask = first_frame_resized + else: + target_size = list(latent_size[2:]) + resized_mask = F.interpolate(mask, size=target_size, mode="trilinear", align_corners=False) + return resized_mask + + +## Add noise to reference video +def add_noise_to_reference_video(image, ratio=None, generator=None): + if ratio is None: + sigma = torch.normal(mean=-3.0, std=0.5, size=(image.shape[0],)).to(image.device) + sigma = torch.exp(sigma).to(image.dtype) + else: + sigma = torch.ones((image.shape[0],)).to(image.device, image.dtype) * ratio + + if generator is not None: + image_noise = ( + torch.randn(image.size(), generator=generator, dtype=image.dtype, device=image.device) + * sigma[:, None, None, None, None] + ) + else: + image_noise = torch.randn_like(image) * sigma[:, None, None, None, None] + image_noise = torch.where(image == -1, torch.zeros_like(image), image_noise) + image = image + image_noise + return image + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class EasyAnimateInpaintPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using EasyAnimate. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + EasyAnimate uses one text encoder [qwen2 vl](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) in V5.1. + + Args: + vae ([`AutoencoderKLMagvit`]): + Variational Auto-Encoder (VAE) Model to encode and decode video to and from latent representations. + text_encoder (`~transformers.Qwen2VLForConditionalGeneration`, `~transformers.BertModel` | None): + EasyAnimate uses [qwen2 vl](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) in V5.1. + tokenizer (`~transformers.Qwen2Tokenizer`, `~transformers.BertTokenizer` | None): + A `Qwen2Tokenizer` or `BertTokenizer` to tokenize text. + transformer ([`EasyAnimateTransformer3DModel`]): + The EasyAnimate model designed by EasyAnimate Team. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with EasyAnimate to denoise the encoded image latents. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKLMagvit, + text_encoder: Qwen2VLForConditionalGeneration | BertModel, + tokenizer: Qwen2Tokenizer | BertTokenizer, + transformer: EasyAnimateTransformer3DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + + self.enable_text_attention_mask = ( + self.transformer.config.enable_text_attention_mask + if getattr(self, "transformer", None) is not None + else True + ) + self.vae_spatial_compression_ratio = ( + self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 8 + ) + self.vae_temporal_compression_ratio = ( + self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 4 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_spatial_compression_ratio) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_spatial_compression_ratio, + do_normalize=False, + do_binarize=True, + do_convert_grayscale=True, + ) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio) + + # Copied from diffusers.pipelines.easyanimate.pipeline_easyanimate.EasyAnimatePipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: str | list[str] | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + max_sequence_length: int = 256, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + dtype (`torch.dtype`): + torch dtype + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the prompt. Required when `prompt_embeds` is passed directly. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly. + max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt. + """ + dtype = dtype or self.text_encoder.dtype + device = device or self.text_encoder.device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + if isinstance(prompt, str): + messages = [ + { + "role": "user", + "content": [{"type": "text", "text": prompt}], + } + ] + else: + messages = [ + { + "role": "user", + "content": [{"type": "text", "text": _prompt}], + } + for _prompt in prompt + ] + text = [ + self.tokenizer.apply_chat_template([m], tokenize=False, add_generation_prompt=True) for m in messages + ] + + text_inputs = self.tokenizer( + text=text, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_attention_mask=True, + padding_side="right", + return_tensors="pt", + ) + text_inputs = text_inputs.to(self.text_encoder.device) + + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + if self.enable_text_attention_mask: + # Inference: Generation of the output + prompt_embeds = self.text_encoder( + input_ids=text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True + ).hidden_states[-2] + else: + raise ValueError("LLM needs attention_mask") + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.to(device=device) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + if negative_prompt is not None and isinstance(negative_prompt, str): + messages = [ + { + "role": "user", + "content": [{"type": "text", "text": negative_prompt}], + } + ] + else: + messages = [ + { + "role": "user", + "content": [{"type": "text", "text": _negative_prompt}], + } + for _negative_prompt in negative_prompt + ] + text = [ + self.tokenizer.apply_chat_template([m], tokenize=False, add_generation_prompt=True) for m in messages + ] + + text_inputs = self.tokenizer( + text=text, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_attention_mask=True, + padding_side="right", + return_tensors="pt", + ) + text_inputs = text_inputs.to(self.text_encoder.device) + + text_input_ids = text_inputs.input_ids + negative_prompt_attention_mask = text_inputs.attention_mask + if self.enable_text_attention_mask: + # Inference: Generation of the output + negative_prompt_embeds = self.text_encoder( + input_ids=text_input_ids, + attention_mask=negative_prompt_attention_mask, + output_hidden_states=True, + ).hidden_states[-2] + else: + raise ValueError("LLM needs attention_mask") + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + negative_prompt_attention_mask = negative_prompt_attention_mask.to(device=device) + + return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def prepare_mask_latents( + self, + mask, + masked_image, + batch_size, + height, + width, + dtype, + device, + generator, + do_classifier_free_guidance, + noise_aug_strength, + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + if mask is not None: + mask = mask.to(device=device, dtype=dtype) + new_mask = [] + bs = 1 + for i in range(0, mask.shape[0], bs): + mask_bs = mask[i : i + bs] + mask_bs = self.vae.encode(mask_bs)[0] + mask_bs = mask_bs.mode() + new_mask.append(mask_bs) + mask = torch.cat(new_mask, dim=0) + mask = mask * self.vae.config.scaling_factor + + if masked_image is not None: + masked_image = masked_image.to(device=device, dtype=dtype) + if self.transformer.config.add_noise_in_inpaint_model: + masked_image = add_noise_to_reference_video( + masked_image, ratio=noise_aug_strength, generator=generator + ) + new_mask_pixel_values = [] + bs = 1 + for i in range(0, masked_image.shape[0], bs): + mask_pixel_values_bs = masked_image[i : i + bs] + mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0] + mask_pixel_values_bs = mask_pixel_values_bs.mode() + new_mask_pixel_values.append(mask_pixel_values_bs) + masked_image_latents = torch.cat(new_mask_pixel_values, dim=0) + masked_image_latents = masked_image_latents * self.vae.config.scaling_factor + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + else: + masked_image_latents = None + + return mask, masked_image_latents + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + num_frames, + dtype, + device, + generator, + latents=None, + video=None, + timestep=None, + is_strength_max=True, + return_noise=False, + return_video_latents=False, + ): + shape = ( + batch_size, + num_channels_latents, + (num_frames - 1) // self.vae_temporal_compression_ratio + 1, + height // self.vae_spatial_compression_ratio, + width // self.vae_spatial_compression_ratio, + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if return_video_latents or (latents is None and not is_strength_max): + video = video.to(device=device, dtype=dtype) + bs = 1 + new_video = [] + for i in range(0, video.shape[0], bs): + video_bs = video[i : i + bs] + video_bs = self.vae.encode(video_bs)[0] + video_bs = video_bs.sample() + new_video.append(video_bs) + video = torch.cat(new_video, dim=0) + video = video * self.vae.config.scaling_factor + + video_latents = video.repeat(batch_size // video.shape[0], 1, 1, 1, 1) + video_latents = video_latents.to(device=device, dtype=dtype) + + if latents is None: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # if strength is 1. then initialise the latents to noise, else initial to image + noise + if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): + latents = noise if is_strength_max else self.scheduler.scale_noise(video_latents, timestep, noise) + else: + latents = noise if is_strength_max else self.scheduler.add_noise(video_latents, noise, timestep) + # if pure noise then scale the initial latents by the Scheduler's init sigma + if hasattr(self.scheduler, "init_noise_sigma"): + latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents + else: + if hasattr(self.scheduler, "init_noise_sigma"): + noise = latents.to(device) + latents = noise * self.scheduler.init_noise_sigma + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + outputs = (latents,) + + if return_noise: + outputs += (noise,) + + if return_video_latents: + outputs += (video_latents,) + + return outputs + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + num_frames: int | None = 49, + video: torch.FloatTensor = None, + mask_video: torch.FloatTensor = None, + masked_video_latents: torch.FloatTensor = None, + height: int | None = 512, + width: int | None = 512, + num_inference_steps: int | None = 50, + guidance_scale: float | None = 5.0, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float | None = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + guidance_rescale: float = 0.0, + strength: float = 1.0, + noise_aug_strength: float = 0.0563, + timesteps: list[int] | None = None, + ): + r""" + The call function to the pipeline for generation with HunyuanDiT. + + Examples: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + num_frames (`int`, *optional*): + Length of the video to be generated in seconds. This parameter influences the number of frames and + continuity of generated content. + video (`torch.FloatTensor`, *optional*): + A tensor representing an input video, which can be modified depending on the prompts provided. + mask_video (`torch.FloatTensor`, *optional*): + A tensor to specify areas of the video to be masked (omitted from generation). + masked_video_latents (`torch.FloatTensor`, *optional*): + Latents from masked portions of the video, utilized during image generation. + height (`int`, *optional*): + The height in pixels of the generated image or video frames. + width (`int`, *optional*): + The width in pixels of the generated image or video frames. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image but slower + inference time. This parameter is modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 5.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is effective when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to exclude in image generation. If not defined, you need to provide + `negative_prompt_embeds`. This parameter is ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + A parameter defined in the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only applies to the + [`~schedulers.DDIMScheduler`] and is ignored in other schedulers. It adjusts noise level during the + inference process. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) for setting + random seeds which helps in making generation deterministic. + latents (`torch.Tensor`, *optional*): + A pre-computed latent representation which can be used to guide the generation process. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings, aiding in fine-tuning what should not be represented in the + outputs. If not provided, embeddings are generated from the `negative_prompt` argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask guiding the focus of the model on specific parts of the prompt text. Required when using + `prompt_embeds`. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the negative prompt, needed when `negative_prompt_embeds` are used. + output_type (`str`, *optional*, defaults to `"latent"`): + The output format of the generated image. Choose between `PIL.Image` and `np.array` to define how you + want the results to be formatted. + return_dict (`bool`, *optional*, defaults to `True`): + If set to `True`, a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] will be returned; + otherwise, a tuple containing the generated images and safety flags will be returned. + callback_on_step_end (`Callable[[int, int], None]`, `PipelineCallback`, `MultiPipelineCallbacks`, + *optional*): + A callback function (or a list of them) that will be executed at the end of each denoising step, + allowing for custom processing during generation. + callback_on_step_end_tensor_inputs (`list[str]`, *optional*): + Specifies which tensor inputs should be included in the callback function. If not defined, all tensor + inputs will be passed, facilitating enhanced logging or monitoring of the generation process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + Rescale parameter for adjusting noise configuration based on guidance rescale. Based on findings from + [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + strength (`float`, *optional*, defaults to 1.0): + Affects the overall styling or quality of the generated output. Values closer to 1 usually provide + direct adherence to prompts. + + Examples: + # Example usage of the function for generating images based on prompts. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + Returns either a structured output containing generated images and their metadata when `return_dict` is + `True`, or a simpler tuple, where the first element is a list of generated images and the second + element indicates if any of them contain "not-safe-for-work" (NSFW) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 0. default height and width + height = int(height // 16 * 16) + width = int(width // 16 * 16) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + callback_on_step_end_tensor_inputs, + ) + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + else: + dtype = self.transformer.dtype + + # 3. Encode input prompt + ( + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + device=device, + dtype=dtype, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + ) + + # 4. set timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps, mu=1 + ) + else: + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps + ) + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps=num_inference_steps, strength=strength, device=device + ) + + # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise + is_strength_max = strength == 1.0 + + if video is not None: + batch_size, channels, num_frames, height_video, width_video = video.shape + init_video = self.image_processor.preprocess( + video.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height_video, width_video), + height=height, + width=width, + ) + init_video = init_video.to(dtype=torch.float32) + init_video = init_video.reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4) + else: + init_video = None + + # Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + num_channels_transformer = self.transformer.config.in_channels + return_image_latents = num_channels_transformer == num_channels_latents + + # 5. Prepare latents. + latents_outputs = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + num_frames, + dtype, + device, + generator, + latents, + video=init_video, + timestep=latent_timestep, + is_strength_max=is_strength_max, + return_noise=True, + return_video_latents=return_image_latents, + ) + if return_image_latents: + latents, noise, image_latents = latents_outputs + else: + latents, noise = latents_outputs + + # 6. Prepare inpaint latents if it needs. + if mask_video is not None: + if (mask_video == 255).all(): + mask = torch.zeros_like(latents).to(device, dtype) + # Use zero latents if we want to t2v. + if self.transformer.config.resize_inpaint_mask_directly: + mask_latents = torch.zeros_like(latents)[:, :1].to(device, dtype) + else: + mask_latents = torch.zeros_like(latents).to(device, dtype) + masked_video_latents = torch.zeros_like(latents).to(device, dtype) + + mask_input = torch.cat([mask_latents] * 2) if self.do_classifier_free_guidance else mask_latents + masked_video_latents_input = ( + torch.cat([masked_video_latents] * 2) if self.do_classifier_free_guidance else masked_video_latents + ) + inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(dtype) + else: + # Prepare mask latent variables + batch_size, channels, num_frames, height_video, width_video = mask_video.shape + mask_condition = self.mask_processor.preprocess( + mask_video.permute(0, 2, 1, 3, 4).reshape( + batch_size * num_frames, channels, height_video, width_video + ), + height=height, + width=width, + ) + mask_condition = mask_condition.to(dtype=torch.float32) + mask_condition = mask_condition.reshape(batch_size, num_frames, channels, height, width).permute( + 0, 2, 1, 3, 4 + ) + + if num_channels_transformer != num_channels_latents: + mask_condition_tile = torch.tile(mask_condition, [1, 3, 1, 1, 1]) + if masked_video_latents is None: + masked_video = ( + init_video * (mask_condition_tile < 0.5) + + torch.ones_like(init_video) * (mask_condition_tile > 0.5) * -1 + ) + else: + masked_video = masked_video_latents + + if self.transformer.config.resize_inpaint_mask_directly: + _, masked_video_latents = self.prepare_mask_latents( + None, + masked_video, + batch_size, + height, + width, + dtype, + device, + generator, + self.do_classifier_free_guidance, + noise_aug_strength=noise_aug_strength, + ) + mask_latents = resize_mask( + 1 - mask_condition, masked_video_latents, self.vae.config.cache_mag_vae + ) + mask_latents = mask_latents.to(device, dtype) * self.vae.config.scaling_factor + else: + mask_latents, masked_video_latents = self.prepare_mask_latents( + mask_condition_tile, + masked_video, + batch_size, + height, + width, + dtype, + device, + generator, + self.do_classifier_free_guidance, + noise_aug_strength=noise_aug_strength, + ) + + mask_input = torch.cat([mask_latents] * 2) if self.do_classifier_free_guidance else mask_latents + masked_video_latents_input = ( + torch.cat([masked_video_latents] * 2) + if self.do_classifier_free_guidance + else masked_video_latents + ) + inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(dtype) + else: + inpaint_latents = None + + mask = torch.tile(mask_condition, [1, num_channels_latents, 1, 1, 1]) + mask = F.interpolate(mask, size=latents.size()[-3:], mode="trilinear", align_corners=True).to( + device, dtype + ) + else: + if num_channels_transformer != num_channels_latents: + mask = torch.zeros_like(latents).to(device, dtype) + if self.transformer.config.resize_inpaint_mask_directly: + mask_latents = torch.zeros_like(latents)[:, :1].to(device, dtype) + else: + mask_latents = torch.zeros_like(latents).to(device, dtype) + masked_video_latents = torch.zeros_like(latents).to(device, dtype) + + mask_input = torch.cat([mask_latents] * 2) if self.do_classifier_free_guidance else mask_latents + masked_video_latents_input = ( + torch.cat([masked_video_latents] * 2) if self.do_classifier_free_guidance else masked_video_latents + ) + inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(dtype) + else: + mask = torch.zeros_like(init_video[:, :1]) + mask = torch.tile(mask, [1, num_channels_latents, 1, 1, 1]) + mask = F.interpolate(mask, size=latents.size()[-3:], mode="trilinear", align_corners=True).to( + device, dtype + ) + + inpaint_latents = None + + # Check that sizes of mask, masked image and latents match + if num_channels_transformer != num_channels_latents: + num_channels_mask = mask_latents.shape[1] + num_channels_masked_image = masked_video_latents.shape[1] + if ( + num_channels_latents + num_channels_mask + num_channels_masked_image + != self.transformer.config.in_channels + ): + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.transformer`: {self.transformer.config} expects" + f" {self.transformer.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of" + " `pipeline.transformer` or your `mask_image` or `image` input." + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask]) + + # To latents.device + prompt_embeds = prompt_embeds.to(device=device) + prompt_attention_mask = prompt_attention_mask.to(device=device) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + if hasattr(self.scheduler, "scale_model_input"): + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input + t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to( + dtype=latent_model_input.dtype + ) + + # predict the noise residual + noise_pred = self.transformer( + latent_model_input, + t_expand, + encoder_hidden_states=prompt_embeds, + inpaint_latents=inpaint_latents, + return_dict=False, + )[0] + if noise_pred.size()[1] != self.vae.config.latent_channels: + noise_pred, _ = noise_pred.chunk(2, dim=1) + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://huggingface.co/papers/2305.08891 + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if num_channels_transformer == num_channels_latents: + init_latents_proper = image_latents + init_mask = mask + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): + init_latents_proper = self.scheduler.scale_noise( + init_latents_proper, torch.tensor([noise_timestep], noise) + ) + else: + init_latents_proper = self.scheduler.add_noise( + init_latents_proper, noise, torch.tensor([noise_timestep]) + ) + + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + latents = 1 / self.vae.config.scaling_factor * latents + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return EasyAnimatePipelineOutput(frames=video) diff --git a/diffusers/pipelines/easyanimate/pipeline_output.py b/diffusers/pipelines/easyanimate/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..4fd904ae7dfeb3e842cbb62844df643c7e5b9889 --- /dev/null +++ b/diffusers/pipelines/easyanimate/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import torch + +from diffusers.utils import BaseOutput + + +@dataclass +class EasyAnimatePipelineOutput(BaseOutput): + r""" + Output class for EasyAnimate pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or list[list[PIL.Image.Image]]): + list of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + frames: torch.Tensor diff --git a/diffusers/pipelines/flux2/__init__.py b/diffusers/pipelines/flux2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f6e1d520663038b617400375b32d38e9371f7268 --- /dev/null +++ b/diffusers/pipelines/flux2/__init__.py @@ -0,0 +1,49 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_additional_imports = {} +_import_structure = {"pipeline_output": ["Flux2PipelineOutput"]} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_flux2"] = ["Flux2Pipeline"] + _import_structure["pipeline_flux2_klein"] = ["Flux2KleinPipeline"] +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .pipeline_flux2 import Flux2Pipeline + from .pipeline_flux2_klein import Flux2KleinPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) + for name, value in _additional_imports.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/pipelines/flux2/pipeline_flux2.py b/diffusers/pipelines/flux2/pipeline_flux2.py new file mode 100644 index 0000000000000000000000000000000000000000..6cd0563fcc1905e746d9e945db5b2ab674c07e7a --- /dev/null +++ b/diffusers/pipelines/flux2/pipeline_flux2.py @@ -0,0 +1,1032 @@ +# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import PIL +import torch +from transformers import AutoProcessor, Mistral3ForConditionalGeneration + +from ...loaders import Flux2LoraLoaderMixin +from ...models import AutoencoderKLFlux2, Flux2Transformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .image_processor import Flux2ImageProcessor +from .pipeline_output import Flux2PipelineOutput +from .system_messages import SYSTEM_MESSAGE, SYSTEM_MESSAGE_UPSAMPLING_I2I, SYSTEM_MESSAGE_UPSAMPLING_T2I + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import Flux2Pipeline + + >>> pipe = Flux2Pipeline.from_pretrained("black-forest-labs/FLUX.2-dev", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + >>> prompt = "A cat holding a sign that says hello world" + >>> # Depending on the variant being used, the pipeline call will slightly vary. + >>> # Refer to the pipeline documentation for more details. + >>> image = pipe(prompt, num_inference_steps=50, guidance_scale=2.5).images[0] + >>> image.save("flux.png") + ``` +""" + +UPSAMPLING_MAX_IMAGE_SIZE = 768**2 + + +# Adapted from +# https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/text_encoder.py#L68 +def format_input( + prompts: list[str], + system_message: str = SYSTEM_MESSAGE, + images: list[PIL.Image.Image, list[list[PIL.Image.Image]]] | None = None, +): + """ + Format a batch of text prompts into the conversation format expected by apply_chat_template. Optionally, add images + to the input. + + Args: + prompts: List of text prompts + system_message: System message to use (default: CREATIVE_SYSTEM_MESSAGE) + images (optional): List of images to add to the input. + + Returns: + List of conversations, where each conversation is a list of message dicts + """ + # Remove [IMG] tokens from prompts to avoid Pixtral validation issues + # when truncation is enabled. The processor counts [IMG] tokens and fails + # if the count changes after truncation. + cleaned_txt = [prompt.replace("[IMG]", "") for prompt in prompts] + + if images is None or len(images) == 0: + return [ + [ + { + "role": "system", + "content": [{"type": "text", "text": system_message}], + }, + {"role": "user", "content": [{"type": "text", "text": prompt}]}, + ] + for prompt in cleaned_txt + ] + else: + assert len(images) == len(prompts), "Number of images must match number of prompts" + messages = [ + [ + { + "role": "system", + "content": [{"type": "text", "text": system_message}], + }, + ] + for _ in cleaned_txt + ] + + for i, (el, images) in enumerate(zip(messages, images)): + # optionally add the images per batch element. + if images is not None: + el.append( + { + "role": "user", + "content": [{"type": "image", "image": image_obj} for image_obj in images], + } + ) + # add the text. + el.append( + { + "role": "user", + "content": [{"type": "text", "text": cleaned_txt[i]}], + } + ) + + return messages + + +# Adapted from +# https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/text_encoder.py#L49C5-L66C19 +def _validate_and_process_images( + images: list[list[PIL.Image.Image]] | list[PIL.Image.Image], + image_processor: Flux2ImageProcessor, + upsampling_max_image_size: int, +) -> list[list[PIL.Image.Image]]: + # Simple validation: ensure it's a list of PIL images or list of lists of PIL images + if not images: + return [] + + # Check if it's a list of lists or a list of images + if isinstance(images[0], PIL.Image.Image): + # It's a list of images, convert to list of lists + images = [[im] for im in images] + + # potentially concatenate multiple images to reduce the size + images = [[image_processor.concatenate_images(img_i)] if len(img_i) > 1 else img_i for img_i in images] + + # cap the pixels + images = [ + [image_processor._resize_if_exceeds_area(img_i, upsampling_max_image_size) for img_i in img_i] + for img_i in images + ] + return images + + +# Taken from +# https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/sampling.py#L251 +def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float: + a1, b1 = 8.73809524e-05, 1.89833333 + a2, b2 = 0.00016927, 0.45666666 + + if image_seq_len > 4300: + mu = a2 * image_seq_len + b2 + return float(mu) + + m_200 = a2 * image_seq_len + b2 + m_10 = a1 * image_seq_len + b1 + + a = (m_200 - m_10) / 190.0 + b = m_200 - 200.0 * a + mu = a * num_steps + b + + return float(mu) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class Flux2Pipeline(DiffusionPipeline, Flux2LoraLoaderMixin): + r""" + The Flux2 pipeline for text-to-image generation. + + Reference: [https://bfl.ai/blog/flux-2](https://bfl.ai/blog/flux-2) + + Args: + transformer ([`Flux2Transformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLFlux2`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`Mistral3ForConditionalGeneration`]): + [Mistral3ForConditionalGeneration](https://huggingface.co/docs/transformers/en/model_doc/mistral3#transformers.Mistral3ForConditionalGeneration) + tokenizer (`AutoProcessor`): + Tokenizer of class + [PixtralProcessor](https://huggingface.co/docs/transformers/en/model_doc/pixtral#transformers.PixtralProcessor). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLFlux2, + text_encoder: Mistral3ForConditionalGeneration, + tokenizer: AutoProcessor, + transformer: Flux2Transformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + transformer=transformer, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = Flux2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.tokenizer_max_length = 512 + self.default_sample_size = 128 + + self.system_message = SYSTEM_MESSAGE + self.system_message_upsampling_t2i = SYSTEM_MESSAGE_UPSAMPLING_T2I + self.system_message_upsampling_i2i = SYSTEM_MESSAGE_UPSAMPLING_I2I + self.upsampling_max_image_size = UPSAMPLING_MAX_IMAGE_SIZE + + @staticmethod + def _get_mistral_3_small_prompt_embeds( + text_encoder: Mistral3ForConditionalGeneration, + tokenizer: AutoProcessor, + prompt: str | list[str], + dtype: torch.dtype | None = None, + device: torch.device | None = None, + max_sequence_length: int = 512, + system_message: str = SYSTEM_MESSAGE, + hidden_states_layers: list[int] = (10, 20, 30), + ): + dtype = text_encoder.dtype if dtype is None else dtype + device = text_encoder.device if device is None else device + + prompt = [prompt] if isinstance(prompt, str) else prompt + + # Format input messages + messages_batch = format_input(prompts=prompt, system_message=system_message) + + # Process all messages at once + inputs = tokenizer.apply_chat_template( + messages_batch, + add_generation_prompt=False, + tokenize=True, + return_dict=True, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=max_sequence_length, + ) + + # Move to device + input_ids = inputs["input_ids"].to(device) + attention_mask = inputs["attention_mask"].to(device) + + # Forward pass through the model + output = text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + use_cache=False, + ) + + # Only use outputs from intermediate layers and stack them + out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1) + out = out.to(dtype=dtype, device=device) + + batch_size, num_channels, seq_len, hidden_dim = out.shape + prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim) + + return prompt_embeds + + @staticmethod + def _prepare_text_ids( + x: torch.Tensor, # (B, L, D) or (L, D) + t_coord: torch.Tensor | None = None, + ): + B, L, _ = x.shape + out_ids = [] + + for i in range(B): + t = torch.arange(1) if t_coord is None else t_coord[i] + h = torch.arange(1) + w = torch.arange(1) + l = torch.arange(L) + + coords = torch.cartesian_prod(t, h, w, l) + out_ids.append(coords) + + return torch.stack(out_ids) + + @staticmethod + def _prepare_latent_ids( + latents: torch.Tensor, # (B, C, H, W) + ): + r""" + Generates 4D position coordinates (T, H, W, L) for latent tensors. + + Args: + latents (torch.Tensor): + Latent tensor of shape (B, C, H, W) + + Returns: + torch.Tensor: + Position IDs tensor of shape (B, H*W, 4) All batches share the same coordinate structure: T=0, + H=[0..H-1], W=[0..W-1], L=0 + """ + + batch_size, _, height, width = latents.shape + + t = torch.arange(1) # [0] - time dimension + h = torch.arange(height) + w = torch.arange(width) + l = torch.arange(1) # [0] - layer dimension + + # Create position IDs: (H*W, 4) + latent_ids = torch.cartesian_prod(t, h, w, l) + + # Expand to batch: (B, H*W, 4) + latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1) + + return latent_ids + + @staticmethod + def _prepare_image_ids( + image_latents: list[torch.Tensor], # [(1, C, H, W), (1, C, H, W), ...] + scale: int = 10, + ): + r""" + Generates 4D time-space coordinates (T, H, W, L) for a sequence of image latents. + + This function creates a unique coordinate for every pixel/patch across all input latent with different + dimensions. + + Args: + image_latents (list[torch.Tensor]): + A list of image latent feature tensors, typically of shape (C, H, W). + scale (int, optional): + A factor used to define the time separation (T-coordinate) between latents. T-coordinate for the i-th + latent is: 'scale + scale * i'. Defaults to 10. + + Returns: + torch.Tensor: + The combined coordinate tensor. Shape: (1, N_total, 4) Where N_total is the sum of (H * W) for all + input latents. + + Coordinate Components (Dimension 4): + - T (Time): The unique index indicating which latent image the coordinate belongs to. + - H (Height): The row index within that latent image. + - W (Width): The column index within that latent image. + - L (Seq. Length): A sequence length dimension, which is always fixed at 0 (size 1) + """ + + if not isinstance(image_latents, list): + raise ValueError(f"Expected `image_latents` to be a list, got {type(image_latents)}.") + + # create time offset for each reference image + t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))] + t_coords = [t.view(-1) for t in t_coords] + + image_latent_ids = [] + for x, t in zip(image_latents, t_coords): + x = x.squeeze(0) + _, height, width = x.shape + + x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1)) + image_latent_ids.append(x_ids) + + image_latent_ids = torch.cat(image_latent_ids, dim=0) + image_latent_ids = image_latent_ids.unsqueeze(0) + + return image_latent_ids + + @staticmethod + def _patchify_latents(latents): + batch_size, num_channels_latents, height, width = latents.shape + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 1, 3, 5, 2, 4) + latents = latents.reshape(batch_size, num_channels_latents * 4, height // 2, width // 2) + return latents + + @staticmethod + def _unpatchify_latents(latents): + batch_size, num_channels_latents, height, width = latents.shape + latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), 2, 2, height, width) + latents = latents.permute(0, 1, 4, 2, 5, 3) + latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), height * 2, width * 2) + return latents + + @staticmethod + def _pack_latents(latents): + """ + pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels) + """ + + batch_size, num_channels, height, width = latents.shape + latents = latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1) + + return latents + + @staticmethod + def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch.Tensor]: + """ + using position ids to scatter tokens into place + """ + x_list = [] + for data, pos in zip(x, x_ids): + _, ch = data.shape # noqa: F841 + h_ids = pos[:, 1].to(torch.int64) + w_ids = pos[:, 2].to(torch.int64) + + h = torch.max(h_ids) + 1 + w = torch.max(w_ids) + 1 + + flat_ids = h_ids * w + w_ids + + out = torch.zeros((h * w, ch), device=data.device, dtype=data.dtype) + out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data) + + # reshape from (H * W, C) to (H, W, C) and permute to (C, H, W) + + out = out.view(h, w, ch).permute(2, 0, 1) + x_list.append(out) + + return torch.stack(x_list, dim=0) + + def upsample_prompt( + self, + prompt: str | list[str], + images: list[PIL.Image.Image, list[list[PIL.Image.Image]]] = None, + temperature: float = 0.15, + device: torch.device = None, + ) -> list[str]: + prompt = [prompt] if isinstance(prompt, str) else prompt + device = self.text_encoder.device if device is None else device + + # Set system message based on whether images are provided + if images is None or len(images) == 0 or images[0] is None: + system_message = SYSTEM_MESSAGE_UPSAMPLING_T2I + else: + system_message = SYSTEM_MESSAGE_UPSAMPLING_I2I + + # Validate and process the input images + if images: + images = _validate_and_process_images(images, self.image_processor, self.upsampling_max_image_size) + + # Format input messages + messages_batch = format_input(prompts=prompt, system_message=system_message, images=images) + + # Process all messages at once + # with image processing a too short max length can throw an error in here. + inputs = self.tokenizer.apply_chat_template( + messages_batch, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=2048, + ) + + # Move to device + inputs["input_ids"] = inputs["input_ids"].to(device) + inputs["attention_mask"] = inputs["attention_mask"].to(device) + + if "pixel_values" in inputs: + inputs["pixel_values"] = inputs["pixel_values"].to(device, self.text_encoder.dtype) + + # Generate text using the model's generate method + generated_ids = self.text_encoder.generate( + **inputs, + max_new_tokens=512, + do_sample=True, + temperature=temperature, + use_cache=True, + ) + + # Decode only the newly generated tokens (skip input tokens) + # Extract only the generated portion + input_length = inputs["input_ids"].shape[1] + generated_tokens = generated_ids[:, input_length:] + + upsampled_prompt = self.tokenizer.tokenizer.batch_decode( + generated_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True + ) + return upsampled_prompt + + def encode_prompt( + self, + prompt: str | list[str], + device: torch.device | None = None, + num_images_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 512, + text_encoder_out_layers: tuple[int] = (10, 20, 30), + ): + device = device or self._execution_device + + if prompt is None: + prompt = "" + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_embeds = self._get_mistral_3_small_prompt_embeds( + text_encoder=self.text_encoder, + tokenizer=self.tokenizer, + prompt=prompt, + device=device, + max_sequence_length=max_sequence_length, + system_message=self.system_message, + hidden_states_layers=text_encoder_out_layers, + ) + + batch_size, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + text_ids = self._prepare_text_ids(prompt_embeds) + text_ids = text_ids.to(device) + return prompt_embeds, text_ids + + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if image.ndim != 4: + raise ValueError(f"Expected image dims 4, got {image.ndim}.") + + image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax") + image_latents = self._patchify_latents(image_latents) + + latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype) + latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps) + image_latents = (image_latents - latents_bn_mean) / latents_bn_std + + return image_latents + + def prepare_latents( + self, + batch_size, + num_latents_channels, + height, + width, + dtype, + device, + generator: torch.Generator, + latents: torch.Tensor | None = None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_latents_channels * 4, height // 2, width // 2) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + latent_ids = self._prepare_latent_ids(latents) + latent_ids = latent_ids.to(device) + + latents = self._pack_latents(latents) # [B, C, H, W] -> [B, H*W, C] + return latents, latent_ids + + def prepare_image_latents( + self, + images: list[torch.Tensor], + batch_size, + generator: torch.Generator, + device, + dtype, + ): + image_latents = [] + for image in images: + image = image.to(device=device, dtype=dtype) + imagge_latent = self._encode_vae_image(image=image, generator=generator) + image_latents.append(imagge_latent) # (1, 128, 32, 32) + + image_latent_ids = self._prepare_image_ids(image_latents) + + # Pack each latent and concatenate + packed_latents = [] + for latent in image_latents: + # latent: (1, 128, 32, 32) + packed = self._pack_latents(latent) # (1, 1024, 128) + packed = packed.squeeze(0) # (1024, 128) - remove batch dim + packed_latents.append(packed) + + # Concatenate all reference tokens along sequence dimension + image_latents = torch.cat(packed_latents, dim=0) # (N*1024, 128) + image_latents = image_latents.unsqueeze(0) # (1, N*1024, 128) + + image_latents = image_latents.repeat(batch_size, 1, 1) + image_latent_ids = image_latent_ids.repeat(batch_size, 1, 1) + image_latent_ids = image_latent_ids.to(device) + + return image_latents, image_latent_ids + + def check_inputs( + self, + prompt, + height, + width, + prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if ( + height is not None + and height % (self.vae_scale_factor * 2) != 0 + or width is not None + and width % (self.vae_scale_factor * 2) != 0 + ): + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: list[PIL.Image.Image, PIL.Image.Image] | None = None, + prompt: str | list[str] = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + sigmas: list[float] | None = None, + guidance_scale: float | None = 4.0, + num_images_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + text_encoder_out_layers: tuple[int] = (10, 20, 30), + caption_upsample_temperature: float = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + guidance_scale (`float`, *optional*, defaults to 1.0): + Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages + a model to generate images more aligned with `prompt` at the expense of lower image quality. + + Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to + the [paper](https://huggingface.co/papers/2210.03142) to learn more. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + text_encoder_out_layers (`tuple[int]`): + Layer indices to use in the `text_encoder` to derive the final prompt embeddings. + caption_upsample_temperature (`float`): + When specified, we will try to perform caption upsampling for potentially improved outputs. We + recommend setting it to 0.15 if caption upsampling is to be performed. + + Examples: + + Returns: + [`~pipelines.flux2.Flux2PipelineOutput`] or `tuple`: [`~pipelines.flux2.Flux2PipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + height=height, + width=width, + prompt_embeds=prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. prepare text embeddings + if caption_upsample_temperature: + prompt = self.upsample_prompt( + prompt, images=image, temperature=caption_upsample_temperature, device=device + ) + prompt_embeds, text_ids = self.encode_prompt( + prompt=prompt, + prompt_embeds=prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + text_encoder_out_layers=text_encoder_out_layers, + ) + + # 4. process images + if image is not None and not isinstance(image, list): + image = [image] + + condition_images = None + if image is not None: + for img in image: + self.image_processor.check_image_input(img) + + condition_images = [] + for img in image: + image_width, image_height = img.size + if image_width * image_height > 1024 * 1024: + img = self.image_processor._resize_to_target_area(img, 1024 * 1024) + image_width, image_height = img.size + + multiple_of = self.vae_scale_factor * 2 + image_width = (image_width // multiple_of) * multiple_of + image_height = (image_height // multiple_of) * multiple_of + img = self.image_processor.preprocess(img, height=image_height, width=image_width, resize_mode="crop") + condition_images.append(img) + height = height or image_height + width = width or image_width + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 5. prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, latent_ids = self.prepare_latents( + batch_size=batch_size * num_images_per_prompt, + num_latents_channels=num_channels_latents, + height=height, + width=width, + dtype=prompt_embeds.dtype, + device=device, + generator=generator, + latents=latents, + ) + + image_latents = None + image_latent_ids = None + if condition_images is not None: + image_latents, image_latent_ids = self.prepare_image_latents( + images=condition_images, + batch_size=batch_size * num_images_per_prompt, + generator=generator, + device=device, + dtype=self.vae.dtype, + ) + + # 6. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas: + sigmas = None + image_seq_len = latents.shape[1] + mu = compute_empirical_mu(image_seq_len=image_seq_len, num_steps=num_inference_steps) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + + # 7. Denoising loop + # We set the index here to remove DtoH sync, helpful especially during compilation. + # Check out more details here: https://github.com/huggingface/diffusers/pull/11696 + self.scheduler.set_begin_index(0) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + latent_model_input = latents.to(self.transformer.dtype) + latent_image_ids = latent_ids + + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1).to(self.transformer.dtype) + latent_image_ids = torch.cat([latent_ids, image_latent_ids], dim=1) + + noise_pred = self.transformer( + hidden_states=latent_model_input, # (B, image_seq_len, C) + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, # B, text_seq_len, 4 + img_ids=latent_image_ids, # B, image_seq_len, 4 + joint_attention_kwargs=self.attention_kwargs, + return_dict=False, + )[0] + + noise_pred = noise_pred[:, : latents.size(1) :] + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents_with_ids(latents, latent_ids) + + latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype) + latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to( + latents.device, latents.dtype + ) + latents = latents * latents_bn_std + latents_bn_mean + latents = self._unpatchify_latents(latents) + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return Flux2PipelineOutput(images=image) diff --git a/diffusers/pipelines/flux2/pipeline_flux2_klein.py b/diffusers/pipelines/flux2/pipeline_flux2_klein.py new file mode 100644 index 0000000000000000000000000000000000000000..936d2c3804ab840b13db9b5b56415cbef728ee0a --- /dev/null +++ b/diffusers/pipelines/flux2/pipeline_flux2_klein.py @@ -0,0 +1,918 @@ +# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import PIL +import torch +from transformers import Qwen2TokenizerFast, Qwen3ForCausalLM + +from ...loaders import Flux2LoraLoaderMixin +from ...models import AutoencoderKLFlux2, Flux2Transformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .image_processor import Flux2ImageProcessor +from .pipeline_output import Flux2PipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import Flux2KleinPipeline + + >>> pipe = Flux2KleinPipeline.from_pretrained( + ... "black-forest-labs/FLUX.2-klein-base-9B", torch_dtype=torch.bfloat16 + ... ) + >>> pipe.to("cuda") + >>> prompt = "A cat holding a sign that says hello world" + >>> # Depending on the variant being used, the pipeline call will slightly vary. + >>> # Refer to the pipeline documentation for more details. + >>> image = pipe(prompt, num_inference_steps=50, guidance_scale=4.0).images[0] + >>> image.save("flux2_output.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux2.pipeline_flux2.compute_empirical_mu +def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float: + a1, b1 = 8.73809524e-05, 1.89833333 + a2, b2 = 0.00016927, 0.45666666 + + if image_seq_len > 4300: + mu = a2 * image_seq_len + b2 + return float(mu) + + m_200 = a2 * image_seq_len + b2 + m_10 = a1 * image_seq_len + b1 + + a = (m_200 - m_10) / 190.0 + b = m_200 - 200.0 * a + mu = a * num_steps + b + + return float(mu) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class Flux2KleinPipeline(DiffusionPipeline, Flux2LoraLoaderMixin): + r""" + The Flux2 Klein pipeline for text-to-image generation. + + Reference: + [https://bfl.ai/blog/flux2-klein-towards-interactive-visual-intelligence](https://bfl.ai/blog/flux2-klein-towards-interactive-visual-intelligence) + + Args: + transformer ([`Flux2Transformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLFlux2`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`Qwen3ForCausalLM`]): + [Qwen3ForCausalLM](https://huggingface.co/docs/transformers/en/model_doc/qwen3#transformers.Qwen3ForCausalLM) + tokenizer (`Qwen2TokenizerFast`): + Tokenizer of class + [Qwen2TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/qwen2#transformers.Qwen2TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLFlux2, + text_encoder: Qwen3ForCausalLM, + tokenizer: Qwen2TokenizerFast, + transformer: Flux2Transformer2DModel, + is_distilled: bool = False, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + transformer=transformer, + ) + + self.register_to_config(is_distilled=is_distilled) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = Flux2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.tokenizer_max_length = 512 + self.default_sample_size = 128 + + @staticmethod + def _get_qwen3_prompt_embeds( + text_encoder: Qwen3ForCausalLM, + tokenizer: Qwen2TokenizerFast, + prompt: str | list[str], + dtype: torch.dtype | None = None, + device: torch.device | None = None, + max_sequence_length: int = 512, + hidden_states_layers: list[int] = (9, 18, 27), + ): + dtype = text_encoder.dtype if dtype is None else dtype + device = text_encoder.device if device is None else device + + prompt = [prompt] if isinstance(prompt, str) else prompt + + all_input_ids = [] + all_attention_masks = [] + + for single_prompt in prompt: + messages = [{"role": "user", "content": single_prompt}] + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + inputs = tokenizer( + text, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=max_sequence_length, + ) + + all_input_ids.append(inputs["input_ids"]) + all_attention_masks.append(inputs["attention_mask"]) + + input_ids = torch.cat(all_input_ids, dim=0).to(device) + attention_mask = torch.cat(all_attention_masks, dim=0).to(device) + + # Forward pass through the model + output = text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + use_cache=False, + ) + + # Only use outputs from intermediate layers and stack them + out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1) + out = out.to(dtype=dtype, device=device) + + batch_size, num_channels, seq_len, hidden_dim = out.shape + prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim) + + return prompt_embeds + + @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_text_ids + def _prepare_text_ids( + x: torch.Tensor, # (B, L, D) or (L, D) + t_coord: torch.Tensor | None = None, + ): + B, L, _ = x.shape + out_ids = [] + + for i in range(B): + t = torch.arange(1) if t_coord is None else t_coord[i] + h = torch.arange(1) + w = torch.arange(1) + l = torch.arange(L) + + coords = torch.cartesian_prod(t, h, w, l) + out_ids.append(coords) + + return torch.stack(out_ids) + + @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_latent_ids + def _prepare_latent_ids( + latents: torch.Tensor, # (B, C, H, W) + ): + r""" + Generates 4D position coordinates (T, H, W, L) for latent tensors. + + Args: + latents (torch.Tensor): + Latent tensor of shape (B, C, H, W) + + Returns: + torch.Tensor: + Position IDs tensor of shape (B, H*W, 4) All batches share the same coordinate structure: T=0, + H=[0..H-1], W=[0..W-1], L=0 + """ + + batch_size, _, height, width = latents.shape + + t = torch.arange(1) # [0] - time dimension + h = torch.arange(height) + w = torch.arange(width) + l = torch.arange(1) # [0] - layer dimension + + # Create position IDs: (H*W, 4) + latent_ids = torch.cartesian_prod(t, h, w, l) + + # Expand to batch: (B, H*W, 4) + latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1) + + return latent_ids + + @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_image_ids + def _prepare_image_ids( + image_latents: list[torch.Tensor], # [(1, C, H, W), (1, C, H, W), ...] + scale: int = 10, + ): + r""" + Generates 4D time-space coordinates (T, H, W, L) for a sequence of image latents. + + This function creates a unique coordinate for every pixel/patch across all input latent with different + dimensions. + + Args: + image_latents (list[torch.Tensor]): + A list of image latent feature tensors, typically of shape (C, H, W). + scale (int, optional): + A factor used to define the time separation (T-coordinate) between latents. T-coordinate for the i-th + latent is: 'scale + scale * i'. Defaults to 10. + + Returns: + torch.Tensor: + The combined coordinate tensor. Shape: (1, N_total, 4) Where N_total is the sum of (H * W) for all + input latents. + + Coordinate Components (Dimension 4): + - T (Time): The unique index indicating which latent image the coordinate belongs to. + - H (Height): The row index within that latent image. + - W (Width): The column index within that latent image. + - L (Seq. Length): A sequence length dimension, which is always fixed at 0 (size 1) + """ + + if not isinstance(image_latents, list): + raise ValueError(f"Expected `image_latents` to be a list, got {type(image_latents)}.") + + # create time offset for each reference image + t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))] + t_coords = [t.view(-1) for t in t_coords] + + image_latent_ids = [] + for x, t in zip(image_latents, t_coords): + x = x.squeeze(0) + _, height, width = x.shape + + x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1)) + image_latent_ids.append(x_ids) + + image_latent_ids = torch.cat(image_latent_ids, dim=0) + image_latent_ids = image_latent_ids.unsqueeze(0) + + return image_latent_ids + + @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._patchify_latents + def _patchify_latents(latents): + batch_size, num_channels_latents, height, width = latents.shape + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 1, 3, 5, 2, 4) + latents = latents.reshape(batch_size, num_channels_latents * 4, height // 2, width // 2) + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._unpatchify_latents + def _unpatchify_latents(latents): + batch_size, num_channels_latents, height, width = latents.shape + latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), 2, 2, height, width) + latents = latents.permute(0, 1, 4, 2, 5, 3) + latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), height * 2, width * 2) + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._pack_latents + def _pack_latents(latents): + """ + pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels) + """ + + batch_size, num_channels, height, width = latents.shape + latents = latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._unpack_latents_with_ids + def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch.Tensor]: + """ + using position ids to scatter tokens into place + """ + x_list = [] + for data, pos in zip(x, x_ids): + _, ch = data.shape # noqa: F841 + h_ids = pos[:, 1].to(torch.int64) + w_ids = pos[:, 2].to(torch.int64) + + h = torch.max(h_ids) + 1 + w = torch.max(w_ids) + 1 + + flat_ids = h_ids * w + w_ids + + out = torch.zeros((h * w, ch), device=data.device, dtype=data.dtype) + out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data) + + # reshape from (H * W, C) to (H, W, C) and permute to (C, H, W) + + out = out.view(h, w, ch).permute(2, 0, 1) + x_list.append(out) + + return torch.stack(x_list, dim=0) + + def encode_prompt( + self, + prompt: str | list[str], + device: torch.device | None = None, + num_images_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 512, + text_encoder_out_layers: tuple[int] = (9, 18, 27), + ): + device = device or self._execution_device + + if prompt is None: + prompt = "" + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_embeds = self._get_qwen3_prompt_embeds( + text_encoder=self.text_encoder, + tokenizer=self.tokenizer, + prompt=prompt, + device=device, + max_sequence_length=max_sequence_length, + hidden_states_layers=text_encoder_out_layers, + ) + + batch_size, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + text_ids = self._prepare_text_ids(prompt_embeds) + text_ids = text_ids.to(device) + return prompt_embeds, text_ids + + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if image.ndim != 4: + raise ValueError(f"Expected image dims 4, got {image.ndim}.") + + image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax") + image_latents = self._patchify_latents(image_latents) + + latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype) + latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps) + image_latents = (image_latents - latents_bn_mean) / latents_bn_std + + return image_latents + + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline.prepare_latents + def prepare_latents( + self, + batch_size, + num_latents_channels, + height, + width, + dtype, + device, + generator: torch.Generator, + latents: torch.Tensor | None = None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_latents_channels * 4, height // 2, width // 2) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + latent_ids = self._prepare_latent_ids(latents) + latent_ids = latent_ids.to(device) + + latents = self._pack_latents(latents) # [B, C, H, W] -> [B, H*W, C] + return latents, latent_ids + + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline.prepare_image_latents + def prepare_image_latents( + self, + images: list[torch.Tensor], + batch_size, + generator: torch.Generator, + device, + dtype, + ): + image_latents = [] + for image in images: + image = image.to(device=device, dtype=dtype) + imagge_latent = self._encode_vae_image(image=image, generator=generator) + image_latents.append(imagge_latent) # (1, 128, 32, 32) + + image_latent_ids = self._prepare_image_ids(image_latents) + + # Pack each latent and concatenate + packed_latents = [] + for latent in image_latents: + # latent: (1, 128, 32, 32) + packed = self._pack_latents(latent) # (1, 1024, 128) + packed = packed.squeeze(0) # (1024, 128) - remove batch dim + packed_latents.append(packed) + + # Concatenate all reference tokens along sequence dimension + image_latents = torch.cat(packed_latents, dim=0) # (N*1024, 128) + image_latents = image_latents.unsqueeze(0) # (1, N*1024, 128) + + image_latents = image_latents.repeat(batch_size, 1, 1) + image_latent_ids = image_latent_ids.repeat(batch_size, 1, 1) + image_latent_ids = image_latent_ids.to(device) + + return image_latents, image_latent_ids + + def check_inputs( + self, + prompt, + height, + width, + prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + guidance_scale=None, + ): + if ( + height is not None + and height % (self.vae_scale_factor * 2) != 0 + or width is not None + and width % (self.vae_scale_factor * 2) != 0 + ): + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if guidance_scale > 1.0 and self.config.is_distilled: + logger.warning(f"Guidance scale {guidance_scale} is ignored for step-wise distilled models.") + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and not self.config.is_distilled + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: list[PIL.Image.Image] | PIL.Image.Image | None = None, + prompt: str | list[str] = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + sigmas: list[float] | None = None, + guidance_scale: float = 4.0, + num_images_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: str | list[str] | None = None, + output_type: str = "pil", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int, dict], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + text_encoder_out_layers: tuple[int] = (9, 18, 27), + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. For step-wise distilled models, + `guidance_scale` is ignored. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Note that "" is used as the negative prompt in this pipeline. + If not provided, will be generated from "". + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + text_encoder_out_layers (`tuple[int]`): + Layer indices to use in the `text_encoder` to derive the final prompt embeddings. + + Examples: + + Returns: + [`~pipelines.flux2.Flux2PipelineOutput`] or `tuple`: [`~pipelines.flux2.Flux2PipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + height=height, + width=width, + prompt_embeds=prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + guidance_scale=guidance_scale, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. prepare text embeddings + prompt_embeds, text_ids = self.encode_prompt( + prompt=prompt, + prompt_embeds=prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + text_encoder_out_layers=text_encoder_out_layers, + ) + + if self.do_classifier_free_guidance: + negative_prompt = "" + if prompt is not None and isinstance(prompt, list): + negative_prompt = [negative_prompt] * len(prompt) + negative_prompt_embeds, negative_text_ids = self.encode_prompt( + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + text_encoder_out_layers=text_encoder_out_layers, + ) + + # 4. process images + if image is not None and not isinstance(image, list): + image = [image] + + condition_images = None + if image is not None: + for img in image: + self.image_processor.check_image_input(img) + + condition_images = [] + for img in image: + image_width, image_height = img.size + if image_width * image_height > 1024 * 1024: + img = self.image_processor._resize_to_target_area(img, 1024 * 1024) + image_width, image_height = img.size + + multiple_of = self.vae_scale_factor * 2 + image_width = (image_width // multiple_of) * multiple_of + image_height = (image_height // multiple_of) * multiple_of + img = self.image_processor.preprocess(img, height=image_height, width=image_width, resize_mode="crop") + condition_images.append(img) + height = height or image_height + width = width or image_width + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 5. prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, latent_ids = self.prepare_latents( + batch_size=batch_size * num_images_per_prompt, + num_latents_channels=num_channels_latents, + height=height, + width=width, + dtype=prompt_embeds.dtype, + device=device, + generator=generator, + latents=latents, + ) + + image_latents = None + image_latent_ids = None + if condition_images is not None: + image_latents, image_latent_ids = self.prepare_image_latents( + images=condition_images, + batch_size=batch_size * num_images_per_prompt, + generator=generator, + device=device, + dtype=self.vae.dtype, + ) + + # 6. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas: + sigmas = None + image_seq_len = latents.shape[1] + mu = compute_empirical_mu(image_seq_len=image_seq_len, num_steps=num_inference_steps) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 7. Denoising loop + # We set the index here to remove DtoH sync, helpful especially during compilation. + # Check out more details here: https://github.com/huggingface/diffusers/pull/11696 + self.scheduler.set_begin_index(0) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + latent_model_input = latents.to(self.transformer.dtype) + latent_image_ids = latent_ids + + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1).to(self.transformer.dtype) + latent_image_ids = torch.cat([latent_ids, image_latent_ids], dim=1) + + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, # (B, image_seq_len, C) + timestep=timestep / 1000, + guidance=None, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, # B, text_seq_len, 4 + img_ids=latent_image_ids, # B, image_seq_len, 4 + joint_attention_kwargs=self.attention_kwargs, + return_dict=False, + )[0] + + noise_pred = noise_pred[:, : latents.size(1) :] + + if self.do_classifier_free_guidance: + with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=None, + encoder_hidden_states=negative_prompt_embeds, + txt_ids=negative_text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self._attention_kwargs, + return_dict=False, + )[0] + neg_noise_pred = neg_noise_pred[:, : latents.size(1) :] + noise_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + latents = self._unpack_latents_with_ids(latents, latent_ids) + + latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype) + latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to( + latents.device, latents.dtype + ) + latents = latents * latents_bn_std + latents_bn_mean + latents = self._unpatchify_latents(latents) + if output_type == "latent": + image = latents + else: + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return Flux2PipelineOutput(images=image) diff --git a/diffusers/pipelines/flux2/pipeline_output.py b/diffusers/pipelines/flux2/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..34ae9b574453349fdf42c2e98680f1cee315fa0a --- /dev/null +++ b/diffusers/pipelines/flux2/pipeline_output.py @@ -0,0 +1,22 @@ +from dataclasses import dataclass + +import numpy as np +import PIL.Image + +from ...utils import BaseOutput + + +@dataclass +class Flux2PipelineOutput(BaseOutput): + """ + Output class for Flux2 image generation pipelines. + + Args: + images (`list[PIL.Image.Image]` or `torch.Tensor` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array or torch tensor of shape `(batch_size, + height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion + pipeline. Torch tensors can represent either the denoised images or the intermediate latents ready to be + passed to the decoder. + """ + + images: list[PIL.Image.Image, np.ndarray] diff --git a/diffusers/pipelines/flux2/system_messages.py b/diffusers/pipelines/flux2/system_messages.py new file mode 100644 index 0000000000000000000000000000000000000000..ecdb1371f0d49afa94788e0f02f89837469a706d --- /dev/null +++ b/diffusers/pipelines/flux2/system_messages.py @@ -0,0 +1,33 @@ +# docstyle-ignore +""" +These system prompts come from: +https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/system_messages.py#L54 +""" + +# docstyle-ignore +SYSTEM_MESSAGE = """You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object +attribution and actions without speculation.""" + +# docstyle-ignore +SYSTEM_MESSAGE_UPSAMPLING_T2I = """You are an expert prompt engineer for FLUX.2 by Black Forest Labs. Rewrite user prompts to be more descriptive while strictly preserving their core subject and intent. + +Guidelines: +1. Structure: Keep structured inputs structured (enhance within fields). Convert natural language to detailed paragraphs. +2. Details: Add concrete visual specifics - form, scale, textures, materials, lighting (quality, direction, color), shadows, spatial relationships, and environmental context. +3. Text in Images: Put ALL text in quotation marks, matching the prompt's language. Always provide explicit quoted text for objects that would contain text in reality (signs, labels, screens, etc.) - without it, the model generates gibberish. + +Output only the revised prompt and nothing else.""" + +# docstyle-ignore +SYSTEM_MESSAGE_UPSAMPLING_I2I = """You are FLUX.2 by Black Forest Labs, an image-editing expert. You convert editing requests into one concise instruction (50-80 words, ~30 for brief requests). + +Rules: +- Single instruction only, no commentary +- Use clear, analytical language (avoid "whimsical," "cascading," etc.) +- Specify what changes AND what stays the same (face, lighting, composition) +- Reference actual image elements +- Turn negatives into positives ("don't change X" → "keep X") +- Make abstractions concrete ("futuristic" → "glowing cyan neon, metallic panels") +- Keep content PG-13 + +Output only the final instruction in plain text and nothing else.""" diff --git a/diffusers/pipelines/hidream_image/__init__.py b/diffusers/pipelines/hidream_image/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..498df900e68b70920e5fc49764eb6bdcd2dc3354 --- /dev/null +++ b/diffusers/pipelines/hidream_image/__init__.py @@ -0,0 +1,47 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_additional_imports = {} +_import_structure = {"pipeline_output": ["HiDreamImagePipelineOutput"]} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_hidream_image"] = ["HiDreamImagePipeline"] +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .pipeline_hidream_image import HiDreamImagePipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) + for name, value in _additional_imports.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/pipelines/hidream_image/pipeline_hidream_image.py b/diffusers/pipelines/hidream_image/pipeline_hidream_image.py new file mode 100644 index 0000000000000000000000000000000000000000..8e5e078cc2af6ce67489f21d483cf9b283c21440 --- /dev/null +++ b/diffusers/pipelines/hidream_image/pipeline_hidream_image.py @@ -0,0 +1,1053 @@ +# Copyright 2025 HiDream-ai Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from typing import Any, Callable + +import torch +from transformers import ( + CLIPTextModelWithProjection, + CLIPTokenizer, + LlamaForCausalLM, + PreTrainedTokenizerFast, + T5EncoderModel, + T5Tokenizer, +) + +from ...image_processor import VaeImageProcessor +from ...loaders import HiDreamImageLoraLoaderMixin +from ...models import AutoencoderKL, HiDreamImageTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler, UniPCMultistepScheduler +from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import HiDreamImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from transformers import AutoTokenizer, LlamaForCausalLM + >>> from diffusers import HiDreamImagePipeline + + >>> tokenizer_4 = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct") + >>> text_encoder_4 = LlamaForCausalLM.from_pretrained( + ... "meta-llama/Meta-Llama-3.1-8B-Instruct", + ... output_hidden_states=True, + ... output_attentions=True, + ... torch_dtype=torch.bfloat16, + ... ) + + >>> pipe = HiDreamImagePipeline.from_pretrained( + ... "HiDream-ai/HiDream-I1-Full", + ... tokenizer_4=tokenizer_4, + ... text_encoder_4=text_encoder_4, + ... torch_dtype=torch.bfloat16, + ... ) + >>> pipe.enable_model_cpu_offload() + + >>> image = pipe( + ... 'A cat holding a sign that says "Hi-Dreams.ai".', + ... height=1024, + ... width=1024, + ... guidance_scale=5.0, + ... num_inference_steps=50, + ... generator=torch.Generator("cuda").manual_seed(0), + ... ).images[0] + >>> image.save("output.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class HiDreamImagePipeline(DiffusionPipeline, HiDreamImageLoraLoaderMixin): + model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->text_encoder_4->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds_t5", "prompt_embeds_llama3", "pooled_prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer_2: CLIPTokenizer, + text_encoder_3: T5EncoderModel, + tokenizer_3: T5Tokenizer, + text_encoder_4: LlamaForCausalLM, + tokenizer_4: PreTrainedTokenizerFast, + transformer: HiDreamImageTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + text_encoder_3=text_encoder_3, + text_encoder_4=text_encoder_4, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + tokenizer_3=tokenizer_3, + tokenizer_4=tokenizer_4, + scheduler=scheduler, + transformer=transformer, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + # HiDreamImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.default_sample_size = 128 + if getattr(self, "tokenizer_4", None) is not None: + self.tokenizer_4.pad_token = self.tokenizer_4.eos_token + + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + max_sequence_length: int = 128, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder_3.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + + text_inputs = self.tokenizer_3( + prompt, + padding="max_length", + max_length=min(max_sequence_length, self.tokenizer_3.model_max_length), + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + attention_mask = text_inputs.attention_mask + untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_3.batch_decode( + untruncated_ids[:, min(max_sequence_length, self.tokenizer_3.model_max_length) - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {min(max_sequence_length, self.tokenizer_3.model_max_length)} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_3(text_input_ids.to(device), attention_mask=attention_mask.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + return prompt_embeds + + def _get_clip_prompt_embeds( + self, + tokenizer, + text_encoder, + prompt: str | list[str], + max_sequence_length: int = 128, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=min(max_sequence_length, 218), + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = tokenizer.batch_decode(untruncated_ids[:, 218 - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {218} tokens: {removed_text}" + ) + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + return prompt_embeds + + def _get_llama3_prompt_embeds( + self, + prompt: str | list[str] = None, + max_sequence_length: int = 128, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder_4.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + + text_inputs = self.tokenizer_4( + prompt, + padding="max_length", + max_length=min(max_sequence_length, self.tokenizer_4.model_max_length), + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + attention_mask = text_inputs.attention_mask + untruncated_ids = self.tokenizer_4(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_4.batch_decode( + untruncated_ids[:, min(max_sequence_length, self.tokenizer_4.model_max_length) - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {min(max_sequence_length, self.tokenizer_4.model_max_length)} tokens: {removed_text}" + ) + + outputs = self.text_encoder_4( + text_input_ids.to(device), + attention_mask=attention_mask.to(device), + output_hidden_states=True, + output_attentions=True, + ) + + prompt_embeds = outputs.hidden_states[1:] + prompt_embeds = torch.stack(prompt_embeds, dim=0) + return prompt_embeds + + def encode_prompt( + self, + prompt: str | list[str] | None = None, + prompt_2: str | list[str] | None = None, + prompt_3: str | list[str] | None = None, + prompt_4: str | list[str] | None = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: str | list[str] | None = None, + negative_prompt_2: str | list[str] | None = None, + negative_prompt_3: str | list[str] | None = None, + negative_prompt_4: str | list[str] | None = None, + prompt_embeds_t5: list[torch.FloatTensor] | None = None, + prompt_embeds_llama3: list[torch.FloatTensor] | None = None, + negative_prompt_embeds_t5: list[torch.FloatTensor] | None = None, + negative_prompt_embeds_llama3: list[torch.FloatTensor] | None = None, + pooled_prompt_embeds: torch.FloatTensor | None = None, + negative_pooled_prompt_embeds: torch.FloatTensor | None = None, + max_sequence_length: int = 128, + lora_scale: float | None = None, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = pooled_prompt_embeds.shape[0] + + device = device or self._execution_device + + if pooled_prompt_embeds is None: + pooled_prompt_embeds_1 = self._get_clip_prompt_embeds( + self.tokenizer, self.text_encoder, prompt, max_sequence_length, device, dtype + ) + + if do_classifier_free_guidance and negative_pooled_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if len(negative_prompt) > 1 and len(negative_prompt) != batch_size: + raise ValueError(f"negative_prompt must be of length 1 or {batch_size}") + + negative_pooled_prompt_embeds_1 = self._get_clip_prompt_embeds( + self.tokenizer, self.text_encoder, negative_prompt, max_sequence_length, device, dtype + ) + + if negative_pooled_prompt_embeds_1.shape[0] == 1 and batch_size > 1: + negative_pooled_prompt_embeds_1 = negative_pooled_prompt_embeds_1.repeat(batch_size, 1) + + if pooled_prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + if len(prompt_2) > 1 and len(prompt_2) != batch_size: + raise ValueError(f"prompt_2 must be of length 1 or {batch_size}") + + pooled_prompt_embeds_2 = self._get_clip_prompt_embeds( + self.tokenizer_2, self.text_encoder_2, prompt_2, max_sequence_length, device, dtype + ) + + if pooled_prompt_embeds_2.shape[0] == 1 and batch_size > 1: + pooled_prompt_embeds_2 = pooled_prompt_embeds_2.repeat(batch_size, 1) + + if do_classifier_free_guidance and negative_pooled_prompt_embeds is None: + negative_prompt_2 = negative_prompt_2 or negative_prompt + negative_prompt_2 = [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + + if len(negative_prompt_2) > 1 and len(negative_prompt_2) != batch_size: + raise ValueError(f"negative_prompt_2 must be of length 1 or {batch_size}") + + negative_pooled_prompt_embeds_2 = self._get_clip_prompt_embeds( + self.tokenizer_2, self.text_encoder_2, negative_prompt_2, max_sequence_length, device, dtype + ) + + if negative_pooled_prompt_embeds_2.shape[0] == 1 and batch_size > 1: + negative_pooled_prompt_embeds_2 = negative_pooled_prompt_embeds_2.repeat(batch_size, 1) + + if pooled_prompt_embeds is None: + pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_1, pooled_prompt_embeds_2], dim=-1) + + if do_classifier_free_guidance and negative_pooled_prompt_embeds is None: + negative_pooled_prompt_embeds = torch.cat( + [negative_pooled_prompt_embeds_1, negative_pooled_prompt_embeds_2], dim=-1 + ) + + if prompt_embeds_t5 is None: + prompt_3 = prompt_3 or prompt + prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3 + + if len(prompt_3) > 1 and len(prompt_3) != batch_size: + raise ValueError(f"prompt_3 must be of length 1 or {batch_size}") + + prompt_embeds_t5 = self._get_t5_prompt_embeds(prompt_3, max_sequence_length, device, dtype) + + if prompt_embeds_t5.shape[0] == 1 and batch_size > 1: + prompt_embeds_t5 = prompt_embeds_t5.repeat(batch_size, 1, 1) + + if do_classifier_free_guidance and negative_prompt_embeds_t5 is None: + negative_prompt_3 = negative_prompt_3 or negative_prompt + negative_prompt_3 = [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3 + + if len(negative_prompt_3) > 1 and len(negative_prompt_3) != batch_size: + raise ValueError(f"negative_prompt_3 must be of length 1 or {batch_size}") + + negative_prompt_embeds_t5 = self._get_t5_prompt_embeds( + negative_prompt_3, max_sequence_length, device, dtype + ) + + if negative_prompt_embeds_t5.shape[0] == 1 and batch_size > 1: + negative_prompt_embeds_t5 = negative_prompt_embeds_t5.repeat(batch_size, 1, 1) + + if prompt_embeds_llama3 is None: + prompt_4 = prompt_4 or prompt + prompt_4 = [prompt_4] if isinstance(prompt_4, str) else prompt_4 + + if len(prompt_4) > 1 and len(prompt_4) != batch_size: + raise ValueError(f"prompt_4 must be of length 1 or {batch_size}") + + prompt_embeds_llama3 = self._get_llama3_prompt_embeds(prompt_4, max_sequence_length, device, dtype) + + if prompt_embeds_llama3.shape[0] == 1 and batch_size > 1: + prompt_embeds_llama3 = prompt_embeds_llama3.repeat(1, batch_size, 1, 1) + + if do_classifier_free_guidance and negative_prompt_embeds_llama3 is None: + negative_prompt_4 = negative_prompt_4 or negative_prompt + negative_prompt_4 = [negative_prompt_4] if isinstance(negative_prompt_4, str) else negative_prompt_4 + + if len(negative_prompt_4) > 1 and len(negative_prompt_4) != batch_size: + raise ValueError(f"negative_prompt_4 must be of length 1 or {batch_size}") + + negative_prompt_embeds_llama3 = self._get_llama3_prompt_embeds( + negative_prompt_4, max_sequence_length, device, dtype + ) + + if negative_prompt_embeds_llama3.shape[0] == 1 and batch_size > 1: + negative_prompt_embeds_llama3 = negative_prompt_embeds_llama3.repeat(1, batch_size, 1, 1) + + # duplicate pooled_prompt_embeds for each generation per prompt + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt) + pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + # duplicate t5_prompt_embeds for batch_size and num_images_per_prompt + bs_embed, seq_len, _ = prompt_embeds_t5.shape + if bs_embed == 1 and batch_size > 1: + prompt_embeds_t5 = prompt_embeds_t5.repeat(batch_size, 1, 1) + elif bs_embed > 1 and bs_embed != batch_size: + raise ValueError(f"cannot duplicate prompt_embeds_t5 of batch size {bs_embed}") + prompt_embeds_t5 = prompt_embeds_t5.repeat(1, num_images_per_prompt, 1) + prompt_embeds_t5 = prompt_embeds_t5.view(batch_size * num_images_per_prompt, seq_len, -1) + + # duplicate llama3_prompt_embeds for batch_size and num_images_per_prompt + _, bs_embed, seq_len, dim = prompt_embeds_llama3.shape + if bs_embed == 1 and batch_size > 1: + prompt_embeds_llama3 = prompt_embeds_llama3.repeat(1, batch_size, 1, 1) + elif bs_embed > 1 and bs_embed != batch_size: + raise ValueError(f"cannot duplicate prompt_embeds_llama3 of batch size {bs_embed}") + prompt_embeds_llama3 = prompt_embeds_llama3.repeat(1, 1, num_images_per_prompt, 1) + prompt_embeds_llama3 = prompt_embeds_llama3.view(-1, batch_size * num_images_per_prompt, seq_len, dim) + + if do_classifier_free_guidance: + # duplicate negative_pooled_prompt_embeds for batch_size and num_images_per_prompt + bs_embed, seq_len = negative_pooled_prompt_embeds.shape + if bs_embed == 1 and batch_size > 1: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(batch_size, 1) + elif bs_embed > 1 and bs_embed != batch_size: + raise ValueError(f"cannot duplicate negative_pooled_prompt_embeds of batch size {bs_embed}") + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + # duplicate negative_t5_prompt_embeds for batch_size and num_images_per_prompt + bs_embed, seq_len, _ = negative_prompt_embeds_t5.shape + if bs_embed == 1 and batch_size > 1: + negative_prompt_embeds_t5 = negative_prompt_embeds_t5.repeat(batch_size, 1, 1) + elif bs_embed > 1 and bs_embed != batch_size: + raise ValueError(f"cannot duplicate negative_prompt_embeds_t5 of batch size {bs_embed}") + negative_prompt_embeds_t5 = negative_prompt_embeds_t5.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds_t5 = negative_prompt_embeds_t5.view(batch_size * num_images_per_prompt, seq_len, -1) + + # duplicate negative_prompt_embeds_llama3 for batch_size and num_images_per_prompt + _, bs_embed, seq_len, dim = negative_prompt_embeds_llama3.shape + if bs_embed == 1 and batch_size > 1: + negative_prompt_embeds_llama3 = negative_prompt_embeds_llama3.repeat(1, batch_size, 1, 1) + elif bs_embed > 1 and bs_embed != batch_size: + raise ValueError(f"cannot duplicate negative_prompt_embeds_llama3 of batch size {bs_embed}") + negative_prompt_embeds_llama3 = negative_prompt_embeds_llama3.repeat(1, 1, num_images_per_prompt, 1) + negative_prompt_embeds_llama3 = negative_prompt_embeds_llama3.view( + -1, batch_size * num_images_per_prompt, seq_len, dim + ) + + return ( + prompt_embeds_t5, + negative_prompt_embeds_t5, + prompt_embeds_llama3, + negative_prompt_embeds_llama3, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." + deprecate( + "enable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." + deprecate( + "disable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." + deprecate( + "enable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." + deprecate( + "disable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.disable_tiling() + + def check_inputs( + self, + prompt, + prompt_2, + prompt_3, + prompt_4, + negative_prompt=None, + negative_prompt_2=None, + negative_prompt_3=None, + negative_prompt_4=None, + prompt_embeds_t5=None, + prompt_embeds_llama3=None, + negative_prompt_embeds_t5=None, + negative_prompt_embeds_llama3=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and pooled_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `pooled_prompt_embeds`: {pooled_prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and pooled_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `pooled_prompt_embeds`: {pooled_prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_3 is not None and prompt_embeds_t5 is not None: + raise ValueError( + f"Cannot forward both `prompt_3`: {prompt_3} and `prompt_embeds_t5`: {prompt_embeds_t5}. Please make sure to" + " only forward one of the two." + ) + elif prompt_4 is not None and prompt_embeds_llama3 is not None: + raise ValueError( + f"Cannot forward both `prompt_4`: {prompt_4} and `prompt_embeds_llama3`: {prompt_embeds_llama3}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and pooled_prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `pooled_prompt_embeds`. Cannot leave both `prompt` and `pooled_prompt_embeds` undefined." + ) + elif prompt is None and prompt_embeds_t5 is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds_t5`. Cannot leave both `prompt` and `prompt_embeds_t5` undefined." + ) + elif prompt is None and prompt_embeds_llama3 is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds_llama3`. Cannot leave both `prompt` and `prompt_embeds_llama3` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)): + raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}") + elif prompt_4 is not None and (not isinstance(prompt_4, str) and not isinstance(prompt_4, list)): + raise ValueError(f"`prompt_4` has to be of type `str` or `list` but is {type(prompt_4)}") + + if negative_prompt is not None and negative_pooled_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_pooled_prompt_embeds`:" + f" {negative_pooled_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_pooled_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_pooled_prompt_embeds`:" + f" {negative_pooled_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_3 is not None and negative_prompt_embeds_t5 is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds_t5`:" + f" {negative_prompt_embeds_t5}. Please make sure to only forward one of the two." + ) + elif negative_prompt_4 is not None and negative_prompt_embeds_llama3 is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_4`: {negative_prompt_4} and `negative_prompt_embeds_llama3`:" + f" {negative_prompt_embeds_llama3}. Please make sure to only forward one of the two." + ) + + if pooled_prompt_embeds is not None and negative_pooled_prompt_embeds is not None: + if pooled_prompt_embeds.shape != negative_pooled_prompt_embeds.shape: + raise ValueError( + "`pooled_prompt_embeds` and `negative_pooled_prompt_embeds` must have the same shape when passed directly, but" + f" got: `pooled_prompt_embeds` {pooled_prompt_embeds.shape} != `negative_pooled_prompt_embeds`" + f" {negative_pooled_prompt_embeds.shape}." + ) + if prompt_embeds_t5 is not None and negative_prompt_embeds_t5 is not None: + if prompt_embeds_t5.shape != negative_prompt_embeds_t5.shape: + raise ValueError( + "`prompt_embeds_t5` and `negative_prompt_embeds_t5` must have the same shape when passed directly, but" + f" got: `prompt_embeds_t5` {prompt_embeds_t5.shape} != `negative_prompt_embeds_t5`" + f" {negative_prompt_embeds_t5.shape}." + ) + if prompt_embeds_llama3 is not None and negative_prompt_embeds_llama3 is not None: + if prompt_embeds_llama3.shape != negative_prompt_embeds_llama3.shape: + raise ValueError( + "`prompt_embeds_llama3` and `negative_prompt_embeds_llama3` must have the same shape when passed directly, but" + f" got: `prompt_embeds_llama3` {prompt_embeds_llama3.shape} != `negative_prompt_embeds_llama3`" + f" {negative_prompt_embeds_llama3.shape}." + ) + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + prompt_2: str | list[str] | None = None, + prompt_3: str | list[str] | None = None, + prompt_4: str | list[str] | None = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + sigmas: list[float] | None = None, + guidance_scale: float = 5.0, + negative_prompt: str | list[str] | None = None, + negative_prompt_2: str | list[str] | None = None, + negative_prompt_3: str | list[str] | None = None, + negative_prompt_4: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.FloatTensor | None = None, + prompt_embeds_t5: torch.FloatTensor | None = None, + prompt_embeds_llama3: torch.FloatTensor | None = None, + negative_prompt_embeds_t5: torch.FloatTensor | None = None, + negative_prompt_embeds_llama3: torch.FloatTensor | None = None, + pooled_prompt_embeds: torch.FloatTensor | None = None, + negative_pooled_prompt_embeds: torch.FloatTensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 128, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead. + prompt_3 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is + will be used instead. + prompt_4 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_4` and `text_encoder_4`. If not defined, `prompt` is + will be used instead. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 3.5): + Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages + a model to generate images more aligned with `prompt` at the expense of lower image quality. + + Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to + the [paper](https://huggingface.co/papers/2210.03142) to learn more. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. + negative_prompt_3 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and + `text_encoder_3`. If not defined, `negative_prompt` is used in all the text-encoders. + negative_prompt_4 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_4` and + `text_encoder_4`. If not defined, `negative_prompt` is used in all the text-encoders. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 128): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.hidream_image.HiDreamImagePipelineOutput`] or `tuple`: + [`~pipelines.hidream_image.HiDreamImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is a list with the generated. images. + """ + + prompt_embeds = kwargs.get("prompt_embeds", None) + negative_prompt_embeds = kwargs.get("negative_prompt_embeds", None) + + if prompt_embeds is not None: + deprecation_message = "The `prompt_embeds` argument is deprecated. Please use `prompt_embeds_t5` and `prompt_embeds_llama3` instead." + deprecate("prompt_embeds", "0.35.0", deprecation_message) + prompt_embeds_t5 = prompt_embeds[0] + prompt_embeds_llama3 = prompt_embeds[1] + + if negative_prompt_embeds is not None: + deprecation_message = "The `negative_prompt_embeds` argument is deprecated. Please use `negative_prompt_embeds_t5` and `negative_prompt_embeds_llama3` instead." + deprecate("negative_prompt_embeds", "0.35.0", deprecation_message) + negative_prompt_embeds_t5 = negative_prompt_embeds[0] + negative_prompt_embeds_llama3 = negative_prompt_embeds[1] + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + division = self.vae_scale_factor * 2 + S_max = (self.default_sample_size * self.vae_scale_factor) ** 2 + scale = S_max / (width * height) + scale = math.sqrt(scale) + width, height = int(width * scale // division * division), int(height * scale // division * division) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + prompt_3, + prompt_4, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + negative_prompt_4=negative_prompt_4, + prompt_embeds_t5=prompt_embeds_t5, + prompt_embeds_llama3=prompt_embeds_llama3, + negative_prompt_embeds_t5=negative_prompt_embeds_t5, + negative_prompt_embeds_llama3=negative_prompt_embeds_llama3, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + elif pooled_prompt_embeds is not None: + batch_size = pooled_prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode prompt + lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None + ( + prompt_embeds_t5, + negative_prompt_embeds_t5, + prompt_embeds_llama3, + negative_prompt_embeds_llama3, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_3=prompt_3, + prompt_4=prompt_4, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + negative_prompt_4=negative_prompt_4, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds_t5=prompt_embeds_t5, + prompt_embeds_llama3=prompt_embeds_llama3, + negative_prompt_embeds_t5=negative_prompt_embeds_t5, + negative_prompt_embeds_llama3=negative_prompt_embeds_llama3, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + if self.do_classifier_free_guidance: + prompt_embeds_t5 = torch.cat([negative_prompt_embeds_t5, prompt_embeds_t5], dim=0) + prompt_embeds_llama3 = torch.cat([negative_prompt_embeds_llama3, prompt_embeds_llama3], dim=1) + pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + pooled_prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 5. Prepare timesteps + mu = calculate_shift(self.transformer.max_seq) + scheduler_kwargs = {"mu": mu} + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + if isinstance(self.scheduler, UniPCMultistepScheduler): + self.scheduler.set_timesteps(num_inference_steps, device=timestep_device) # , shift=math.exp(mu)) + timesteps = self.scheduler.timesteps + else: + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + timestep_device, + sigmas=sigmas, + **scheduler_kwargs, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timesteps=timestep, + encoder_hidden_states_t5=prompt_embeds_t5, + encoder_hidden_states_llama3=prompt_embeds_llama3, + pooled_embeds=pooled_prompt_embeds, + return_dict=False, + )[0] + noise_pred = -noise_pred + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds_t5 = callback_outputs.pop("prompt_embeds_t5", prompt_embeds_t5) + prompt_embeds_llama3 = callback_outputs.pop("prompt_embeds_llama3", prompt_embeds_llama3) + pooled_prompt_embeds = callback_outputs.pop("pooled_prompt_embeds", pooled_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return HiDreamImagePipelineOutput(images=image) diff --git a/diffusers/pipelines/hidream_image/pipeline_output.py b/diffusers/pipelines/hidream_image/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..1802c7220691d5b14c99580de3a6bae5846b838d --- /dev/null +++ b/diffusers/pipelines/hidream_image/pipeline_output.py @@ -0,0 +1,34 @@ +# Copyright 2025 HiDream-ai Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +import numpy as np +import PIL.Image + +from ...utils import BaseOutput + + +@dataclass +class HiDreamImagePipelineOutput(BaseOutput): + """ + Output class for HiDreamImage pipelines. + + Args: + images (`list[PIL.Image.Image]` or `np.ndarray`) + list of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: list[PIL.Image.Image] | np.ndarray diff --git a/diffusers/pipelines/hunyuan_image/__init__.py b/diffusers/pipelines/hunyuan_image/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7da72fa12b2c18c1b9ec7f9fb9b2170e4b516041 --- /dev/null +++ b/diffusers/pipelines/hunyuan_image/__init__.py @@ -0,0 +1,50 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_hunyuanimage"] = ["HunyuanImagePipeline"] + _import_structure["pipeline_hunyuanimage_refiner"] = ["HunyuanImageRefinerPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_hunyuanimage import HunyuanImagePipeline + from .pipeline_hunyuanimage_refiner import HunyuanImageRefinerPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/pipelines/hunyuan_image/pipeline_hunyuanimage.py b/diffusers/pipelines/hunyuan_image/pipeline_hunyuanimage.py new file mode 100644 index 0000000000000000000000000000000000000000..50239e9afa22678bfe1532291fa75cda6e756988 --- /dev/null +++ b/diffusers/pipelines/hunyuan_image/pipeline_hunyuanimage.py @@ -0,0 +1,868 @@ +# Copyright 2025 Hunyuan-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import inspect +import re +from typing import Any, Callable + +import numpy as np +import torch +from transformers import ByT5Tokenizer, Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, T5EncoderModel + +from ...guiders import AdaptiveProjectedMixGuidance +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKLHunyuanImage, HunyuanImageTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import HunyuanImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import HunyuanImagePipeline + + >>> pipe = HunyuanImagePipeline.from_pretrained( + ... "hunyuanvideo-community/HunyuanImage-2.1-Diffusers", torch_dtype=torch.bfloat16 + ... ) + >>> pipe.to("cuda") + >>> prompt = "A cat holding a sign that says hello world" + >>> # Depending on the variant being used, the pipeline call will slightly vary. + >>> # Refer to the pipeline documentation for more details. + >>> image = pipe(prompt, negative_prompt="", num_inference_steps=50).images[0] + >>> image.save("hunyuanimage.png") + ``` +""" + + +def extract_glyph_text(prompt: str): + """ + Extract text enclosed in quotes for glyph rendering. + + Finds text in single quotes, double quotes, and Chinese quotes, then formats it for byT5 processing. + + Args: + prompt: Input text prompt + + Returns: + Formatted glyph text string or None if no quoted text found + """ + text_prompt_texts = [] + pattern_quote_single = r"\'(.*?)\'" + pattern_quote_double = r"\"(.*?)\"" + pattern_quote_chinese_single = r"‘(.*?)’" + pattern_quote_chinese_double = r"“(.*?)”" + + matches_quote_single = re.findall(pattern_quote_single, prompt) + matches_quote_double = re.findall(pattern_quote_double, prompt) + matches_quote_chinese_single = re.findall(pattern_quote_chinese_single, prompt) + matches_quote_chinese_double = re.findall(pattern_quote_chinese_double, prompt) + + text_prompt_texts.extend(matches_quote_single) + text_prompt_texts.extend(matches_quote_double) + text_prompt_texts.extend(matches_quote_chinese_single) + text_prompt_texts.extend(matches_quote_chinese_double) + + if text_prompt_texts: + glyph_text_formatted = ". ".join([f'Text "{text}"' for text in text_prompt_texts]) + ". " + else: + glyph_text_formatted = None + + return glyph_text_formatted + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class HunyuanImagePipeline(DiffusionPipeline): + r""" + The HunyuanImage pipeline for text-to-image generation. + + Args: + transformer ([`HunyuanImageTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLHunyuanImage`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`Qwen2.5-VL-7B-Instruct`]): + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant. + tokenizer (`Qwen2Tokenizer`): Tokenizer of class [Qwen2Tokenizer]. + text_encoder_2 ([`T5EncoderModel`]): + [T5EncoderModel](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel) + variant. + tokenizer_2 (`ByT5Tokenizer`): Tokenizer of class [ByT5Tokenizer] + guider ([`AdaptiveProjectedMixGuidance`]): + [AdaptiveProjectedMixGuidance]to be used to guide the image generation. + ocr_guider ([`AdaptiveProjectedMixGuidance`], *optional*): + [AdaptiveProjectedMixGuidance] to be used to guide the image generation when text rendering is needed. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + _optional_components = ["ocr_guider", "guider"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLHunyuanImage, + text_encoder: Qwen2_5_VLForConditionalGeneration, + tokenizer: Qwen2Tokenizer, + text_encoder_2: T5EncoderModel, + tokenizer_2: ByT5Tokenizer, + transformer: HunyuanImageTransformer2DModel, + guider: AdaptiveProjectedMixGuidance | None = None, + ocr_guider: AdaptiveProjectedMixGuidance | None = None, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + guider=guider, + ocr_guider=ocr_guider, + ) + + self.vae_scale_factor = self.vae.config.spatial_compression_ratio if getattr(self, "vae", None) else 32 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.tokenizer_max_length = 1000 + self.tokenizer_2_max_length = 128 + self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>" + self.prompt_template_encode_start_idx = 34 + self.default_sample_size = 64 + + def _get_qwen_prompt_embeds( + self, + tokenizer: Qwen2Tokenizer, + text_encoder: Qwen2_5_VLForConditionalGeneration, + prompt: str | list[str] = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + tokenizer_max_length: int = 1000, + template: str = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>", + drop_idx: int = 34, + hidden_state_skip_layer: int = 2, + ): + device = device or self._execution_device + dtype = dtype or text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + + txt = [template.format(e) for e in prompt] + txt_tokens = tokenizer( + txt, max_length=tokenizer_max_length + drop_idx, padding="max_length", truncation=True, return_tensors="pt" + ).to(device) + + encoder_hidden_states = text_encoder( + input_ids=txt_tokens.input_ids, + attention_mask=txt_tokens.attention_mask, + output_hidden_states=True, + ) + prompt_embeds = encoder_hidden_states.hidden_states[-(hidden_state_skip_layer + 1)] + + prompt_embeds = prompt_embeds[:, drop_idx:] + encoder_attention_mask = txt_tokens.attention_mask[:, drop_idx:] + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + encoder_attention_mask = encoder_attention_mask.to(device=device) + + return prompt_embeds, encoder_attention_mask + + def _get_byt5_prompt_embeds( + self, + tokenizer: ByT5Tokenizer, + text_encoder: T5EncoderModel, + prompt: str, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + tokenizer_max_length: int = 128, + ): + device = device or self._execution_device + dtype = dtype or text_encoder.dtype + + if isinstance(prompt, list): + raise ValueError("byt5 prompt should be a string") + elif prompt is None: + raise ValueError("byt5 prompt should not be None") + + txt_tokens = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer_max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ).to(device) + + prompt_embeds = text_encoder( + input_ids=txt_tokens.input_ids, + attention_mask=txt_tokens.attention_mask.float(), + )[0] + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + encoder_attention_mask = txt_tokens.attention_mask.to(device=device) + + return prompt_embeds, encoder_attention_mask + + def encode_prompt( + self, + prompt: str | list[str], + device: torch.device | None = None, + batch_size: int = 1, + num_images_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_mask: torch.Tensor | None = None, + prompt_embeds_2: torch.Tensor | None = None, + prompt_embeds_mask_2: torch.Tensor | None = None, + ): + r""" + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + batch_size (`int`): + batch size of prompts, defaults to 1 + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. If not provided, text embeddings will be generated from `prompt` input + argument. + prompt_embeds_mask (`torch.Tensor`, *optional*): + Pre-generated text mask. If not provided, text mask will be generated from `prompt` input argument. + prompt_embeds_2 (`torch.Tensor`, *optional*): + Pre-generated glyph text embeddings from ByT5. If not provided, will be generated from `prompt` input + argument using self.tokenizer_2 and self.text_encoder_2. + prompt_embeds_mask_2 (`torch.Tensor`, *optional*): + Pre-generated glyph text mask from ByT5. If not provided, will be generated from `prompt` input + argument using self.tokenizer_2 and self.text_encoder_2. + """ + device = device or self._execution_device + + if prompt is None: + prompt = [""] * batch_size + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds( + tokenizer=self.tokenizer, + text_encoder=self.text_encoder, + prompt=prompt, + device=device, + tokenizer_max_length=self.tokenizer_max_length, + template=self.prompt_template_encode, + drop_idx=self.prompt_template_encode_start_idx, + ) + + if prompt_embeds_2 is None: + prompt_embeds_2_list = [] + prompt_embeds_mask_2_list = [] + + glyph_texts = [extract_glyph_text(p) for p in prompt] + for glyph_text in glyph_texts: + if glyph_text is None: + glyph_text_embeds = torch.zeros( + (1, self.tokenizer_2_max_length, self.text_encoder_2.config.d_model), device=device + ) + glyph_text_embeds_mask = torch.zeros( + (1, self.tokenizer_2_max_length), device=device, dtype=torch.int64 + ) + else: + glyph_text_embeds, glyph_text_embeds_mask = self._get_byt5_prompt_embeds( + tokenizer=self.tokenizer_2, + text_encoder=self.text_encoder_2, + prompt=glyph_text, + device=device, + tokenizer_max_length=self.tokenizer_2_max_length, + ) + + prompt_embeds_2_list.append(glyph_text_embeds) + prompt_embeds_mask_2_list.append(glyph_text_embeds_mask) + + prompt_embeds_2 = torch.cat(prompt_embeds_2_list, dim=0) + prompt_embeds_mask_2 = torch.cat(prompt_embeds_mask_2_list, dim=0) + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + + _, seq_len_2, _ = prompt_embeds_2.shape + prompt_embeds_2 = prompt_embeds_2.repeat(1, num_images_per_prompt, 1) + prompt_embeds_2 = prompt_embeds_2.view(batch_size * num_images_per_prompt, seq_len_2, -1) + prompt_embeds_mask_2 = prompt_embeds_mask_2.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask_2 = prompt_embeds_mask_2.view(batch_size * num_images_per_prompt, seq_len_2) + + return prompt_embeds, prompt_embeds_mask, prompt_embeds_2, prompt_embeds_mask_2 + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_embeds_mask=None, + negative_prompt_embeds_mask=None, + prompt_embeds_2=None, + prompt_embeds_mask_2=None, + negative_prompt_embeds_2=None, + negative_prompt_embeds_mask_2=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_embeds_mask is None: + raise ValueError( + "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`." + ) + if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if prompt is None and prompt_embeds_2 is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined." + ) + + if prompt_embeds_2 is not None and prompt_embeds_mask_2 is None: + raise ValueError( + "If `prompt_embeds_2` are provided, `prompt_embeds_mask_2` also have to be passed. Make sure to generate `prompt_embeds_mask_2` from the same text encoder that was used to generate `prompt_embeds_2`." + ) + if negative_prompt_embeds_2 is not None and negative_prompt_embeds_mask_2 is None: + raise ValueError( + "If `negative_prompt_embeds_2` are provided, `negative_prompt_embeds_mask_2` also have to be passed. Make sure to generate `negative_prompt_embeds_mask_2` from the same text encoder that was used to generate `negative_prompt_embeds_2`." + ) + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + height = int(height) // self.vae_scale_factor + width = int(width) // self.vae_scale_factor + + shape = (batch_size, num_channels_latents, height, width) + + if latents is not None: + return latents.to(device=device, dtype=dtype) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + return latents + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str | list[str] = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + distilled_guidance_scale: float | None = 3.25, + sigmas: list[float] | None = None, + num_images_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds_mask: torch.Tensor | None = None, + prompt_embeds_2: torch.Tensor | None = None, + prompt_embeds_mask_2: torch.Tensor | None = None, + negative_prompt_embeds_2: torch.Tensor | None = None, + negative_prompt_embeds_mask_2: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined and negative_prompt_embeds is + not provided, will use an empty negative prompt. Ignored when not using guidance. ). + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + distilled_guidance_scale (`float`, *optional*, defaults to None): + A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance + where the guidance scale is applied during inference through noise prediction rescaling, guidance + distilled models take the guidance scale directly as an input parameter during forward pass. Guidance + is enabled by setting `distilled_guidance_scale > 1`. Higher guidance scale encourages to generate + images that are closely linked to the text `prompt`, usually at the expense of lower image quality. For + guidance distilled models, this parameter is required. For non-distilled models, this parameter will be + ignored. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_embeds_mask (`torch.Tensor`, *optional*): + Pre-generated text embeddings mask. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, text embeddings mask will be generated from `prompt` input argument. + prompt_embeds_2 (`torch.Tensor`, *optional*): + Pre-generated text embeddings for ocr. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, text embeddings for ocr will be generated from `prompt` input argument. + prompt_embeds_mask_2 (`torch.Tensor`, *optional*): + Pre-generated text embeddings mask for ocr. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, text embeddings mask for ocr will be generated from `prompt` input + argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_prompt_embeds_mask (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings mask. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative text embeddings mask will be generated from `negative_prompt` + input argument. + negative_prompt_embeds_2 (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings for ocr. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative text embeddings for ocr will be generated from `negative_prompt` + input argument. + negative_prompt_embeds_mask_2 (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings mask for ocr. Can be used to easily tweak text inputs, *e.g.* + prompt weighting. If not provided, negative text embeddings mask for ocr will be generated from + `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.hunyuan_image.HunyuanImagePipelineOutput`] or `tuple`: + [`~pipelines.hunyuan_image.HunyuanImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is a list with the generated images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + negative_prompt_embeds_mask=negative_prompt_embeds_mask, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + prompt_embeds_2=prompt_embeds_2, + prompt_embeds_mask_2=prompt_embeds_mask_2, + negative_prompt_embeds_2=negative_prompt_embeds_2, + negative_prompt_embeds_mask_2=negative_prompt_embeds_mask_2, + ) + + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. prepare prompt embeds + + prompt_embeds, prompt_embeds_mask, prompt_embeds_2, prompt_embeds_mask_2 = self.encode_prompt( + prompt=prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + device=device, + batch_size=batch_size, + num_images_per_prompt=num_images_per_prompt, + prompt_embeds_2=prompt_embeds_2, + prompt_embeds_mask_2=prompt_embeds_mask_2, + ) + + prompt_embeds = prompt_embeds.to(self.transformer.dtype) + prompt_embeds_2 = prompt_embeds_2.to(self.transformer.dtype) + + # select guider + if not torch.all(prompt_embeds_2 == 0) and self.ocr_guider is not None: + # prompt contains ocr and pipeline has a guider for ocr + guider = self.ocr_guider + elif self.guider is not None: + guider = self.guider + # distilled model does not use guidance method, use default guider with enabled=False + else: + guider = AdaptiveProjectedMixGuidance(enabled=False) + + if guider._enabled and guider.num_conditions > 1: + ( + negative_prompt_embeds, + negative_prompt_embeds_mask, + negative_prompt_embeds_2, + negative_prompt_embeds_mask_2, + ) = self.encode_prompt( + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + device=device, + batch_size=batch_size, + num_images_per_prompt=num_images_per_prompt, + prompt_embeds_2=negative_prompt_embeds_2, + prompt_embeds_mask_2=negative_prompt_embeds_mask_2, + ) + + negative_prompt_embeds = negative_prompt_embeds.to(self.transformer.dtype) + negative_prompt_embeds_2 = negative_prompt_embeds_2.to(self.transformer.dtype) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size=batch_size * num_images_per_prompt, + num_channels_latents=num_channels_latents, + height=height, + width=width, + dtype=prompt_embeds.dtype, + device=device, + generator=generator, + latents=latents, + ) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance (for guidance-distilled model) + if self.transformer.config.guidance_embeds and distilled_guidance_scale is None: + raise ValueError("`distilled_guidance_scale` is required for guidance-distilled model.") + + if self.transformer.config.guidance_embeds: + guidance = ( + torch.tensor( + [distilled_guidance_scale] * latents.shape[0], dtype=self.transformer.dtype, device=device + ) + * 1000.0 + ) + + else: + guidance = None + + if self.attention_kwargs is None: + self._attention_kwargs = {} + + # 6. Denoising loop + self.scheduler.set_begin_index(0) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + if self.transformer.config.use_meanflow: + if i == len(timesteps) - 1: + timestep_r = torch.tensor([0.0], device=device) + else: + timestep_r = timesteps[i + 1] + timestep_r = timestep_r.expand(latents.shape[0]).to(latents.dtype) + else: + timestep_r = None + + # Step 1: Collect model inputs needed for the guidance method + # conditional inputs should always be first element in the tuple + guider_inputs = { + "encoder_hidden_states": (prompt_embeds, negative_prompt_embeds), + "encoder_attention_mask": (prompt_embeds_mask, negative_prompt_embeds_mask), + "encoder_hidden_states_2": (prompt_embeds_2, negative_prompt_embeds_2), + "encoder_attention_mask_2": (prompt_embeds_mask_2, negative_prompt_embeds_mask_2), + } + + # Step 2: Update guider's internal state for this denoising step + guider.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t) + + # Step 3: Prepare batched model inputs based on the guidance method + # The guider splits model inputs into separate batches for conditional/unconditional predictions. + # For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}: + # you will get a guider_state with two batches: + # guider_state = [ + # {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch + # {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch + # ] + # Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG). + guider_state = guider.prepare_inputs(guider_inputs) + # Step 4: Run the denoiser for each batch + # Each batch in guider_state represents a different conditioning (conditional, unconditional, etc.). + # We run the model once per batch and store the noise prediction in guider_state_batch.noise_pred. + for guider_state_batch in guider_state: + guider.prepare_models(self.transformer) + + # Extract conditioning kwargs for this batch (e.g., encoder_hidden_states) + cond_kwargs = { + input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys() + } + + # e.g. "pred_cond"/"pred_uncond" + context_name = getattr(guider_state_batch, guider._identifier_key) + with self.transformer.cache_context(context_name): + # Run denoiser and store noise prediction in this batch + guider_state_batch.noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep, + timestep_r=timestep_r, + guidance=guidance, + attention_kwargs=self.attention_kwargs, + return_dict=False, + **cond_kwargs, + )[0] + + # Cleanup model (e.g., remove hooks) + guider.cleanup_models(self.transformer) + + # Step 5: Combine predictions using the guidance method + # The guider takes all noise predictions from guider_state and combines them according to the guidance algorithm. + # Continuing the CFG example, the guider receives: + # guider_state = [ + # {"encoder_hidden_states": prompt_embeds, "noise_pred": noise_pred_cond, "__guidance_identifier__": "pred_cond"}, # batch 0 + # {"encoder_hidden_states": negative_prompt_embeds, "noise_pred": noise_pred_uncond, "__guidance_identifier__": "pred_uncond"}, # batch 1 + # ] + # And extracts predictions using the __guidance_identifier__: + # pred_cond = guider_state[0]["noise_pred"] # extracts noise_pred_cond + # pred_uncond = guider_state[1]["noise_pred"] # extracts noise_pred_uncond + # Then applies CFG formula: + # noise_pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond) + # Returns GuiderOutput(pred=noise_pred, pred_cond=pred_cond, pred_uncond=pred_uncond) + noise_pred = guider(guider_state)[0] + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + if output_type == "latent": + image = latents + else: + latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return HunyuanImagePipelineOutput(images=image) diff --git a/diffusers/pipelines/hunyuan_image/pipeline_hunyuanimage_refiner.py b/diffusers/pipelines/hunyuan_image/pipeline_hunyuanimage_refiner.py new file mode 100644 index 0000000000000000000000000000000000000000..93e4deb2974a9f9854c49cd45bcaec1b7e7a4257 --- /dev/null +++ b/diffusers/pipelines/hunyuan_image/pipeline_hunyuanimage_refiner.py @@ -0,0 +1,754 @@ +# Copyright 2025 Hunyuan-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import inspect +from typing import Any, Callable + +import numpy as np +import torch +from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer + +from ...guiders import AdaptiveProjectedMixGuidance +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...models import AutoencoderKLHunyuanImageRefiner, HunyuanImageTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import HunyuanImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import HunyuanImageRefinerPipeline + + >>> pipe = HunyuanImageRefinerPipeline.from_pretrained( + ... "hunyuanvideo-community/HunyuanImage-2.1-Refiner-Diffusers", torch_dtype=torch.bfloat16 + ... ) + >>> pipe.to("cuda") + >>> prompt = "A cat holding a sign that says hello world" + >>> image = load_image("path/to/image.png") + >>> # Depending on the variant being used, the pipeline call will slightly vary. + >>> # Refer to the pipeline documentation for more details. + >>> image = pipe(prompt, image=image, num_inference_steps=4).images[0] + >>> image.save("hunyuanimage.png") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class HunyuanImageRefinerPipeline(DiffusionPipeline): + r""" + The HunyuanImage pipeline for text-to-image generation. + + Args: + transformer ([`HunyuanImageTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLHunyuanImageRefiner`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`Qwen2.5-VL-7B-Instruct`]): + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant. + tokenizer (`Qwen2Tokenizer`): Tokenizer of class [Qwen2Tokenizer]. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + _optional_components = ["guider"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLHunyuanImageRefiner, + text_encoder: Qwen2_5_VLForConditionalGeneration, + tokenizer: Qwen2Tokenizer, + transformer: HunyuanImageTransformer2DModel, + guider: AdaptiveProjectedMixGuidance | None = None, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + guider=guider, + ) + + self.vae_scale_factor = self.vae.config.spatial_compression_ratio if getattr(self, "vae", None) else 16 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.tokenizer_max_length = 256 + self.prompt_template_encode = "<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" + self.prompt_template_encode_start_idx = 36 + self.default_sample_size = 64 + self.latent_channels = self.transformer.config.in_channels // 2 if getattr(self, "transformer", None) else 64 + + # Copied from diffusers.pipelines.hunyuan_image.pipeline_hunyuanimage.HunyuanImagePipeline._get_qwen_prompt_embeds + def _get_qwen_prompt_embeds( + self, + tokenizer: Qwen2Tokenizer, + text_encoder: Qwen2_5_VLForConditionalGeneration, + prompt: str | list[str] = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + tokenizer_max_length: int = 1000, + template: str = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>", + drop_idx: int = 34, + hidden_state_skip_layer: int = 2, + ): + device = device or self._execution_device + dtype = dtype or text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + + txt = [template.format(e) for e in prompt] + txt_tokens = tokenizer( + txt, max_length=tokenizer_max_length + drop_idx, padding="max_length", truncation=True, return_tensors="pt" + ).to(device) + + encoder_hidden_states = text_encoder( + input_ids=txt_tokens.input_ids, + attention_mask=txt_tokens.attention_mask, + output_hidden_states=True, + ) + prompt_embeds = encoder_hidden_states.hidden_states[-(hidden_state_skip_layer + 1)] + + prompt_embeds = prompt_embeds[:, drop_idx:] + encoder_attention_mask = txt_tokens.attention_mask[:, drop_idx:] + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + encoder_attention_mask = encoder_attention_mask.to(device=device) + + return prompt_embeds, encoder_attention_mask + + def encode_prompt( + self, + prompt: str | list[str] | None = None, + device: torch.device | None = None, + batch_size: int = 1, + num_images_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_mask: torch.Tensor | None = None, + ): + r""" + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + batch_size (`int`): + batch size of prompts, defaults to 1 + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. If not provided, text embeddings will be generated from `prompt` input + argument. + prompt_embeds_mask (`torch.Tensor`, *optional*): + Pre-generated text mask. If not provided, text mask will be generated from `prompt` input argument. + prompt_embeds_2 (`torch.Tensor`, *optional*): + Pre-generated glyph text embeddings from ByT5. If not provided, will be generated from `prompt` input + argument using self.tokenizer_2 and self.text_encoder_2. + prompt_embeds_mask_2 (`torch.Tensor`, *optional*): + Pre-generated glyph text mask from ByT5. If not provided, will be generated from `prompt` input + argument using self.tokenizer_2 and self.text_encoder_2. + """ + device = device or self._execution_device + + if prompt is None: + prompt = [""] * batch_size + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds( + tokenizer=self.tokenizer, + text_encoder=self.text_encoder, + prompt=prompt, + device=device, + tokenizer_max_length=self.tokenizer_max_length, + template=self.prompt_template_encode, + drop_idx=self.prompt_template_encode_start_idx, + ) + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + + return prompt_embeds, prompt_embeds_mask + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_embeds_mask=None, + negative_prompt_embeds_mask=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_embeds_mask is None: + raise ValueError( + "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`." + ) + if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + def prepare_latents( + self, + image_latents, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + strength=0.25, + ): + height = int(height) // self.vae_scale_factor + width = int(width) // self.vae_scale_factor + + shape = (batch_size, num_channels_latents, 1, height, width) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + cond_latents = strength * noise + (1 - strength) * image_latents + + return latents, cond_latents + + @staticmethod + def _reorder_image_tokens(image_latents): + image_latents = torch.cat((image_latents[:, :, :1], image_latents), dim=2) + batch_size, num_latent_channels, num_latent_frames, latent_height, latent_width = image_latents.shape + image_latents = image_latents.permute(0, 2, 1, 3, 4) + image_latents = image_latents.reshape( + batch_size, num_latent_frames // 2, num_latent_channels * 2, latent_height, latent_width + ) + image_latents = image_latents.permute(0, 2, 1, 3, 4).contiguous() + + return image_latents + + @staticmethod + def _restore_image_tokens_order(latents): + """Restore image tokens order by splitting channels and removing first frame slice.""" + batch_size, num_latent_channels, num_latent_frames, latent_height, latent_width = latents.shape + + latents = latents.permute(0, 2, 1, 3, 4) # B, F, C, H, W + latents = latents.reshape( + batch_size, num_latent_frames * 2, num_latent_channels // 2, latent_height, latent_width + ) # B, F*2, C//2, H, W + + latents = latents.permute(0, 2, 1, 3, 4) # B, C//2, F*2, H, W + # Remove first frame slice + latents = latents[:, :, 1:] + + return latents + + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="sample") + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="sample") + image_latents = self._reorder_image_tokens(image_latents) + + image_latents = image_latents * self.vae.config.scaling_factor + + return image_latents + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str | list[str] = None, + distilled_guidance_scale: float | None = 3.25, + image: PipelineImageInput | None = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 4, + sigmas: list[float] | None = None, + num_images_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, will use an empty negative + prompt. Ignored when not using guidance. + distilled_guidance_scale (`float`, *optional*, defaults to None): + A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance + where the guidance scale is applied during inference through noise prediction rescaling, guidance + distilled models take the guidance scale directly as an input parameter during forward pass. Guidance + is enabled by setting `distilled_guidance_scale > 1`. Higher guidance scale encourages to generate + images that are closely linked to the text `prompt`, usually at the expense of lower image quality. For + guidance distilled models, this parameter is required. For non-distilled models, this parameter will be + ignored. + num_images_per_prompt (`int`, *optional*, defaults to 1): + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.hunyuan_image.HunyuanImagePipelineOutput`] or `tuple`: + [`~pipelines.hunyuan_image.HunyuanImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is a list with the generated images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + negative_prompt_embeds_mask=negative_prompt_embeds_mask, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. process image + if image is not None and isinstance(image, torch.Tensor) and image.shape[1] == self.latent_channels: + image_latents = image + else: + image = self.image_processor.preprocess(image, height, width) + image = image.unsqueeze(2).to(device, dtype=self.vae.dtype) + image_latents = self._encode_vae_image(image=image, generator=generator) + + # 3.prepare prompt embeds + + if self.guider is not None: + guider = self.guider + else: + # distilled model does not use guidance method, use default guider with enabled=False + guider = AdaptiveProjectedMixGuidance(enabled=False) + + requires_unconditional_embeds = guider._enabled and guider.num_conditions > 1 + prompt_embeds, prompt_embeds_mask = self.encode_prompt( + prompt=prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + device=device, + batch_size=batch_size, + num_images_per_prompt=num_images_per_prompt, + ) + + prompt_embeds = prompt_embeds.to(self.transformer.dtype) + + if requires_unconditional_embeds: + ( + negative_prompt_embeds, + negative_prompt_embeds_mask, + ) = self.encode_prompt( + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + device=device, + batch_size=batch_size, + num_images_per_prompt=num_images_per_prompt, + ) + + negative_prompt_embeds = negative_prompt_embeds.to(self.transformer.dtype) + + # 4. Prepare latent variables + latents, cond_latents = self.prepare_latents( + image_latents=image_latents, + batch_size=batch_size * num_images_per_prompt, + num_channels_latents=self.latent_channels, + height=height, + width=width, + dtype=prompt_embeds.dtype, + device=device, + generator=generator, + latents=latents, + ) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance (this pipeline only supports guidance-distilled models) + if distilled_guidance_scale is None: + raise ValueError("`distilled_guidance_scale` is required for guidance-distilled model.") + guidance = ( + torch.tensor([distilled_guidance_scale] * latents.shape[0], dtype=self.transformer.dtype, device=device) + * 1000.0 + ) + + if self.attention_kwargs is None: + self._attention_kwargs = {} + + # 6. Denoising loop + self.scheduler.set_begin_index(0) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + latent_model_input = torch.cat([latents, cond_latents], dim=1).to(self.transformer.dtype) + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + # Step 1: Collect model inputs needed for the guidance method + # conditional inputs should always be first element in the tuple + guider_inputs = { + "encoder_hidden_states": (prompt_embeds, negative_prompt_embeds), + "encoder_attention_mask": (prompt_embeds_mask, negative_prompt_embeds_mask), + } + + # Step 2: Update guider's internal state for this denoising step + guider.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t) + + # Step 3: Prepare batched model inputs based on the guidance method + # The guider splits model inputs into separate batches for conditional/unconditional predictions. + # For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}: + # you will get a guider_state with two batches: + # guider_state = [ + # {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch + # {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch + # ] + # Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG). + guider_state = guider.prepare_inputs(guider_inputs) + + # Step 4: Run the denoiser for each batch + # Each batch in guider_state represents a different conditioning (conditional, unconditional, etc.). + # We run the model once per batch and store the noise prediction in guider_state_batch.noise_pred. + for guider_state_batch in guider_state: + guider.prepare_models(self.transformer) + + # Extract conditioning kwargs for this batch (e.g., encoder_hidden_states) + cond_kwargs = { + input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys() + } + + # e.g. "pred_cond"/"pred_uncond" + context_name = getattr(guider_state_batch, guider._identifier_key) + with self.transformer.cache_context(context_name): + # Run denoiser and store noise prediction in this batch + guider_state_batch.noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + guidance=guidance, + attention_kwargs=self.attention_kwargs, + return_dict=False, + **cond_kwargs, + )[0] + + # Cleanup model (e.g., remove hooks) + guider.cleanup_models(self.transformer) + + # Step 5: Combine predictions using the guidance method + # The guider takes all noise predictions from guider_state and combines them according to the guidance algorithm. + # Continuing the CFG example, the guider receives: + # guider_state = [ + # {"encoder_hidden_states": prompt_embeds, "noise_pred": noise_pred_cond, "__guidance_identifier__": "pred_cond"}, # batch 0 + # {"encoder_hidden_states": negative_prompt_embeds, "noise_pred": noise_pred_uncond, "__guidance_identifier__": "pred_uncond"}, # batch 1 + # ] + # And extracts predictions using the __guidance_identifier__: + # pred_cond = guider_state[0]["noise_pred"] # extracts noise_pred_cond + # pred_uncond = guider_state[1]["noise_pred"] # extracts noise_pred_uncond + # Then applies CFG formula: + # noise_pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond) + # Returns GuiderOutput(pred=noise_pred, pred_cond=pred_cond, pred_uncond=pred_uncond) + noise_pred = guider(guider_state)[0] + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + if output_type == "latent": + image = latents + else: + latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor + latents = self._restore_image_tokens_order(latents) + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image.squeeze(2), output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return HunyuanImagePipelineOutput(images=image) diff --git a/diffusers/pipelines/hunyuan_image/pipeline_output.py b/diffusers/pipelines/hunyuan_image/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..169436b7d86f78d4584396a42a2d3f1d0209d3a3 --- /dev/null +++ b/diffusers/pipelines/hunyuan_image/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import numpy as np +import PIL.Image + +from ...utils import BaseOutput + + +@dataclass +class HunyuanImagePipelineOutput(BaseOutput): + """ + Output class for HunyuanImage pipelines. + + Args: + images (`list[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: list[PIL.Image.Image, np.ndarray] diff --git a/diffusers/pipelines/hunyuan_video/__init__.py b/diffusers/pipelines/hunyuan_video/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d42d38fac9793db2c2fb94c4b41d3755be75d898 --- /dev/null +++ b/diffusers/pipelines/hunyuan_video/__init__.py @@ -0,0 +1,54 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_hunyuan_skyreels_image2video"] = ["HunyuanSkyreelsImageToVideoPipeline"] + _import_structure["pipeline_hunyuan_video"] = ["HunyuanVideoPipeline"] + _import_structure["pipeline_hunyuan_video_framepack"] = ["HunyuanVideoFramepackPipeline"] + _import_structure["pipeline_hunyuan_video_image2video"] = ["HunyuanVideoImageToVideoPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_hunyuan_skyreels_image2video import HunyuanSkyreelsImageToVideoPipeline + from .pipeline_hunyuan_video import HunyuanVideoPipeline + from .pipeline_hunyuan_video_framepack import HunyuanVideoFramepackPipeline + from .pipeline_hunyuan_video_image2video import HunyuanVideoImageToVideoPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py b/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py new file mode 100644 index 0000000000000000000000000000000000000000..1a7cae256d638ae2a02d23df33e9321319ab36c2 --- /dev/null +++ b/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py @@ -0,0 +1,832 @@ +# Copyright 2025 The HunyuanVideo Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import torch +from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput +from ...loaders import HunyuanVideoLoraLoaderMixin +from ...models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import HunyuanVideoPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import HunyuanSkyreelsImageToVideoPipeline, HunyuanVideoTransformer3DModel + >>> from diffusers.utils import load_image, export_to_video + + >>> model_id = "hunyuanvideo-community/HunyuanVideo" + >>> transformer_model_id = "Skywork/SkyReels-V1-Hunyuan-I2V" + >>> transformer = HunyuanVideoTransformer3DModel.from_pretrained( + ... transformer_model_id, torch_dtype=torch.bfloat16 + ... ) + >>> pipe = HunyuanSkyreelsImageToVideoPipeline.from_pretrained( + ... model_id, transformer=transformer, torch_dtype=torch.float16 + ... ) + >>> pipe.vae.enable_tiling() + >>> pipe.to("cuda") + + >>> prompt = "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." + >>> negative_prompt = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion" + >>> image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg" + ... ) + + >>> output = pipe( + ... image=image, + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... num_inference_steps=30, + ... true_cfg_scale=6.0, + ... guidance_scale=1.0, + ... ).frames[0] + >>> export_to_video(output, "output.mp4", fps=15) + ``` +""" + + +DEFAULT_PROMPT_TEMPLATE = { + "template": ( + "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: " + "1. The main content and theme of the video." + "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects." + "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects." + "4. background environment, light, style and atmosphere." + "5. camera angles, movements, and transitions used in the video:<|eot_id|>" + "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" + ), + "crop_start": 95, +} + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class HunyuanSkyreelsImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin): + r""" + Pipeline for image-to-video generation using HunyuanVideo. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + text_encoder ([`LlamaModel`]): + [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). + tokenizer (`LlamaTokenizer`): + Tokenizer from [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). + transformer ([`HunyuanVideoTransformer3DModel`]): + Conditional Transformer to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLHunyuanVideo`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder_2 ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer_2 (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + text_encoder: LlamaModel, + tokenizer: LlamaTokenizerFast, + transformer: HunyuanVideoTransformer3DModel, + vae: AutoencoderKLHunyuanVideo, + scheduler: FlowMatchEulerDiscreteScheduler, + text_encoder_2: CLIPTextModel, + tokenizer_2: CLIPTokenizer, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + ) + + self.vae_scale_factor_temporal = self.vae.temporal_compression_ratio if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = self.vae.spatial_compression_ratio if getattr(self, "vae", None) else 8 + self.vae_scaling_factor = self.vae.config.scaling_factor if getattr(self, "vae", None) else 0.476986 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + # Copied from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video.HunyuanVideoPipeline._get_llama_prompt_embeds + def _get_llama_prompt_embeds( + self, + prompt: str | list[str], + prompt_template: dict[str, Any], + num_videos_per_prompt: int = 1, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + max_sequence_length: int = 256, + num_hidden_layers_to_skip: int = 2, + ) -> tuple[torch.Tensor, torch.Tensor]: + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + prompt = [prompt_template["template"].format(p) for p in prompt] + + crop_start = prompt_template.get("crop_start", None) + if crop_start is None: + prompt_template_input = self.tokenizer( + prompt_template["template"], + padding="max_length", + return_tensors="pt", + return_length=False, + return_overflowing_tokens=False, + return_attention_mask=False, + ) + crop_start = prompt_template_input["input_ids"].shape[-1] + # Remove <|eot_id|> token and placeholder {} + crop_start -= 2 + + max_sequence_length += crop_start + text_inputs = self.tokenizer( + prompt, + max_length=max_sequence_length, + padding="max_length", + truncation=True, + return_tensors="pt", + return_length=False, + return_overflowing_tokens=False, + return_attention_mask=True, + ) + text_input_ids = text_inputs.input_ids.to(device=device) + prompt_attention_mask = text_inputs.attention_mask.to(device=device) + + prompt_embeds = self.text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_attention_mask, + output_hidden_states=True, + ).hidden_states[-(num_hidden_layers_to_skip + 1)] + prompt_embeds = prompt_embeds.to(dtype=dtype) + + if crop_start is not None and crop_start > 0: + prompt_embeds = prompt_embeds[:, crop_start:] + prompt_attention_mask = prompt_attention_mask[:, crop_start:] + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.repeat(1, num_videos_per_prompt) + prompt_attention_mask = prompt_attention_mask.view(batch_size * num_videos_per_prompt, seq_len) + + return prompt_embeds, prompt_attention_mask + + # Copied from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video.HunyuanVideoPipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( + self, + prompt: str | list[str], + num_videos_per_prompt: int = 1, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + max_sequence_length: int = 77, + ) -> torch.Tensor: + device = device or self._execution_device + dtype = dtype or self.text_encoder_2.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False).pooler_output + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video.HunyuanVideoPipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + prompt_2: str | list[str] = None, + prompt_template: dict[str, Any] = DEFAULT_PROMPT_TEMPLATE, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + max_sequence_length: int = 256, + ): + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_llama_prompt_embeds( + prompt, + prompt_template, + num_videos_per_prompt, + device=device, + dtype=dtype, + max_sequence_length=max_sequence_length, + ) + + if pooled_prompt_embeds is None: + if prompt_2 is None: + prompt_2 = prompt + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt, + num_videos_per_prompt, + device=device, + dtype=dtype, + max_sequence_length=77, + ) + + return prompt_embeds, pooled_prompt_embeds, prompt_attention_mask + + # Copied from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video.HunyuanVideoPipeline.check_inputs + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + prompt_template=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if prompt_template is not None: + if not isinstance(prompt_template, dict): + raise ValueError(f"`prompt_template` has to be of type `dict` but is {type(prompt_template)}") + if "template" not in prompt_template: + raise ValueError( + f"`prompt_template` has to contain a key `template` but only found {prompt_template.keys()}" + ) + + def prepare_latents( + self, + image: torch.Tensor, + batch_size: int, + num_channels_latents: int = 32, + height: int = 544, + width: int = 960, + num_frames: int = 97, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + image = image.unsqueeze(2) # [B, C, 1, H, W] + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i].unsqueeze(0)), generator[i]) for i in range(batch_size) + ] + else: + image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in image] + + image_latents = torch.cat(image_latents, dim=0).to(dtype) * self.vae_scaling_factor + + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + latent_height, latent_width = height // self.vae_scale_factor_spatial, width // self.vae_scale_factor_spatial + shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width) + padding_shape = (batch_size, num_channels_latents, num_latent_frames - 1, latent_height, latent_width) + + latents_padding = torch.zeros(padding_shape, dtype=dtype, device=device) + image_latents = torch.cat([image_latents, latents_padding], dim=2) + + if latents is None: + latents = randn_tensor(shape, generator=generator, dtype=dtype, device=device) + else: + latents = latents.to(dtype=dtype, device=device) + + return latents, image_latents + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." + deprecate( + "enable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." + deprecate( + "disable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." + deprecate( + "enable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." + deprecate( + "disable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.disable_tiling() + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput, + prompt: str | list[str] = None, + prompt_2: str | list[str] = None, + negative_prompt: str | list[str] = None, + negative_prompt_2: str | list[str] = None, + height: int = 544, + width: int = 960, + num_frames: int = 97, + num_inference_steps: int = 50, + sigmas: list[float] = None, + true_cfg_scale: float = 6.0, + guidance_scale: float = 1.0, + num_videos_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + prompt_template: dict[str, Any] = DEFAULT_PROMPT_TEMPLATE, + max_sequence_length: int = 256, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. + height (`int`, defaults to `720`): + The height in pixels of the generated image. + width (`int`, defaults to `1280`): + The width in pixels of the generated image. + num_frames (`int`, defaults to `129`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + true_cfg_scale (`float`, *optional*, defaults to 1.0): + When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance. + guidance_scale (`float`, defaults to `6.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. Note that the only available + HunyuanVideo model is CFG-distilled, which means that traditional guidance between unconditional and + conditional latent is not applied. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~HunyuanVideoPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned, otherwise a `tuple` is returned + where the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + prompt_embeds, + callback_on_step_end_tensor_inputs, + prompt_template, + ) + + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None + ) + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + transformer_dtype = self.transformer.dtype + prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_template=prompt_template, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + device=device, + max_sequence_length=max_sequence_length, + ) + prompt_embeds = prompt_embeds.to(transformer_dtype) + prompt_attention_mask = prompt_attention_mask.to(transformer_dtype) + pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype) + + if do_true_cfg: + negative_prompt_embeds, negative_pooled_prompt_embeds, negative_prompt_attention_mask = self.encode_prompt( + prompt=negative_prompt, + prompt_2=negative_prompt_2, + prompt_template=prompt_template, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=negative_pooled_prompt_embeds, + prompt_attention_mask=negative_prompt_attention_mask, + device=device, + max_sequence_length=max_sequence_length, + ) + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + negative_prompt_attention_mask = negative_prompt_attention_mask.to(transformer_dtype) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(transformer_dtype) + + # 4. Prepare timesteps + sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, sigmas=sigmas + ) + + # 5. Prepare latent variables + vae_dtype = self.vae.dtype + image = self.video_processor.preprocess(image, height=height, width=width).to(device, vae_dtype) + num_channels_latents = self.transformer.config.in_channels // 2 + latents, image_latents = self.prepare_latents( + image, + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + ) + latent_image_input = image_latents.to(transformer_dtype) + + # 6. Prepare guidance condition + guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0 + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + latent_model_input = latents.to(transformer_dtype) + latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=1) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + pooled_projections=pooled_prompt_embeds, + guidance=guidance, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if do_true_cfg: + neg_noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + encoder_attention_mask=negative_prompt_attention_mask, + pooled_projections=negative_pooled_prompt_embeds, + guidance=guidance, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return HunyuanVideoPipelineOutput(frames=video) diff --git a/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py new file mode 100644 index 0000000000000000000000000000000000000000..3c6ec39398efd65f060b7e20fae2485d92eeae23 --- /dev/null +++ b/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py @@ -0,0 +1,783 @@ +# Copyright 2025 The HunyuanVideo Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import torch +from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import HunyuanVideoLoraLoaderMixin +from ...models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import HunyuanVideoPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel + >>> from diffusers.utils import export_to_video + + >>> model_id = "hunyuanvideo-community/HunyuanVideo" + >>> transformer = HunyuanVideoTransformer3DModel.from_pretrained( + ... model_id, subfolder="transformer", torch_dtype=torch.bfloat16 + ... ) + >>> pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.float16) + >>> pipe.vae.enable_tiling() + >>> pipe.to("cuda") + + >>> output = pipe( + ... prompt="A cat walks on the grass, realistic", + ... height=320, + ... width=512, + ... num_frames=61, + ... num_inference_steps=30, + ... ).frames[0] + >>> export_to_video(output, "output.mp4", fps=15) + ``` +""" + + +DEFAULT_PROMPT_TEMPLATE = { + "template": ( + "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: " + "1. The main content and theme of the video." + "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects." + "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects." + "4. background environment, light, style and atmosphere." + "5. camera angles, movements, and transitions used in the video:<|eot_id|>" + "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" + ), + "crop_start": 95, +} + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin): + r""" + Pipeline for text-to-video generation using HunyuanVideo. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + text_encoder ([`LlamaModel`]): + [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). + tokenizer (`LlamaTokenizer`): + Tokenizer from [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). + transformer ([`HunyuanVideoTransformer3DModel`]): + Conditional Transformer to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLHunyuanVideo`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder_2 ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer_2 (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + text_encoder: LlamaModel, + tokenizer: LlamaTokenizerFast, + transformer: HunyuanVideoTransformer3DModel, + vae: AutoencoderKLHunyuanVideo, + scheduler: FlowMatchEulerDiscreteScheduler, + text_encoder_2: CLIPTextModel, + tokenizer_2: CLIPTokenizer, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + ) + + self.vae_scale_factor_temporal = self.vae.temporal_compression_ratio if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = self.vae.spatial_compression_ratio if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + def _get_llama_prompt_embeds( + self, + prompt: str | list[str], + prompt_template: dict[str, Any], + num_videos_per_prompt: int = 1, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + max_sequence_length: int = 256, + num_hidden_layers_to_skip: int = 2, + ) -> tuple[torch.Tensor, torch.Tensor]: + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + prompt = [prompt_template["template"].format(p) for p in prompt] + + crop_start = prompt_template.get("crop_start", None) + if crop_start is None: + prompt_template_input = self.tokenizer( + prompt_template["template"], + padding="max_length", + return_tensors="pt", + return_length=False, + return_overflowing_tokens=False, + return_attention_mask=False, + ) + crop_start = prompt_template_input["input_ids"].shape[-1] + # Remove <|eot_id|> token and placeholder {} + crop_start -= 2 + + max_sequence_length += crop_start + text_inputs = self.tokenizer( + prompt, + max_length=max_sequence_length, + padding="max_length", + truncation=True, + return_tensors="pt", + return_length=False, + return_overflowing_tokens=False, + return_attention_mask=True, + ) + text_input_ids = text_inputs.input_ids.to(device=device) + prompt_attention_mask = text_inputs.attention_mask.to(device=device) + + prompt_embeds = self.text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_attention_mask, + output_hidden_states=True, + ).hidden_states[-(num_hidden_layers_to_skip + 1)] + prompt_embeds = prompt_embeds.to(dtype=dtype) + + if crop_start is not None and crop_start > 0: + prompt_embeds = prompt_embeds[:, crop_start:] + prompt_attention_mask = prompt_attention_mask[:, crop_start:] + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.repeat(1, num_videos_per_prompt) + prompt_attention_mask = prompt_attention_mask.view(batch_size * num_videos_per_prompt, seq_len) + + return prompt_embeds, prompt_attention_mask + + def _get_clip_prompt_embeds( + self, + prompt: str | list[str], + num_videos_per_prompt: int = 1, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + max_sequence_length: int = 77, + ) -> torch.Tensor: + device = device or self._execution_device + dtype = dtype or self.text_encoder_2.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False).pooler_output + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: str | list[str], + prompt_2: str | list[str] = None, + prompt_template: dict[str, Any] = DEFAULT_PROMPT_TEMPLATE, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + max_sequence_length: int = 256, + ): + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_llama_prompt_embeds( + prompt, + prompt_template, + num_videos_per_prompt, + device=device, + dtype=dtype, + max_sequence_length=max_sequence_length, + ) + + if pooled_prompt_embeds is None: + if prompt_2 is None: + prompt_2 = prompt + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt, + num_videos_per_prompt, + device=device, + dtype=dtype, + max_sequence_length=77, + ) + + return prompt_embeds, pooled_prompt_embeds, prompt_attention_mask + + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + prompt_template=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if prompt_template is not None: + if not isinstance(prompt_template, dict): + raise ValueError(f"`prompt_template` has to be of type `dict` but is {type(prompt_template)}") + if "template" not in prompt_template: + raise ValueError( + f"`prompt_template` has to contain a key `template` but only found {prompt_template.keys()}" + ) + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int = 32, + height: int = 720, + width: int = 1280, + num_frames: int = 129, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + shape = ( + batch_size, + num_channels_latents, + (num_frames - 1) // self.vae_scale_factor_temporal + 1, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." + deprecate( + "enable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." + deprecate( + "disable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." + deprecate( + "enable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." + deprecate( + "disable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.disable_tiling() + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + prompt_2: str | list[str] = None, + negative_prompt: str | list[str] = None, + negative_prompt_2: str | list[str] = None, + height: int = 720, + width: int = 1280, + num_frames: int = 129, + num_inference_steps: int = 50, + sigmas: list[float] = None, + true_cfg_scale: float = 1.0, + guidance_scale: float = 6.0, + num_videos_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + prompt_template: dict[str, Any] = DEFAULT_PROMPT_TEMPLATE, + max_sequence_length: int = 256, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. + height (`int`, defaults to `720`): + The height in pixels of the generated image. + width (`int`, defaults to `1280`): + The width in pixels of the generated image. + num_frames (`int`, defaults to `129`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + true_cfg_scale (`float`, *optional*, defaults to 1.0): + True classifier-free guidance (guidance scale) is enabled when `true_cfg_scale` > 1 and + `negative_prompt` is provided. + guidance_scale (`float`, defaults to `6.0`): + Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages + a model to generate images more aligned with `prompt` at the expense of lower image quality. + + Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to + the [paper](https://huggingface.co/papers/2210.03142) to learn more. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~HunyuanVideoPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned, otherwise a `tuple` is returned + where the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + prompt_embeds, + callback_on_step_end_tensor_inputs, + prompt_template, + ) + + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None + ) + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + transformer_dtype = self.transformer.dtype + prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_template=prompt_template, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + device=device, + max_sequence_length=max_sequence_length, + ) + prompt_embeds = prompt_embeds.to(transformer_dtype) + prompt_attention_mask = prompt_attention_mask.to(transformer_dtype) + pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype) + + if do_true_cfg: + negative_prompt_embeds, negative_pooled_prompt_embeds, negative_prompt_attention_mask = self.encode_prompt( + prompt=negative_prompt, + prompt_2=negative_prompt_2, + prompt_template=prompt_template, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=negative_pooled_prompt_embeds, + prompt_attention_mask=negative_prompt_attention_mask, + device=device, + max_sequence_length=max_sequence_length, + ) + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + negative_prompt_attention_mask = negative_prompt_attention_mask.to(transformer_dtype) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(transformer_dtype) + + # 4. Prepare timesteps + sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, sigmas=sigmas + ) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + ) + + # 6. Prepare guidance condition + guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0 + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + latent_model_input = latents.to(transformer_dtype) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + pooled_projections=pooled_prompt_embeds, + guidance=guidance, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if do_true_cfg: + with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + encoder_attention_mask=negative_prompt_attention_mask, + pooled_projections=negative_pooled_prompt_embeds, + guidance=guidance, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return HunyuanVideoPipelineOutput(frames=video) diff --git a/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py b/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py new file mode 100644 index 0000000000000000000000000000000000000000..f82f26eea5b9a723b79039376480d49e2517ffbb --- /dev/null +++ b/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py @@ -0,0 +1,1134 @@ +# Copyright 2025 The Framepack Team, The HunyuanVideo Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from enum import Enum +from typing import Any, Callable + +import numpy as np +import torch +from transformers import ( + CLIPTextModel, + CLIPTokenizer, + LlamaModel, + LlamaTokenizerFast, + SiglipImageProcessor, + SiglipVisionModel, +) + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput +from ...loaders import HunyuanVideoLoraLoaderMixin +from ...models import AutoencoderKLHunyuanVideo, HunyuanVideoFramepackTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import HunyuanVideoFramepackPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# TODO(yiyi): We can pack the checkpoints nicely with modular loader +EXAMPLE_DOC_STRING = """ + Examples: + ##### Image-to-Video + + ```python + >>> import torch + >>> from diffusers import HunyuanVideoFramepackPipeline, HunyuanVideoFramepackTransformer3DModel + >>> from diffusers.utils import export_to_video, load_image + >>> from transformers import SiglipImageProcessor, SiglipVisionModel + + >>> transformer = HunyuanVideoFramepackTransformer3DModel.from_pretrained( + ... "lllyasviel/FramePackI2V_HY", torch_dtype=torch.bfloat16 + ... ) + >>> feature_extractor = SiglipImageProcessor.from_pretrained( + ... "lllyasviel/flux_redux_bfl", subfolder="feature_extractor" + ... ) + >>> image_encoder = SiglipVisionModel.from_pretrained( + ... "lllyasviel/flux_redux_bfl", subfolder="image_encoder", torch_dtype=torch.float16 + ... ) + >>> pipe = HunyuanVideoFramepackPipeline.from_pretrained( + ... "hunyuanvideo-community/HunyuanVideo", + ... transformer=transformer, + ... feature_extractor=feature_extractor, + ... image_encoder=image_encoder, + ... torch_dtype=torch.float16, + ... ) + >>> pipe.vae.enable_tiling() + >>> pipe.to("cuda") + + >>> image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/penguin.png" + ... ) + >>> output = pipe( + ... image=image, + ... prompt="A penguin dancing in the snow", + ... height=832, + ... width=480, + ... num_frames=91, + ... num_inference_steps=30, + ... guidance_scale=9.0, + ... generator=torch.Generator().manual_seed(0), + ... sampling_type="inverted_anti_drifting", + ... ).frames[0] + >>> export_to_video(output, "output.mp4", fps=30) + ``` + + ##### First and Last Image-to-Video + + ```python + >>> import torch + >>> from diffusers import HunyuanVideoFramepackPipeline, HunyuanVideoFramepackTransformer3DModel + >>> from diffusers.utils import export_to_video, load_image + >>> from transformers import SiglipImageProcessor, SiglipVisionModel + + >>> transformer = HunyuanVideoFramepackTransformer3DModel.from_pretrained( + ... "lllyasviel/FramePackI2V_HY", torch_dtype=torch.bfloat16 + ... ) + >>> feature_extractor = SiglipImageProcessor.from_pretrained( + ... "lllyasviel/flux_redux_bfl", subfolder="feature_extractor" + ... ) + >>> image_encoder = SiglipVisionModel.from_pretrained( + ... "lllyasviel/flux_redux_bfl", subfolder="image_encoder", torch_dtype=torch.float16 + ... ) + >>> pipe = HunyuanVideoFramepackPipeline.from_pretrained( + ... "hunyuanvideo-community/HunyuanVideo", + ... transformer=transformer, + ... feature_extractor=feature_extractor, + ... image_encoder=image_encoder, + ... torch_dtype=torch.float16, + ... ) + >>> pipe.to("cuda") + + >>> prompt = "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird's feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective." + >>> first_image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_first_frame.png" + ... ) + >>> last_image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_last_frame.png" + ... ) + >>> output = pipe( + ... image=first_image, + ... last_image=last_image, + ... prompt=prompt, + ... height=512, + ... width=512, + ... num_frames=91, + ... num_inference_steps=30, + ... guidance_scale=9.0, + ... generator=torch.Generator().manual_seed(0), + ... sampling_type="inverted_anti_drifting", + ... ).frames[0] + >>> export_to_video(output, "output.mp4", fps=30) + ``` +""" + + +DEFAULT_PROMPT_TEMPLATE = { + "template": ( + "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: " + "1. The main content and theme of the video." + "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects." + "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects." + "4. background environment, light, style and atmosphere." + "5. camera angles, movements, and transitions used in the video:<|eot_id|>" + "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" + ), + "crop_start": 95, +} + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class FramepackSamplingType(str, Enum): + VANILLA = "vanilla" + INVERTED_ANTI_DRIFTING = "inverted_anti_drifting" + + +class HunyuanVideoFramepackPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin): + r""" + Pipeline for text-to-video generation using HunyuanVideo. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + text_encoder ([`LlamaModel`]): + [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). + tokenizer (`LlamaTokenizer`): + Tokenizer from [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). + transformer ([`HunyuanVideoTransformer3DModel`]): + Conditional Transformer to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLHunyuanVideo`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder_2 ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer_2 (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + text_encoder: LlamaModel, + tokenizer: LlamaTokenizerFast, + transformer: HunyuanVideoFramepackTransformer3DModel, + vae: AutoencoderKLHunyuanVideo, + scheduler: FlowMatchEulerDiscreteScheduler, + text_encoder_2: CLIPTextModel, + tokenizer_2: CLIPTokenizer, + image_encoder: SiglipVisionModel, + feature_extractor: SiglipImageProcessor, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + + self.vae_scale_factor_temporal = self.vae.temporal_compression_ratio if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = self.vae.spatial_compression_ratio if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + # Copied from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video.HunyuanVideoPipeline._get_llama_prompt_embeds + def _get_llama_prompt_embeds( + self, + prompt: str | list[str], + prompt_template: dict[str, Any], + num_videos_per_prompt: int = 1, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + max_sequence_length: int = 256, + num_hidden_layers_to_skip: int = 2, + ) -> tuple[torch.Tensor, torch.Tensor]: + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + prompt = [prompt_template["template"].format(p) for p in prompt] + + crop_start = prompt_template.get("crop_start", None) + if crop_start is None: + prompt_template_input = self.tokenizer( + prompt_template["template"], + padding="max_length", + return_tensors="pt", + return_length=False, + return_overflowing_tokens=False, + return_attention_mask=False, + ) + crop_start = prompt_template_input["input_ids"].shape[-1] + # Remove <|eot_id|> token and placeholder {} + crop_start -= 2 + + max_sequence_length += crop_start + text_inputs = self.tokenizer( + prompt, + max_length=max_sequence_length, + padding="max_length", + truncation=True, + return_tensors="pt", + return_length=False, + return_overflowing_tokens=False, + return_attention_mask=True, + ) + text_input_ids = text_inputs.input_ids.to(device=device) + prompt_attention_mask = text_inputs.attention_mask.to(device=device) + + prompt_embeds = self.text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_attention_mask, + output_hidden_states=True, + ).hidden_states[-(num_hidden_layers_to_skip + 1)] + prompt_embeds = prompt_embeds.to(dtype=dtype) + + if crop_start is not None and crop_start > 0: + prompt_embeds = prompt_embeds[:, crop_start:] + prompt_attention_mask = prompt_attention_mask[:, crop_start:] + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.repeat(1, num_videos_per_prompt) + prompt_attention_mask = prompt_attention_mask.view(batch_size * num_videos_per_prompt, seq_len) + + return prompt_embeds, prompt_attention_mask + + # Copied from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video.HunyuanVideoPipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( + self, + prompt: str | list[str], + num_videos_per_prompt: int = 1, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + max_sequence_length: int = 77, + ) -> torch.Tensor: + device = device or self._execution_device + dtype = dtype or self.text_encoder_2.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False).pooler_output + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video.HunyuanVideoPipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + prompt_2: str | list[str] = None, + prompt_template: dict[str, Any] = DEFAULT_PROMPT_TEMPLATE, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + max_sequence_length: int = 256, + ): + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_llama_prompt_embeds( + prompt, + prompt_template, + num_videos_per_prompt, + device=device, + dtype=dtype, + max_sequence_length=max_sequence_length, + ) + + if pooled_prompt_embeds is None: + if prompt_2 is None: + prompt_2 = prompt + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt, + num_videos_per_prompt, + device=device, + dtype=dtype, + max_sequence_length=77, + ) + + return prompt_embeds, pooled_prompt_embeds, prompt_attention_mask + + def encode_image(self, image: torch.Tensor, device: torch.device | None = None, dtype: torch.dtype | None = None): + device = device or self._execution_device + image = (image + 1) / 2.0 # [-1, 1] -> [0, 1] + image = self.feature_extractor(images=image, return_tensors="pt", do_rescale=False).to( + device=device, dtype=self.image_encoder.dtype + ) + image_embeds = self.image_encoder(**image).last_hidden_state + return image_embeds.to(dtype=dtype) + + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + prompt_template=None, + image=None, + image_latents=None, + last_image=None, + last_image_latents=None, + sampling_type=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if prompt_template is not None: + if not isinstance(prompt_template, dict): + raise ValueError(f"`prompt_template` has to be of type `dict` but is {type(prompt_template)}") + if "template" not in prompt_template: + raise ValueError( + f"`prompt_template` has to contain a key `template` but only found {prompt_template.keys()}" + ) + + sampling_types = [x.value for x in FramepackSamplingType.__members__.values()] + if sampling_type not in sampling_types: + raise ValueError(f"`sampling_type` has to be one of '{sampling_types}' but is '{sampling_type}'") + + if image is not None and image_latents is not None: + raise ValueError("Only one of `image` or `image_latents` can be passed.") + if last_image is not None and last_image_latents is not None: + raise ValueError("Only one of `last_image` or `last_image_latents` can be passed.") + if sampling_type != FramepackSamplingType.INVERTED_ANTI_DRIFTING and ( + last_image is not None or last_image_latents is not None + ): + raise ValueError( + 'Only `"inverted_anti_drifting"` inference type supports `last_image` or `last_image_latents`.' + ) + + def prepare_latents( + self, + batch_size: int = 1, + num_channels_latents: int = 16, + height: int = 720, + width: int = 1280, + num_frames: int = 129, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + shape = ( + batch_size, + num_channels_latents, + (num_frames - 1) // self.vae_scale_factor_temporal + 1, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + def prepare_image_latents( + self, + image: torch.Tensor, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + device = device or self._execution_device + if latents is None: + image = image.unsqueeze(2).to(device=device, dtype=self.vae.dtype) + latents = self.vae.encode(image).latent_dist.sample(generator=generator) + latents = latents * self.vae.config.scaling_factor + return latents.to(device=device, dtype=dtype) + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." + deprecate( + "enable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." + deprecate( + "disable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." + deprecate( + "enable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." + deprecate( + "disable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.disable_tiling() + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput, + last_image: PipelineImageInput | None = None, + prompt: str | list[str] = None, + prompt_2: str | list[str] = None, + negative_prompt: str | list[str] = None, + negative_prompt_2: str | list[str] = None, + height: int = 720, + width: int = 1280, + num_frames: int = 129, + latent_window_size: int = 9, + num_inference_steps: int = 50, + sigmas: list[float] = None, + true_cfg_scale: float = 1.0, + guidance_scale: float = 6.0, + num_videos_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + image_latents: torch.Tensor | None = None, + last_image_latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + prompt_template: dict[str, Any] = DEFAULT_PROMPT_TEMPLATE, + max_sequence_length: int = 256, + sampling_type: FramepackSamplingType = FramepackSamplingType.INVERTED_ANTI_DRIFTING, + ): + r""" + The call function to the pipeline for generation. + + Args: + image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`): + The image to be used as the starting point for the video generation. + last_image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`, *optional*): + The optional last image to be used as the ending point for the video generation. This is useful for + generating transitions between two images. + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. + height (`int`, defaults to `720`): + The height in pixels of the generated image. + width (`int`, defaults to `1280`): + The width in pixels of the generated image. + num_frames (`int`, defaults to `129`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + true_cfg_scale (`float`, *optional*, defaults to 1.0): + When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance. + guidance_scale (`float`, defaults to `6.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. Note that the only available + HunyuanVideo model is CFG-distilled, which means that traditional guidance between unconditional and + conditional latent is not applied. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + image_latents (`torch.Tensor`, *optional*): + Pre-encoded image latents. If not provided, the image will be encoded using the VAE. + last_image_latents (`torch.Tensor`, *optional*): + Pre-encoded last image latents. If not provided, the last image will be encoded using the VAE. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`HunyuanVideoFramepackPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~HunyuanVideoFramepackPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`HunyuanVideoFramepackPipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images and the second element is a list + of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) + content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + prompt_embeds, + callback_on_step_end_tensor_inputs, + prompt_template, + image, + image_latents, + last_image, + last_image_latents, + sampling_type, + ) + + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None + ) + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + transformer_dtype = self.transformer.dtype + vae_dtype = self.vae.dtype + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + transformer_dtype = self.transformer.dtype + prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_template=prompt_template, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + device=device, + max_sequence_length=max_sequence_length, + ) + prompt_embeds = prompt_embeds.to(transformer_dtype) + prompt_attention_mask = prompt_attention_mask.to(transformer_dtype) + pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype) + + if do_true_cfg: + negative_prompt_embeds, negative_pooled_prompt_embeds, negative_prompt_attention_mask = self.encode_prompt( + prompt=negative_prompt, + prompt_2=negative_prompt_2, + prompt_template=prompt_template, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=negative_pooled_prompt_embeds, + prompt_attention_mask=negative_prompt_attention_mask, + device=device, + max_sequence_length=max_sequence_length, + ) + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + negative_prompt_attention_mask = negative_prompt_attention_mask.to(transformer_dtype) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(transformer_dtype) + + # 4. Prepare image + image = self.video_processor.preprocess(image, height, width) + image_embeds = self.encode_image(image, device=device).to(transformer_dtype) + if last_image is not None: + # Credits: https://github.com/lllyasviel/FramePack/pull/167 + # Users can modify the weighting strategy applied here + last_image = self.video_processor.preprocess(last_image, height, width) + last_image_embeds = self.encode_image(last_image, device=device).to(transformer_dtype) + last_image_embeds = (image_embeds + last_image_embeds) / 2 + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + window_num_frames = (latent_window_size - 1) * self.vae_scale_factor_temporal + 1 + num_latent_sections = max(1, (num_frames + window_num_frames - 1) // window_num_frames) + history_video = None + total_generated_latent_frames = 0 + + image_latents = self.prepare_image_latents( + image, dtype=torch.float32, device=device, generator=generator, latents=image_latents + ) + if last_image is not None: + last_image_latents = self.prepare_image_latents( + last_image, dtype=torch.float32, device=device, generator=generator + ) + + # Specific to the released checkpoints: + # - https://huggingface.co/lllyasviel/FramePackI2V_HY + # - https://huggingface.co/lllyasviel/FramePack_F1_I2V_HY_20250503 + # TODO: find a more generic way in future if there are more checkpoints + if sampling_type == FramepackSamplingType.INVERTED_ANTI_DRIFTING: + history_sizes = [1, 2, 16] + history_latents = torch.zeros( + batch_size, + num_channels_latents, + sum(history_sizes), + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + device=device, + dtype=torch.float32, + ) + + elif sampling_type == FramepackSamplingType.VANILLA: + history_sizes = [16, 2, 1] + history_latents = torch.zeros( + batch_size, + num_channels_latents, + sum(history_sizes), + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + device=device, + dtype=torch.float32, + ) + history_latents = torch.cat([history_latents, image_latents], dim=2) + total_generated_latent_frames += 1 + + else: + assert False + + # 6. Prepare guidance condition + guidance = torch.tensor([guidance_scale] * batch_size, dtype=transformer_dtype, device=device) * 1000.0 + + # 7. Denoising loop + for k in range(num_latent_sections): + if sampling_type == FramepackSamplingType.INVERTED_ANTI_DRIFTING: + latent_paddings = list(reversed(range(num_latent_sections))) + if num_latent_sections > 4: + latent_paddings = [3] + [2] * (num_latent_sections - 3) + [1, 0] + + is_first_section = k == 0 + is_last_section = k == num_latent_sections - 1 + latent_padding_size = latent_paddings[k] * latent_window_size + + indices = torch.arange(0, sum([1, latent_padding_size, latent_window_size, *history_sizes])) + ( + indices_prefix, + indices_padding, + indices_latents, + indices_latents_history_1x, + indices_latents_history_2x, + indices_latents_history_4x, + ) = indices.split([1, latent_padding_size, latent_window_size, *history_sizes], dim=0) + # Inverted anti-drifting sampling: Figure 2(c) in the paper + indices_clean_latents = torch.cat([indices_prefix, indices_latents_history_1x], dim=0) + + latents_prefix = image_latents + latents_history_1x, latents_history_2x, latents_history_4x = history_latents[ + :, :, : sum(history_sizes) + ].split(history_sizes, dim=2) + if last_image is not None and is_first_section: + latents_history_1x = last_image_latents + latents_clean = torch.cat([latents_prefix, latents_history_1x], dim=2) + + elif sampling_type == FramepackSamplingType.VANILLA: + indices = torch.arange(0, sum([1, *history_sizes, latent_window_size])) + ( + indices_prefix, + indices_latents_history_4x, + indices_latents_history_2x, + indices_latents_history_1x, + indices_latents, + ) = indices.split([1, *history_sizes, latent_window_size], dim=0) + indices_clean_latents = torch.cat([indices_prefix, indices_latents_history_1x], dim=0) + + latents_prefix = image_latents + latents_history_4x, latents_history_2x, latents_history_1x = history_latents[ + :, :, -sum(history_sizes) : + ].split(history_sizes, dim=2) + latents_clean = torch.cat([latents_prefix, latents_history_1x], dim=2) + + else: + assert False + + latents = self.prepare_latents( + batch_size, + num_channels_latents, + height, + width, + window_num_frames, + dtype=torch.float32, + device=device, + generator=generator, + latents=None, + ) + + sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas + image_seq_len = ( + latents.shape[2] * latents.shape[3] * latents.shape[4] / self.transformer.config.patch_size**2 + ) + exp_max = 7.0 + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + mu = min(mu, math.log(exp_max)) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, sigmas=sigmas, mu=mu + ) + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + timestep = t.expand(latents.shape[0]) + + noise_pred = self.transformer( + hidden_states=latents.to(transformer_dtype), + timestep=timestep, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + pooled_projections=pooled_prompt_embeds, + image_embeds=image_embeds, + indices_latents=indices_latents, + guidance=guidance, + latents_clean=latents_clean.to(transformer_dtype), + indices_latents_clean=indices_clean_latents, + latents_history_2x=latents_history_2x.to(transformer_dtype), + indices_latents_history_2x=indices_latents_history_2x, + latents_history_4x=latents_history_4x.to(transformer_dtype), + indices_latents_history_4x=indices_latents_history_4x, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if do_true_cfg: + neg_noise_pred = self.transformer( + hidden_states=latents.to(transformer_dtype), + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + encoder_attention_mask=negative_prompt_attention_mask, + pooled_projections=negative_pooled_prompt_embeds, + image_embeds=image_embeds, + indices_latents=indices_latents, + guidance=guidance, + latents_clean=latents_clean.to(transformer_dtype), + indices_latents_clean=indices_clean_latents, + latents_history_2x=latents_history_2x.to(transformer_dtype), + indices_latents_history_2x=indices_latents_history_2x, + latents_history_4x=latents_history_4x.to(transformer_dtype), + indices_latents_history_4x=indices_latents_history_4x, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred.float(), t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if sampling_type == FramepackSamplingType.INVERTED_ANTI_DRIFTING: + if is_last_section: + latents = torch.cat([image_latents, latents], dim=2) + total_generated_latent_frames += latents.shape[2] + history_latents = torch.cat([latents, history_latents], dim=2) + real_history_latents = history_latents[:, :, :total_generated_latent_frames] + section_latent_frames = ( + (latent_window_size * 2 + 1) if is_last_section else (latent_window_size * 2) + ) + index_slice = (slice(None), slice(None), slice(0, section_latent_frames)) + + elif sampling_type == FramepackSamplingType.VANILLA: + total_generated_latent_frames += latents.shape[2] + history_latents = torch.cat([history_latents, latents], dim=2) + real_history_latents = history_latents[:, :, -total_generated_latent_frames:] + section_latent_frames = latent_window_size * 2 + index_slice = (slice(None), slice(None), slice(-section_latent_frames, None)) + + else: + assert False + + if history_video is None: + if not output_type == "latent": + current_latents = real_history_latents.to(vae_dtype) / self.vae.config.scaling_factor + history_video = self.vae.decode(current_latents, return_dict=False)[0] + else: + history_video = [real_history_latents] + else: + if not output_type == "latent": + overlapped_frames = (latent_window_size - 1) * self.vae_scale_factor_temporal + 1 + current_latents = ( + real_history_latents[index_slice].to(vae_dtype) / self.vae.config.scaling_factor + ) + current_video = self.vae.decode(current_latents, return_dict=False)[0] + + if sampling_type == FramepackSamplingType.INVERTED_ANTI_DRIFTING: + history_video = self._soft_append(current_video, history_video, overlapped_frames) + elif sampling_type == FramepackSamplingType.VANILLA: + history_video = self._soft_append(history_video, current_video, overlapped_frames) + else: + assert False + else: + history_video.append(real_history_latents) + + self._current_timestep = None + + if not output_type == "latent": + generated_frames = history_video.size(2) + generated_frames = ( + generated_frames - 1 + ) // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + history_video = history_video[:, :, :generated_frames] + video = self.video_processor.postprocess_video(history_video, output_type=output_type) + else: + video = history_video + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return HunyuanVideoFramepackPipelineOutput(frames=video) + + def _soft_append(self, history: torch.Tensor, current: torch.Tensor, overlap: int = 0): + if overlap <= 0: + return torch.cat([history, current], dim=2) + + assert history.shape[2] >= overlap, f"Current length ({history.shape[2]}) must be >= overlap ({overlap})" + assert current.shape[2] >= overlap, f"History length ({current.shape[2]}) must be >= overlap ({overlap})" + + weights = torch.linspace(1, 0, overlap, dtype=history.dtype, device=history.device).view(1, 1, -1, 1, 1) + blended = weights * history[:, :, -overlap:] + (1 - weights) * current[:, :, :overlap] + output = torch.cat([history[:, :, :-overlap], blended, current[:, :, overlap:]], dim=2) + + return output.to(history) diff --git a/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py b/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py new file mode 100644 index 0000000000000000000000000000000000000000..c599488c2379ee2ef436f74fae72efe52680f869 --- /dev/null +++ b/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py @@ -0,0 +1,1002 @@ +# Copyright 2025 The HunyuanVideo Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import PIL.Image +import torch +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTokenizer, + LlamaTokenizerFast, + LlavaForConditionalGeneration, +) + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import HunyuanVideoLoraLoaderMixin +from ...models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import HunyuanVideoPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import HunyuanVideoImageToVideoPipeline, HunyuanVideoTransformer3DModel + >>> from diffusers.utils import load_image, export_to_video + + >>> # Available checkpoints: hunyuanvideo-community/HunyuanVideo-I2V, hunyuanvideo-community/HunyuanVideo-I2V-33ch + >>> model_id = "hunyuanvideo-community/HunyuanVideo-I2V" + >>> transformer = HunyuanVideoTransformer3DModel.from_pretrained( + ... model_id, subfolder="transformer", torch_dtype=torch.bfloat16 + ... ) + >>> pipe = HunyuanVideoImageToVideoPipeline.from_pretrained( + ... model_id, transformer=transformer, torch_dtype=torch.float16 + ... ) + >>> pipe.vae.enable_tiling() + >>> pipe.to("cuda") + + >>> prompt = "A man with short gray hair plays a red electric guitar." + >>> image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/guitar-man.png" + ... ) + + >>> # If using hunyuanvideo-community/HunyuanVideo-I2V + >>> output = pipe(image=image, prompt=prompt, guidance_scale=6.0).frames[0] + + >>> # If using hunyuanvideo-community/HunyuanVideo-I2V-33ch + >>> output = pipe(image=image, prompt=prompt, guidance_scale=1.0, true_cfg_scale=1.0).frames[0] + + >>> export_to_video(output, "output.mp4", fps=15) + ``` +""" + + +DEFAULT_PROMPT_TEMPLATE = { + "template": ( + "<|start_header_id|>system<|end_header_id|>\n\n\nDescribe the video by detailing the following aspects according to the reference image: " + "1. The main content and theme of the video." + "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects." + "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects." + "4. background environment, light, style and atmosphere." + "5. camera angles, movements, and transitions used in the video:<|eot_id|>\n\n" + "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n\n" + ), + "crop_start": 103, + "image_emb_start": 5, + "image_emb_end": 581, + "image_emb_len": 576, + "double_return_token_id": 271, +} + + +def _expand_input_ids_with_image_tokens( + text_input_ids, + prompt_attention_mask, + max_sequence_length, + image_token_index, + image_emb_len, + image_emb_start, + image_emb_end, + pad_token_id, +): + special_image_token_mask = text_input_ids == image_token_index + num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) + batch_indices, non_image_indices = torch.where(text_input_ids != image_token_index) + + max_expanded_length = max_sequence_length + (num_special_image_tokens.max() * (image_emb_len - 1)) + new_token_positions = torch.cumsum((special_image_token_mask * (image_emb_len - 1) + 1), -1) - 1 + text_to_overwrite = new_token_positions[batch_indices, non_image_indices] + + expanded_input_ids = torch.full( + (text_input_ids.shape[0], max_expanded_length), + pad_token_id, + dtype=text_input_ids.dtype, + device=text_input_ids.device, + ) + expanded_input_ids[batch_indices, text_to_overwrite] = text_input_ids[batch_indices, non_image_indices] + expanded_input_ids[batch_indices, image_emb_start:image_emb_end] = image_token_index + + expanded_attention_mask = torch.zeros( + (text_input_ids.shape[0], max_expanded_length), + dtype=prompt_attention_mask.dtype, + device=prompt_attention_mask.device, + ) + attn_batch_indices, attention_indices = torch.where(expanded_input_ids != pad_token_id) + expanded_attention_mask[attn_batch_indices, attention_indices] = 1.0 + expanded_attention_mask = expanded_attention_mask.to(prompt_attention_mask.dtype) + position_ids = (expanded_attention_mask.cumsum(-1) - 1).masked_fill_((expanded_attention_mask == 0), 1) + + return { + "input_ids": expanded_input_ids, + "attention_mask": expanded_attention_mask, + "position_ids": position_ids, + } + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin): + r""" + Pipeline for image-to-video generation using HunyuanVideo. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + text_encoder ([`LlavaForConditionalGeneration`]): + [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). + tokenizer (`LlamaTokenizer`): + Tokenizer from [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). + transformer ([`HunyuanVideoTransformer3DModel`]): + Conditional Transformer to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLHunyuanVideo`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder_2 ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer_2 (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + text_encoder: LlavaForConditionalGeneration, + tokenizer: LlamaTokenizerFast, + transformer: HunyuanVideoTransformer3DModel, + vae: AutoencoderKLHunyuanVideo, + scheduler: FlowMatchEulerDiscreteScheduler, + text_encoder_2: CLIPTextModel, + tokenizer_2: CLIPTokenizer, + image_processor: CLIPImageProcessor, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + image_processor=image_processor, + ) + + self.vae_scaling_factor = self.vae.config.scaling_factor if getattr(self, "vae", None) else 0.476986 + self.vae_scale_factor_temporal = self.vae.temporal_compression_ratio if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = self.vae.spatial_compression_ratio if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + def _get_llama_prompt_embeds( + self, + image: torch.Tensor, + prompt: str | list[str], + prompt_template: dict[str, Any], + num_videos_per_prompt: int = 1, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + max_sequence_length: int = 256, + num_hidden_layers_to_skip: int = 2, + image_embed_interleave: int = 2, + ) -> tuple[torch.Tensor, torch.Tensor]: + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_template["template"].format(p) for p in prompt] + + crop_start = prompt_template.get("crop_start", None) + + image_emb_len = prompt_template.get("image_emb_len", 576) + image_emb_start = prompt_template.get("image_emb_start", 5) + image_emb_end = prompt_template.get("image_emb_end", 581) + double_return_token_id = prompt_template.get("double_return_token_id", 271) + + if crop_start is None: + prompt_template_input = self.tokenizer( + prompt_template["template"], + padding="max_length", + return_tensors="pt", + return_length=False, + return_overflowing_tokens=False, + return_attention_mask=False, + ) + crop_start = prompt_template_input["input_ids"].shape[-1] + # Remove <|start_header_id|>, <|end_header_id|>, assistant, <|eot_id|>, and placeholder {} + crop_start -= 5 + + max_sequence_length += crop_start + text_inputs = self.tokenizer( + prompt, + max_length=max_sequence_length, + padding="max_length", + truncation=True, + return_tensors="pt", + return_length=False, + return_overflowing_tokens=False, + return_attention_mask=True, + ) + text_input_ids = text_inputs.input_ids.to(device=device) + prompt_attention_mask = text_inputs.attention_mask.to(device=device) + + image_embeds = self.image_processor(image, return_tensors="pt").pixel_values.to(device) + + image_token_index = self.text_encoder.config.image_token_index + pad_token_id = self.text_encoder.config.pad_token_id + expanded_inputs = _expand_input_ids_with_image_tokens( + text_input_ids, + prompt_attention_mask, + max_sequence_length, + image_token_index, + image_emb_len, + image_emb_start, + image_emb_end, + pad_token_id, + ) + prompt_embeds = self.text_encoder( + **expanded_inputs, + pixel_values=image_embeds, + output_hidden_states=True, + ).hidden_states[-(num_hidden_layers_to_skip + 1)] + prompt_embeds = prompt_embeds.to(dtype=dtype) + + if crop_start is not None and crop_start > 0: + text_crop_start = crop_start - 1 + image_emb_len + batch_indices, last_double_return_token_indices = torch.where(text_input_ids == double_return_token_id) + + if last_double_return_token_indices.shape[0] == 3: + # in case the prompt is too long + last_double_return_token_indices = torch.cat( + (last_double_return_token_indices, torch.tensor([text_input_ids.shape[-1]])) + ) + batch_indices = torch.cat((batch_indices, torch.tensor([0]))) + + last_double_return_token_indices = last_double_return_token_indices.reshape(text_input_ids.shape[0], -1)[ + :, -1 + ] + batch_indices = batch_indices.reshape(text_input_ids.shape[0], -1)[:, -1] + assistant_crop_start = last_double_return_token_indices - 1 + image_emb_len - 4 + assistant_crop_end = last_double_return_token_indices - 1 + image_emb_len + attention_mask_assistant_crop_start = last_double_return_token_indices - 4 + attention_mask_assistant_crop_end = last_double_return_token_indices + + prompt_embed_list = [] + prompt_attention_mask_list = [] + image_embed_list = [] + image_attention_mask_list = [] + + for i in range(text_input_ids.shape[0]): + prompt_embed_list.append( + torch.cat( + [ + prompt_embeds[i, text_crop_start : assistant_crop_start[i].item()], + prompt_embeds[i, assistant_crop_end[i].item() :], + ] + ) + ) + prompt_attention_mask_list.append( + torch.cat( + [ + prompt_attention_mask[i, crop_start : attention_mask_assistant_crop_start[i].item()], + prompt_attention_mask[i, attention_mask_assistant_crop_end[i].item() :], + ] + ) + ) + image_embed_list.append(prompt_embeds[i, image_emb_start:image_emb_end]) + image_attention_mask_list.append( + torch.ones(image_embed_list[-1].shape[0]).to(prompt_embeds.device).to(prompt_attention_mask.dtype) + ) + + prompt_embed_list = torch.stack(prompt_embed_list) + prompt_attention_mask_list = torch.stack(prompt_attention_mask_list) + image_embed_list = torch.stack(image_embed_list) + image_attention_mask_list = torch.stack(image_attention_mask_list) + + if 0 < image_embed_interleave < 6: + image_embed_list = image_embed_list[:, ::image_embed_interleave, :] + image_attention_mask_list = image_attention_mask_list[:, ::image_embed_interleave] + + assert ( + prompt_embed_list.shape[0] == prompt_attention_mask_list.shape[0] + and image_embed_list.shape[0] == image_attention_mask_list.shape[0] + ) + + prompt_embeds = torch.cat([image_embed_list, prompt_embed_list], dim=1) + prompt_attention_mask = torch.cat([image_attention_mask_list, prompt_attention_mask_list], dim=1) + + return prompt_embeds, prompt_attention_mask + + def _get_clip_prompt_embeds( + self, + prompt: str | list[str], + num_videos_per_prompt: int = 1, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + max_sequence_length: int = 77, + ) -> torch.Tensor: + device = device or self._execution_device + dtype = dtype or self.text_encoder_2.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False).pooler_output + return prompt_embeds + + def encode_prompt( + self, + image: torch.Tensor, + prompt: str | list[str], + prompt_2: str | list[str] = None, + prompt_template: dict[str, Any] = DEFAULT_PROMPT_TEMPLATE, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + max_sequence_length: int = 256, + image_embed_interleave: int = 2, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_llama_prompt_embeds( + image, + prompt, + prompt_template, + num_videos_per_prompt, + device=device, + dtype=dtype, + max_sequence_length=max_sequence_length, + image_embed_interleave=image_embed_interleave, + ) + + if pooled_prompt_embeds is None: + if prompt_2 is None: + prompt_2 = prompt + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt, + num_videos_per_prompt, + device=device, + dtype=dtype, + max_sequence_length=77, + ) + + return prompt_embeds, pooled_prompt_embeds, prompt_attention_mask + + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + prompt_template=None, + true_cfg_scale=1.0, + guidance_scale=1.0, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if prompt_template is not None: + if not isinstance(prompt_template, dict): + raise ValueError(f"`prompt_template` has to be of type `dict` but is {type(prompt_template)}") + if "template" not in prompt_template: + raise ValueError( + f"`prompt_template` has to contain a key `template` but only found {prompt_template.keys()}" + ) + + if true_cfg_scale > 1.0 and guidance_scale > 1.0: + logger.warning( + "Both `true_cfg_scale` and `guidance_scale` are greater than 1.0. This will result in both " + "classifier-free guidance and embedded-guidance to be applied. This is not recommended " + "as it may lead to higher memory usage, slower inference and potentially worse results." + ) + + def prepare_latents( + self, + image: torch.Tensor, + batch_size: int, + num_channels_latents: int = 32, + height: int = 720, + width: int = 1280, + num_frames: int = 129, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + image_condition_type: str = "latent_concat", + ) -> torch.Tensor: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + latent_height, latent_width = height // self.vae_scale_factor_spatial, width // self.vae_scale_factor_spatial + shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width) + + image = image.unsqueeze(2) # [B, C, 1, H, W] + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i].unsqueeze(0)), generator[i], "argmax") + for i in range(batch_size) + ] + else: + image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator, "argmax") for img in image] + + image_latents = torch.cat(image_latents, dim=0).to(dtype) * self.vae_scaling_factor + image_latents = image_latents.repeat(1, 1, num_latent_frames, 1, 1) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + t = torch.tensor([0.999]).to(device=device) + latents = latents * t + image_latents * (1 - t) + + if image_condition_type == "token_replace": + image_latents = image_latents[:, :, :1] + + return latents, image_latents + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." + deprecate( + "enable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." + deprecate( + "disable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." + deprecate( + "enable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." + deprecate( + "disable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.disable_tiling() + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PIL.Image.Image, + prompt: str | list[str] = None, + prompt_2: str | list[str] = None, + negative_prompt: str | list[str] = None, + negative_prompt_2: str | list[str] = None, + height: int = 720, + width: int = 1280, + num_frames: int = 129, + num_inference_steps: int = 50, + sigmas: list[float] = None, + true_cfg_scale: float = 1.0, + guidance_scale: float = 1.0, + num_videos_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + prompt_template: dict[str, Any] = DEFAULT_PROMPT_TEMPLATE, + max_sequence_length: int = 256, + image_embed_interleave: int | None = None, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. + height (`int`, defaults to `720`): + The height in pixels of the generated image. + width (`int`, defaults to `1280`): + The width in pixels of the generated image. + num_frames (`int`, defaults to `129`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + true_cfg_scale (`float`, *optional*, defaults to 1.0): + When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance. + guidance_scale (`float`, defaults to `1.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. Note that the only available + HunyuanVideo model is CFG-distilled, which means that traditional guidance between unconditional and + conditional latent is not applied. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~HunyuanVideoPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned, otherwise a `tuple` is returned + where the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + prompt_embeds, + callback_on_step_end_tensor_inputs, + prompt_template, + true_cfg_scale, + guidance_scale, + ) + + image_condition_type = self.transformer.config.image_condition_type + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None + ) + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + image_embed_interleave = ( + image_embed_interleave + if image_embed_interleave is not None + else ( + 2 if image_condition_type == "latent_concat" else 4 if image_condition_type == "token_replace" else 1 + ) + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Prepare latent variables + vae_dtype = self.vae.dtype + image_tensor = self.video_processor.preprocess(image, height, width).to(device, vae_dtype) + + if image_condition_type == "latent_concat": + num_channels_latents = (self.transformer.config.in_channels - 1) // 2 + elif image_condition_type == "token_replace": + num_channels_latents = self.transformer.config.in_channels + + latents, image_latents = self.prepare_latents( + image_tensor, + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + image_condition_type, + ) + if image_condition_type == "latent_concat": + image_latents[:, :, 1:] = 0 + mask = image_latents.new_ones(image_latents.shape[0], 1, *image_latents.shape[2:]) + mask[:, :, 1:] = 0 + + # 4. Encode input prompt + transformer_dtype = self.transformer.dtype + prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = self.encode_prompt( + image=image, + prompt=prompt, + prompt_2=prompt_2, + prompt_template=prompt_template, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + device=device, + max_sequence_length=max_sequence_length, + image_embed_interleave=image_embed_interleave, + ) + prompt_embeds = prompt_embeds.to(transformer_dtype) + prompt_attention_mask = prompt_attention_mask.to(transformer_dtype) + pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype) + + if do_true_cfg: + black_image = PIL.Image.new("RGB", (width, height), 0) + negative_prompt_embeds, negative_pooled_prompt_embeds, negative_prompt_attention_mask = self.encode_prompt( + image=black_image, + prompt=negative_prompt, + prompt_2=negative_prompt_2, + prompt_template=prompt_template, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=negative_pooled_prompt_embeds, + prompt_attention_mask=negative_prompt_attention_mask, + device=device, + max_sequence_length=max_sequence_length, + ) + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + negative_prompt_attention_mask = negative_prompt_attention_mask.to(transformer_dtype) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(transformer_dtype) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas) + + # 6. Prepare guidance condition + guidance = None + if self.transformer.config.guidance_embeds: + guidance = ( + torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0 + ) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + if image_condition_type == "latent_concat": + latent_model_input = torch.cat([latents, image_latents, mask], dim=1).to(transformer_dtype) + elif image_condition_type == "token_replace": + latent_model_input = torch.cat([image_latents, latents[:, :, 1:]], dim=2).to(transformer_dtype) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + pooled_projections=pooled_prompt_embeds, + guidance=guidance, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if do_true_cfg: + neg_noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + encoder_attention_mask=negative_prompt_attention_mask, + pooled_projections=negative_pooled_prompt_embeds, + guidance=guidance, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + # compute the previous noisy sample x_t -> x_t-1 + if image_condition_type == "latent_concat": + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + elif image_condition_type == "token_replace": + latents = latents = self.scheduler.step( + noise_pred[:, :, 1:], t, latents[:, :, 1:], return_dict=False + )[0] + latents = torch.cat([image_latents, latents], dim=2) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) / self.vae_scaling_factor + video = self.vae.decode(latents, return_dict=False)[0] + if image_condition_type == "latent_concat": + video = video[:, :, 4:, :, :] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + if image_condition_type == "latent_concat": + video = latents[:, :, 1:, :, :] + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return HunyuanVideoPipelineOutput(frames=video) diff --git a/diffusers/pipelines/hunyuan_video/pipeline_output.py b/diffusers/pipelines/hunyuan_video/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..4cf96626277047d35d7d06e1481af6ee2fb56447 --- /dev/null +++ b/diffusers/pipelines/hunyuan_video/pipeline_output.py @@ -0,0 +1,38 @@ +from dataclasses import dataclass + +import numpy as np +import PIL.Image +import torch + +from diffusers.utils import BaseOutput + + +@dataclass +class HunyuanVideoPipelineOutput(BaseOutput): + r""" + Output class for HunyuanVideo pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or list[list[PIL.Image.Image]]): + list of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + frames: torch.Tensor + + +@dataclass +class HunyuanVideoFramepackPipelineOutput(BaseOutput): + r""" + Output class for HunyuanVideo pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or list[list[PIL.Image.Image]]): + list of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. Or, a list of torch tensors where each tensor + corresponds to a latent that decodes to multiple frames. + """ + + frames: torch.Tensor | np.ndarray | list[list[PIL.Image.Image]] | list[torch.Tensor] diff --git a/diffusers/pipelines/hunyuan_video1_5/__init__.py b/diffusers/pipelines/hunyuan_video1_5/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..846320f4ace07d502c0c8d067d63527a80664882 --- /dev/null +++ b/diffusers/pipelines/hunyuan_video1_5/__init__.py @@ -0,0 +1,50 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_hunyuan_video1_5"] = ["HunyuanVideo15Pipeline"] + _import_structure["pipeline_hunyuan_video1_5_image2video"] = ["HunyuanVideo15ImageToVideoPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_hunyuan_video1_5 import HunyuanVideo15Pipeline + from .pipeline_hunyuan_video1_5_image2video import HunyuanVideo15ImageToVideoPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/pipelines/hunyuan_video1_5/image_processor.py b/diffusers/pipelines/hunyuan_video1_5/image_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..82817365b6a529c2307540af8e3fbb268ad0b05a --- /dev/null +++ b/diffusers/pipelines/hunyuan_video1_5/image_processor.py @@ -0,0 +1,103 @@ +# Copyright 2025 The HunyuanVideo Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +from ...configuration_utils import register_to_config +from ...video_processor import VideoProcessor + + +# copied from https://github.com/Tencent-Hunyuan/HunyuanVideo-1.5/blob/main/hyvideo/utils/data_utils.py#L20 +def generate_crop_size_list(base_size=256, patch_size=16, max_ratio=4.0): + num_patches = round((base_size / patch_size) ** 2) + assert max_ratio >= 1.0 + crop_size_list = [] + wp, hp = num_patches, 1 + while wp > 0: + if max(wp, hp) / min(wp, hp) <= max_ratio: + crop_size_list.append((wp * patch_size, hp * patch_size)) + if (hp + 1) * wp <= num_patches: + hp += 1 + else: + wp -= 1 + return crop_size_list + + +# copied from https://github.com/Tencent-Hunyuan/HunyuanVideo-1.5/blob/main/hyvideo/utils/data_utils.py#L38 +def get_closest_ratio(height: float, width: float, ratios: list, buckets: list): + """ + Get the closest ratio in the buckets. + + Args: + height (float): video height + width (float): video width + ratios (list): video aspect ratio + buckets (list): buckets generated by `generate_crop_size_list` + + Returns: + the closest size in the buckets and the corresponding ratio + """ + aspect_ratio = float(height) / float(width) + diff_ratios = ratios - aspect_ratio + + if aspect_ratio >= 1: + indices = [(index, x) for index, x in enumerate(diff_ratios) if x <= 0] + else: + indices = [(index, x) for index, x in enumerate(diff_ratios) if x >= 0] + + closest_ratio_id = min(indices, key=lambda pair: abs(pair[1]))[0] + closest_size = buckets[closest_ratio_id] + closest_ratio = ratios[closest_ratio_id] + + return closest_size, closest_ratio + + +class HunyuanVideo15ImageProcessor(VideoProcessor): + r""" + Image/video processor to preproces/postprocess the reference image/generatedvideo for the HunyuanVideo1.5 model. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept + `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method. + vae_scale_factor (`int`, *optional*, defaults to `16`): + VAE (spatial) scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of + this factor. + vae_latent_channels (`int`, *optional*, defaults to `32`): + VAE latent channels. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + """ + + @register_to_config + def __init__( + self, + do_resize: bool = True, + vae_scale_factor: int = 16, + vae_latent_channels: int = 32, + do_convert_rgb: bool = True, + ): + super().__init__( + do_resize=do_resize, + vae_scale_factor=vae_scale_factor, + vae_latent_channels=vae_latent_channels, + do_convert_rgb=do_convert_rgb, + ) + + def calculate_default_height_width(self, height: int, width: int, target_size: int): + crop_size_list = generate_crop_size_list(base_size=target_size, patch_size=self.config.vae_scale_factor) + aspect_ratios = np.array([round(float(h) / float(w), 5) for h, w in crop_size_list]) + height, width = get_closest_ratio(height, width, aspect_ratios, crop_size_list)[0] + + return height, width diff --git a/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5.py b/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5.py new file mode 100644 index 0000000000000000000000000000000000000000..a0adff493ac0e7243f64a8763be74f4a5daeb9a8 --- /dev/null +++ b/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5.py @@ -0,0 +1,837 @@ +# Copyright 2025 The HunyuanVideo Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import re +from typing import Any + +import numpy as np +import torch +from transformers import ByT5Tokenizer, Qwen2_5_VLTextModel, Qwen2Tokenizer, T5EncoderModel + +from ...guiders import ClassifierFreeGuidance +from ...models import AutoencoderKLHunyuanVideo15, HunyuanVideo15Transformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .image_processor import HunyuanVideo15ImageProcessor +from .pipeline_output import HunyuanVideo15PipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import HunyuanVideo15Pipeline + >>> from diffusers.utils import export_to_video + + >>> model_id = "hunyuanvideo-community/HunyuanVideo-1.5-480p_t2v" + >>> pipe = HunyuanVideo15Pipeline.from_pretrained(model_id, torch_dtype=torch.float16) + >>> pipe.vae.enable_tiling() + >>> pipe.to("cuda") + + >>> output = pipe( + ... prompt="A cat walks on the grass, realistic", + ... num_inference_steps=50, + ... ).frames[0] + >>> export_to_video(output, "output.mp4", fps=15) + ``` +""" + + +def format_text_input(prompt: list[str], system_message: str) -> list[dict[str, Any]]: + """ + Apply text to template. + + Args: + prompt (list[str]): Input text. + system_message (str): System message. + + Returns: + list[dict[str, Any]]: List of chat conversation. + """ + + template = [ + [{"role": "system", "content": system_message}, {"role": "user", "content": p if p else " "}] for p in prompt + ] + + return template + + +def extract_glyph_texts(prompt: str) -> list[str]: + """ + Extract glyph texts from prompt using regex pattern. + + Args: + prompt: Input prompt string + + Returns: + List of extracted glyph texts + """ + pattern = r"\"(.*?)\"|“(.*?)”" + matches = re.findall(pattern, prompt) + result = [match[0] or match[1] for match in matches] + result = list(dict.fromkeys(result)) if len(result) > 1 else result + + if result: + formatted_result = ". ".join([f'Text "{text}"' for text in result]) + ". " + else: + formatted_result = None + + return formatted_result + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class HunyuanVideo15Pipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using HunyuanVideo1.5. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + transformer ([`HunyuanVideo15Transformer3DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded video latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + vae ([`AutoencoderKLHunyuanVideo15`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder ([`Qwen2.5-VL-7B-Instruct`]): + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant. + tokenizer (`Qwen2Tokenizer`): Tokenizer of class [Qwen2Tokenizer]. + text_encoder_2 ([`T5EncoderModel`]): + [T5EncoderModel](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel) + variant. + tokenizer_2 (`ByT5Tokenizer`): Tokenizer of class [ByT5Tokenizer] + guider ([`ClassifierFreeGuidance`]): + [ClassifierFreeGuidance]for classifier free guidance. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + + def __init__( + self, + text_encoder: Qwen2_5_VLTextModel, + tokenizer: Qwen2Tokenizer, + transformer: HunyuanVideo15Transformer3DModel, + vae: AutoencoderKLHunyuanVideo15, + scheduler: FlowMatchEulerDiscreteScheduler, + text_encoder_2: T5EncoderModel, + tokenizer_2: ByT5Tokenizer, + guider: ClassifierFreeGuidance, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + guider=guider, + ) + + self.vae_scale_factor_temporal = self.vae.temporal_compression_ratio if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = self.vae.spatial_compression_ratio if getattr(self, "vae", None) else 16 + self.video_processor = HunyuanVideo15ImageProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + self.target_size = self.transformer.config.target_size if getattr(self, "transformer", None) else 640 + self.vision_states_dim = ( + self.transformer.config.image_embed_dim if getattr(self, "transformer", None) else 1152 + ) + self.num_channels_latents = self.vae.config.latent_channels if hasattr(self, "vae") else 32 + # fmt: off + self.system_message = "You are a helpful assistant. Describe the video by detailing the following aspects: \ + 1. The main content and theme of the video. \ + 2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects. \ + 3. Actions, events, behaviors temporal relationships, physical movement changes of the objects. \ + 4. background environment, light, style and atmosphere. \ + 5. camera angles, movements, and transitions used in the video." + # fmt: on + self.prompt_template_encode_start_idx = 108 + self.tokenizer_max_length = 1000 + self.tokenizer_2_max_length = 256 + self.vision_num_semantic_tokens = 729 + self.default_aspect_ratio = (16, 9) # (width: height) + + @staticmethod + def _get_mllm_prompt_embeds( + text_encoder: Qwen2_5_VLTextModel, + tokenizer: Qwen2Tokenizer, + prompt: str | list[str], + device: torch.device, + tokenizer_max_length: int = 1000, + num_hidden_layers_to_skip: int = 2, + # fmt: off + system_message: str = "You are a helpful assistant. Describe the video by detailing the following aspects: \ + 1. The main content and theme of the video. \ + 2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects. \ + 3. Actions, events, behaviors temporal relationships, physical movement changes of the objects. \ + 4. background environment, light, style and atmosphere. \ + 5. camera angles, movements, and transitions used in the video.", + # fmt: on + crop_start: int = 108, + ) -> tuple[torch.Tensor, torch.Tensor]: + prompt = [prompt] if isinstance(prompt, str) else prompt + + prompt = format_text_input(prompt, system_message) + + text_inputs = tokenizer.apply_chat_template( + prompt, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + padding="max_length", + max_length=tokenizer_max_length + crop_start, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device=device) + prompt_attention_mask = text_inputs.attention_mask.to(device=device) + + prompt_embeds = text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_attention_mask, + output_hidden_states=True, + ).hidden_states[-(num_hidden_layers_to_skip + 1)] + + if crop_start is not None and crop_start > 0: + prompt_embeds = prompt_embeds[:, crop_start:] + prompt_attention_mask = prompt_attention_mask[:, crop_start:] + + return prompt_embeds, prompt_attention_mask + + @staticmethod + def _get_byt5_prompt_embeds( + tokenizer: ByT5Tokenizer, + text_encoder: T5EncoderModel, + prompt: str | list[str], + device: torch.device, + tokenizer_max_length: int = 256, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + + glyph_texts = [extract_glyph_texts(p) for p in prompt] + + prompt_embeds_list = [] + prompt_embeds_mask_list = [] + + for glyph_text in glyph_texts: + if glyph_text is None: + glyph_text_embeds = torch.zeros( + (1, tokenizer_max_length, text_encoder.config.d_model), device=device, dtype=text_encoder.dtype + ) + glyph_text_embeds_mask = torch.zeros((1, tokenizer_max_length), device=device, dtype=torch.int64) + else: + txt_tokens = tokenizer( + glyph_text, + padding="max_length", + max_length=tokenizer_max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ).to(device) + + glyph_text_embeds = text_encoder( + input_ids=txt_tokens.input_ids, + attention_mask=txt_tokens.attention_mask.float(), + )[0] + glyph_text_embeds = glyph_text_embeds.to(device=device) + glyph_text_embeds_mask = txt_tokens.attention_mask.to(device=device) + + prompt_embeds_list.append(glyph_text_embeds) + prompt_embeds_mask_list.append(glyph_text_embeds_mask) + + prompt_embeds = torch.cat(prompt_embeds_list, dim=0) + prompt_embeds_mask = torch.cat(prompt_embeds_mask_list, dim=0) + + return prompt_embeds, prompt_embeds_mask + + def encode_prompt( + self, + prompt: str | list[str], + device: torch.device | None = None, + dtype: torch.dtype | None = None, + batch_size: int = 1, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_mask: torch.Tensor | None = None, + prompt_embeds_2: torch.Tensor | None = None, + prompt_embeds_mask_2: torch.Tensor | None = None, + ): + r""" + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + batch_size (`int`): + batch size of prompts, defaults to 1 + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. If not provided, text embeddings will be generated from `prompt` input + argument. + prompt_embeds_mask (`torch.Tensor`, *optional*): + Pre-generated text mask. If not provided, text mask will be generated from `prompt` input argument. + prompt_embeds_2 (`torch.Tensor`, *optional*): + Pre-generated glyph text embeddings from ByT5. If not provided, will be generated from `prompt` input + argument using self.tokenizer_2 and self.text_encoder_2. + prompt_embeds_mask_2 (`torch.Tensor`, *optional*): + Pre-generated glyph text mask from ByT5. If not provided, will be generated from `prompt` input + argument using self.tokenizer_2 and self.text_encoder_2. + """ + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + if prompt is None: + prompt = [""] * batch_size + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_embeds, prompt_embeds_mask = self._get_mllm_prompt_embeds( + tokenizer=self.tokenizer, + text_encoder=self.text_encoder, + prompt=prompt, + device=device, + tokenizer_max_length=self.tokenizer_max_length, + system_message=self.system_message, + crop_start=self.prompt_template_encode_start_idx, + ) + + if prompt_embeds_2 is None: + prompt_embeds_2, prompt_embeds_mask_2 = self._get_byt5_prompt_embeds( + tokenizer=self.tokenizer_2, + text_encoder=self.text_encoder_2, + prompt=prompt, + device=device, + tokenizer_max_length=self.tokenizer_2_max_length, + ) + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_videos_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_videos_per_prompt, seq_len) + + _, seq_len_2, _ = prompt_embeds_2.shape + prompt_embeds_2 = prompt_embeds_2.repeat(1, num_videos_per_prompt, 1) + prompt_embeds_2 = prompt_embeds_2.view(batch_size * num_videos_per_prompt, seq_len_2, -1) + prompt_embeds_mask_2 = prompt_embeds_mask_2.repeat(1, num_videos_per_prompt, 1) + prompt_embeds_mask_2 = prompt_embeds_mask_2.view(batch_size * num_videos_per_prompt, seq_len_2) + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds_mask = prompt_embeds_mask.to(dtype=dtype, device=device) + prompt_embeds_2 = prompt_embeds_2.to(dtype=dtype, device=device) + prompt_embeds_mask_2 = prompt_embeds_mask_2.to(dtype=dtype, device=device) + + return prompt_embeds, prompt_embeds_mask, prompt_embeds_2, prompt_embeds_mask_2 + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_embeds_mask=None, + negative_prompt_embeds_mask=None, + prompt_embeds_2=None, + prompt_embeds_mask_2=None, + negative_prompt_embeds_2=None, + negative_prompt_embeds_mask_2=None, + ): + if height is None and width is not None: + raise ValueError("If `width` is provided, `height` also have to be provided.") + elif width is None and height is not None: + raise ValueError("If `height` is provided, `width` also have to be provided.") + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_embeds_mask is None: + raise ValueError( + "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`." + ) + if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if prompt is None and prompt_embeds_2 is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined." + ) + + if prompt_embeds_2 is not None and prompt_embeds_mask_2 is None: + raise ValueError( + "If `prompt_embeds_2` are provided, `prompt_embeds_mask_2` also have to be passed. Make sure to generate `prompt_embeds_mask_2` from the same text encoder that was used to generate `prompt_embeds_2`." + ) + if negative_prompt_embeds_2 is not None and negative_prompt_embeds_mask_2 is None: + raise ValueError( + "If `negative_prompt_embeds_2` are provided, `negative_prompt_embeds_mask_2` also have to be passed. Make sure to generate `negative_prompt_embeds_mask_2` from the same text encoder that was used to generate `negative_prompt_embeds_2`." + ) + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int = 32, + height: int = 720, + width: int = 1280, + num_frames: int = 129, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + shape = ( + batch_size, + num_channels_latents, + (num_frames - 1) // self.vae_scale_factor_temporal + 1, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + def prepare_cond_latents_and_mask(self, latents, dtype: torch.dtype | None, device: torch.device | None): + """ + Prepare conditional latents and mask for t2v generation. + + Args: + latents: Main latents tensor (B, C, F, H, W) + + Returns: + tuple: (cond_latents_concat, mask_concat) - both are zero tensors for t2v + """ + batch, channels, frames, height, width = latents.shape + + cond_latents_concat = torch.zeros(batch, channels, frames, height, width, dtype=dtype, device=device) + + mask_concat = torch.zeros(batch, 1, frames, height, width, dtype=dtype, device=device) + + return cond_latents_concat, mask_concat + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str | list[str] = None, + height: int | None = None, + width: int | None = None, + num_frames: int = 121, + num_inference_steps: int = 50, + sigmas: list[float] = None, + num_videos_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds_mask: torch.Tensor | None = None, + prompt_embeds_2: torch.Tensor | None = None, + prompt_embeds_mask_2: torch.Tensor | None = None, + negative_prompt_embeds_2: torch.Tensor | None = None, + negative_prompt_embeds_mask_2: torch.Tensor | None = None, + output_type: str | None = "np", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds` + instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. + height (`int`, *optional*): + The height in pixels of the generated video. + width (`int`, *optional*): + The width in pixels of the generated video. + num_frames (`int`, defaults to `121`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality video at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + prompt_embeds_mask (`torch.Tensor`, *optional*): + Pre-generated mask for prompt embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_prompt_embeds_mask (`torch.Tensor`, *optional*): + Pre-generated mask for negative prompt embeddings. + prompt_embeds_2 (`torch.Tensor`, *optional*): + Pre-generated text embeddings from the second text encoder. Can be used to easily tweak text inputs. + prompt_embeds_mask_2 (`torch.Tensor`, *optional*): + Pre-generated mask for prompt embeddings from the second text encoder. + negative_prompt_embeds_2 (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings from the second text encoder. + negative_prompt_embeds_mask_2 (`torch.Tensor`, *optional*): + Pre-generated mask for negative prompt embeddings from the second text encoder. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated video. Choose between "np", "pt", or "latent". + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`HunyuanVideo15PipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + + Examples: + + Returns: + [`~HunyuanVideo15PipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`HunyuanVideo15PipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated videos. + """ + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + height=height, + width=width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + negative_prompt_embeds_mask=negative_prompt_embeds_mask, + prompt_embeds_2=prompt_embeds_2, + prompt_embeds_mask_2=prompt_embeds_mask_2, + negative_prompt_embeds_2=negative_prompt_embeds_2, + negative_prompt_embeds_mask_2=negative_prompt_embeds_mask_2, + ) + + if height is None and width is None: + height, width = self.video_processor.calculate_default_height_width( + self.default_aspect_ratio[1], self.default_aspect_ratio[0], self.target_size + ) + + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + prompt_embeds, prompt_embeds_mask, prompt_embeds_2, prompt_embeds_mask_2 = self.encode_prompt( + prompt=prompt, + device=device, + dtype=self.transformer.dtype, + batch_size=batch_size, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + prompt_embeds_2=prompt_embeds_2, + prompt_embeds_mask_2=prompt_embeds_mask_2, + ) + + if self.guider._enabled and self.guider.num_conditions > 1: + ( + negative_prompt_embeds, + negative_prompt_embeds_mask, + negative_prompt_embeds_2, + negative_prompt_embeds_mask_2, + ) = self.encode_prompt( + prompt=negative_prompt, + device=device, + dtype=self.transformer.dtype, + batch_size=batch_size, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + prompt_embeds_2=negative_prompt_embeds_2, + prompt_embeds_mask_2=negative_prompt_embeds_mask_2, + ) + + # 4. Prepare timesteps + sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas) + + # 5. Prepare latent variables + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + self.num_channels_latents, + height, + width, + num_frames, + self.transformer.dtype, + device, + generator, + latents, + ) + cond_latents_concat, mask_concat = self.prepare_cond_latents_and_mask(latents, self.transformer.dtype, device) + image_embeds = torch.zeros( + batch_size, + self.vision_num_semantic_tokens, + self.vision_states_dim, + dtype=self.transformer.dtype, + device=device, + ) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + latent_model_input = torch.cat([latents, cond_latents_concat, mask_concat], dim=1) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype) + + # Step 1: Collect model inputs needed for the guidance method + # conditional inputs should always be first element in the tuple + guider_inputs = { + "encoder_hidden_states": (prompt_embeds, negative_prompt_embeds), + "encoder_attention_mask": (prompt_embeds_mask, negative_prompt_embeds_mask), + "encoder_hidden_states_2": (prompt_embeds_2, negative_prompt_embeds_2), + "encoder_attention_mask_2": (prompt_embeds_mask_2, negative_prompt_embeds_mask_2), + } + + # Step 2: Update guider's internal state for this denoising step + self.guider.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t) + + # Step 3: Prepare batched model inputs based on the guidance method + # The guider splits model inputs into separate batches for conditional/unconditional predictions. + # For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}: + # you will get a guider_state with two batches: + # guider_state = [ + # {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch + # {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch + # ] + # Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG). + guider_state = self.guider.prepare_inputs(guider_inputs) + # Step 4: Run the denoiser for each batch + # Each batch in guider_state represents a different conditioning (conditional, unconditional, etc.). + # We run the model once per batch and store the noise prediction in guider_state_batch.noise_pred. + for guider_state_batch in guider_state: + self.guider.prepare_models(self.transformer) + + # Extract conditioning kwargs for this batch (e.g., encoder_hidden_states) + cond_kwargs = { + input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys() + } + + # e.g. "pred_cond"/"pred_uncond" + context_name = getattr(guider_state_batch, self.guider._identifier_key) + with self.transformer.cache_context(context_name): + # Run denoiser and store noise prediction in this batch + guider_state_batch.noise_pred = self.transformer( + hidden_states=latent_model_input, + image_embeds=image_embeds, + timestep=timestep, + attention_kwargs=self.attention_kwargs, + return_dict=False, + **cond_kwargs, + )[0] + + # Cleanup model (e.g., remove hooks) + self.guider.cleanup_models(self.transformer) + + # Step 5: Combine predictions using the guidance method + # The guider takes all noise predictions from guider_state and combines them according to the guidance algorithm. + # Continuing the CFG example, the guider receives: + # guider_state = [ + # {"encoder_hidden_states": prompt_embeds, "noise_pred": noise_pred_cond, "__guidance_identifier__": "pred_cond"}, # batch 0 + # {"encoder_hidden_states": negative_prompt_embeds, "noise_pred": noise_pred_uncond, "__guidance_identifier__": "pred_uncond"}, # batch 1 + # ] + # And extracts predictions using the __guidance_identifier__: + # pred_cond = guider_state[0]["noise_pred"] # extracts noise_pred_cond + # pred_uncond = guider_state[1]["noise_pred"] # extracts noise_pred_uncond + # Then applies CFG formula: + # noise_pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond) + # Returns GuiderOutput(pred=noise_pred, pred_cond=pred_cond, pred_uncond=pred_uncond) + noise_pred = self.guider(guider_state)[0] + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + # 8. decode the latents to video and postprocess + if not output_type == "latent": + latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return HunyuanVideo15PipelineOutput(frames=video) diff --git a/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5_image2video.py b/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5_image2video.py new file mode 100644 index 0000000000000000000000000000000000000000..791dec0735248a12a4c68743084bded7663a3c25 --- /dev/null +++ b/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5_image2video.py @@ -0,0 +1,960 @@ +# Copyright 2025 The HunyuanVideo Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import re +from typing import Any + +import numpy as np +import PIL +import torch +from transformers import ( + ByT5Tokenizer, + Qwen2_5_VLTextModel, + Qwen2Tokenizer, + SiglipImageProcessor, + SiglipVisionModel, + T5EncoderModel, +) + +from ...guiders import ClassifierFreeGuidance +from ...models import AutoencoderKLHunyuanVideo15, HunyuanVideo15Transformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .image_processor import HunyuanVideo15ImageProcessor +from .pipeline_output import HunyuanVideo15PipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import HunyuanVideo15ImageToVideoPipeline + >>> from diffusers.utils import export_to_video + + >>> model_id = "hunyuanvideo-community/HunyuanVideo-1.5-480p_i2v" + >>> pipe = HunyuanVideo15ImageToVideoPipeline.from_pretrained(model_id, torch_dtype=torch.float16) + >>> pipe.vae.enable_tiling() + >>> pipe.to("cuda") + + >>> image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/wan_i2v_input.JPG") + + >>> output = pipe( + ... prompt="Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.", + ... image=image, + ... num_inference_steps=50, + ... ).frames[0] + >>> export_to_video(output, "output.mp4", fps=24) + ``` +""" + + +# Copied from diffusers.pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5.format_text_input +def format_text_input(prompt: list[str], system_message: str) -> list[dict[str, Any]]: + """ + Apply text to template. + + Args: + prompt (list[str]): Input text. + system_message (str): System message. + + Returns: + list[dict[str, Any]]: List of chat conversation. + """ + + template = [ + [{"role": "system", "content": system_message}, {"role": "user", "content": p if p else " "}] for p in prompt + ] + + return template + + +# Copied from diffusers.pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5.extract_glyph_texts +def extract_glyph_texts(prompt: str) -> list[str]: + """ + Extract glyph texts from prompt using regex pattern. + + Args: + prompt: Input prompt string + + Returns: + List of extracted glyph texts + """ + pattern = r"\"(.*?)\"|“(.*?)”" + matches = re.findall(pattern, prompt) + result = [match[0] or match[1] for match in matches] + result = list(dict.fromkeys(result)) if len(result) > 1 else result + + if result: + formatted_result = ". ".join([f'Text "{text}"' for text in result]) + ". " + else: + formatted_result = None + + return formatted_result + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class HunyuanVideo15ImageToVideoPipeline(DiffusionPipeline): + r""" + Pipeline for image-to-video generation using HunyuanVideo1.5. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + transformer ([`HunyuanVideo15Transformer3DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded video latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + vae ([`AutoencoderKLHunyuanVideo15`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder ([`Qwen2.5-VL-7B-Instruct`]): + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant. + tokenizer (`Qwen2Tokenizer`): Tokenizer of class [Qwen2Tokenizer]. + text_encoder_2 ([`T5EncoderModel`]): + [T5EncoderModel](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel) + variant. + tokenizer_2 (`ByT5Tokenizer`): Tokenizer of class [ByT5Tokenizer] + guider ([`ClassifierFreeGuidance`]): + [ClassifierFreeGuidance]for classifier free guidance. + image_encoder ([`SiglipVisionModel`]): + [SiglipVisionModel](https://huggingface.co/docs/transformers/en/model_doc/siglip#transformers.SiglipVisionModel) + variant. + feature_extractor ([`SiglipImageProcessor`]): + [SiglipImageProcessor](https://huggingface.co/docs/transformers/en/model_doc/siglip#transformers.SiglipImageProcessor) + variant. + """ + + model_cpu_offload_seq = "image_encoder->text_encoder->transformer->vae" + + def __init__( + self, + text_encoder: Qwen2_5_VLTextModel, + tokenizer: Qwen2Tokenizer, + transformer: HunyuanVideo15Transformer3DModel, + vae: AutoencoderKLHunyuanVideo15, + scheduler: FlowMatchEulerDiscreteScheduler, + text_encoder_2: T5EncoderModel, + tokenizer_2: ByT5Tokenizer, + guider: ClassifierFreeGuidance, + image_encoder: SiglipVisionModel, + feature_extractor: SiglipImageProcessor, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + guider=guider, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + + self.vae_scale_factor_temporal = self.vae.temporal_compression_ratio if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = self.vae.spatial_compression_ratio if getattr(self, "vae", None) else 16 + self.video_processor = HunyuanVideo15ImageProcessor( + vae_scale_factor=self.vae_scale_factor_spatial, do_resize=False, do_convert_rgb=True + ) + self.target_size = self.transformer.config.target_size if getattr(self, "transformer", None) else 640 + self.vision_states_dim = ( + self.transformer.config.image_embed_dim if getattr(self, "transformer", None) else 1152 + ) + self.num_channels_latents = self.vae.config.latent_channels if hasattr(self, "vae") else 32 + # fmt: off + self.system_message = "You are a helpful assistant. Describe the video by detailing the following aspects: \ + 1. The main content and theme of the video. \ + 2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects. \ + 3. Actions, events, behaviors temporal relationships, physical movement changes of the objects. \ + 4. background environment, light, style and atmosphere. \ + 5. camera angles, movements, and transitions used in the video." + # fmt: on + self.prompt_template_encode_start_idx = 108 + self.tokenizer_max_length = 1000 + self.tokenizer_2_max_length = 256 + self.vision_num_semantic_tokens = 729 + + @staticmethod + # Copied from diffusers.pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5.HunyuanVideo15Pipeline._get_mllm_prompt_embeds + def _get_mllm_prompt_embeds( + text_encoder: Qwen2_5_VLTextModel, + tokenizer: Qwen2Tokenizer, + prompt: str | list[str], + device: torch.device, + tokenizer_max_length: int = 1000, + num_hidden_layers_to_skip: int = 2, + # fmt: off + system_message: str = "You are a helpful assistant. Describe the video by detailing the following aspects: \ + 1. The main content and theme of the video. \ + 2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects. \ + 3. Actions, events, behaviors temporal relationships, physical movement changes of the objects. \ + 4. background environment, light, style and atmosphere. \ + 5. camera angles, movements, and transitions used in the video.", + # fmt: on + crop_start: int = 108, + ) -> tuple[torch.Tensor, torch.Tensor]: + prompt = [prompt] if isinstance(prompt, str) else prompt + + prompt = format_text_input(prompt, system_message) + + text_inputs = tokenizer.apply_chat_template( + prompt, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + padding="max_length", + max_length=tokenizer_max_length + crop_start, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device=device) + prompt_attention_mask = text_inputs.attention_mask.to(device=device) + + prompt_embeds = text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_attention_mask, + output_hidden_states=True, + ).hidden_states[-(num_hidden_layers_to_skip + 1)] + + if crop_start is not None and crop_start > 0: + prompt_embeds = prompt_embeds[:, crop_start:] + prompt_attention_mask = prompt_attention_mask[:, crop_start:] + + return prompt_embeds, prompt_attention_mask + + @staticmethod + # Copied from diffusers.pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5.HunyuanVideo15Pipeline._get_byt5_prompt_embeds + def _get_byt5_prompt_embeds( + tokenizer: ByT5Tokenizer, + text_encoder: T5EncoderModel, + prompt: str | list[str], + device: torch.device, + tokenizer_max_length: int = 256, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + + glyph_texts = [extract_glyph_texts(p) for p in prompt] + + prompt_embeds_list = [] + prompt_embeds_mask_list = [] + + for glyph_text in glyph_texts: + if glyph_text is None: + glyph_text_embeds = torch.zeros( + (1, tokenizer_max_length, text_encoder.config.d_model), device=device, dtype=text_encoder.dtype + ) + glyph_text_embeds_mask = torch.zeros((1, tokenizer_max_length), device=device, dtype=torch.int64) + else: + txt_tokens = tokenizer( + glyph_text, + padding="max_length", + max_length=tokenizer_max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ).to(device) + + glyph_text_embeds = text_encoder( + input_ids=txt_tokens.input_ids, + attention_mask=txt_tokens.attention_mask.float(), + )[0] + glyph_text_embeds = glyph_text_embeds.to(device=device) + glyph_text_embeds_mask = txt_tokens.attention_mask.to(device=device) + + prompt_embeds_list.append(glyph_text_embeds) + prompt_embeds_mask_list.append(glyph_text_embeds_mask) + + prompt_embeds = torch.cat(prompt_embeds_list, dim=0) + prompt_embeds_mask = torch.cat(prompt_embeds_mask_list, dim=0) + + return prompt_embeds, prompt_embeds_mask + + @staticmethod + def _get_image_latents( + vae: AutoencoderKLHunyuanVideo15, + image_processor: HunyuanVideo15ImageProcessor, + image: PIL.Image.Image, + height: int, + width: int, + device: torch.device, + ) -> torch.Tensor: + vae_dtype = vae.dtype + image_tensor = image_processor.preprocess(image, height=height, width=width).to(device, dtype=vae_dtype) + image_tensor = image_tensor.unsqueeze(2) + image_latents = retrieve_latents(vae.encode(image_tensor), sample_mode="argmax") + image_latents = image_latents * vae.config.scaling_factor + return image_latents + + @staticmethod + def _get_image_embeds( + image_encoder: SiglipVisionModel, + feature_extractor: SiglipImageProcessor, + image: PIL.Image.Image, + device: torch.device, + ) -> torch.Tensor: + image_encoder_dtype = next(image_encoder.parameters()).dtype + image = feature_extractor.preprocess(images=image, do_resize=True, return_tensors="pt", do_convert_rgb=True) + image = image.to(device=device, dtype=image_encoder_dtype) + image_enc_hidden_states = image_encoder(**image).last_hidden_state + + return image_enc_hidden_states + + def encode_image( + self, + image: PIL.Image.Image, + batch_size: int, + device: torch.device, + dtype: torch.dtype, + ) -> torch.Tensor: + image_embeds = self._get_image_embeds( + image_encoder=self.image_encoder, + feature_extractor=self.feature_extractor, + image=image, + device=device, + ) + image_embeds = image_embeds.repeat(batch_size, 1, 1) + image_embeds = image_embeds.to(device=device, dtype=dtype) + return image_embeds + + # Copied from diffusers.pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5.HunyuanVideo15Pipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + device: torch.device | None = None, + dtype: torch.dtype | None = None, + batch_size: int = 1, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_mask: torch.Tensor | None = None, + prompt_embeds_2: torch.Tensor | None = None, + prompt_embeds_mask_2: torch.Tensor | None = None, + ): + r""" + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + batch_size (`int`): + batch size of prompts, defaults to 1 + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. If not provided, text embeddings will be generated from `prompt` input + argument. + prompt_embeds_mask (`torch.Tensor`, *optional*): + Pre-generated text mask. If not provided, text mask will be generated from `prompt` input argument. + prompt_embeds_2 (`torch.Tensor`, *optional*): + Pre-generated glyph text embeddings from ByT5. If not provided, will be generated from `prompt` input + argument using self.tokenizer_2 and self.text_encoder_2. + prompt_embeds_mask_2 (`torch.Tensor`, *optional*): + Pre-generated glyph text mask from ByT5. If not provided, will be generated from `prompt` input + argument using self.tokenizer_2 and self.text_encoder_2. + """ + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + if prompt is None: + prompt = [""] * batch_size + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_embeds, prompt_embeds_mask = self._get_mllm_prompt_embeds( + tokenizer=self.tokenizer, + text_encoder=self.text_encoder, + prompt=prompt, + device=device, + tokenizer_max_length=self.tokenizer_max_length, + system_message=self.system_message, + crop_start=self.prompt_template_encode_start_idx, + ) + + if prompt_embeds_2 is None: + prompt_embeds_2, prompt_embeds_mask_2 = self._get_byt5_prompt_embeds( + tokenizer=self.tokenizer_2, + text_encoder=self.text_encoder_2, + prompt=prompt, + device=device, + tokenizer_max_length=self.tokenizer_2_max_length, + ) + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_videos_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_videos_per_prompt, seq_len) + + _, seq_len_2, _ = prompt_embeds_2.shape + prompt_embeds_2 = prompt_embeds_2.repeat(1, num_videos_per_prompt, 1) + prompt_embeds_2 = prompt_embeds_2.view(batch_size * num_videos_per_prompt, seq_len_2, -1) + prompt_embeds_mask_2 = prompt_embeds_mask_2.repeat(1, num_videos_per_prompt, 1) + prompt_embeds_mask_2 = prompt_embeds_mask_2.view(batch_size * num_videos_per_prompt, seq_len_2) + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds_mask = prompt_embeds_mask.to(dtype=dtype, device=device) + prompt_embeds_2 = prompt_embeds_2.to(dtype=dtype, device=device) + prompt_embeds_mask_2 = prompt_embeds_mask_2.to(dtype=dtype, device=device) + + return prompt_embeds, prompt_embeds_mask, prompt_embeds_2, prompt_embeds_mask_2 + + def check_inputs( + self, + prompt, + image: PIL.Image.Image, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_embeds_mask=None, + negative_prompt_embeds_mask=None, + prompt_embeds_2=None, + prompt_embeds_mask_2=None, + negative_prompt_embeds_2=None, + negative_prompt_embeds_mask_2=None, + ): + if not isinstance(image, PIL.Image.Image): + raise ValueError(f"`image` has to be of type `PIL.Image.Image` but is {type(image)}") + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_embeds_mask is None: + raise ValueError( + "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`." + ) + if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if prompt is None and prompt_embeds_2 is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined." + ) + + if prompt_embeds_2 is not None and prompt_embeds_mask_2 is None: + raise ValueError( + "If `prompt_embeds_2` are provided, `prompt_embeds_mask_2` also have to be passed. Make sure to generate `prompt_embeds_mask_2` from the same text encoder that was used to generate `prompt_embeds_2`." + ) + if negative_prompt_embeds_2 is not None and negative_prompt_embeds_mask_2 is None: + raise ValueError( + "If `negative_prompt_embeds_2` are provided, `negative_prompt_embeds_mask_2` also have to be passed. Make sure to generate `negative_prompt_embeds_mask_2` from the same text encoder that was used to generate `negative_prompt_embeds_2`." + ) + + # Copied from diffusers.pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5.HunyuanVideo15Pipeline.prepare_latents + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int = 32, + height: int = 720, + width: int = 1280, + num_frames: int = 129, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + shape = ( + batch_size, + num_channels_latents, + (num_frames - 1) // self.vae_scale_factor_temporal + 1, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + def prepare_cond_latents_and_mask( + self, + latents: torch.Tensor, + image: PIL.Image.Image, + batch_size: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + ): + """ + Prepare conditional latents and mask for t2v generation. + + Args: + latents: Main latents tensor (B, C, F, H, W) + + Returns: + tuple: (cond_latents_concat, mask_concat) - both are zero tensors for t2v + """ + + batch, channels, frames, height, width = latents.shape + + image_latents = self._get_image_latents( + vae=self.vae, + image_processor=self.video_processor, + image=image, + height=height, + width=width, + device=device, + ) + + latent_condition = image_latents.repeat(batch_size, 1, frames, 1, 1) + latent_condition[:, :, 1:, :, :] = 0 + latent_condition = latent_condition.to(device=device, dtype=dtype) + + latent_mask = torch.zeros(batch, 1, frames, height, width, dtype=dtype, device=device) + latent_mask[:, :, 0, :, :] = 1.0 + + return latent_condition, latent_mask + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PIL.Image.Image, + prompt: str | list[str] = None, + negative_prompt: str | list[str] = None, + num_frames: int = 121, + num_inference_steps: int = 50, + sigmas: list[float] = None, + num_videos_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds_mask: torch.Tensor | None = None, + prompt_embeds_2: torch.Tensor | None = None, + prompt_embeds_mask_2: torch.Tensor | None = None, + negative_prompt_embeds_2: torch.Tensor | None = None, + negative_prompt_embeds_mask_2: torch.Tensor | None = None, + output_type: str | None = "np", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + ): + r""" + The call function to the pipeline for generation. + + Args: + image (`PIL.Image.Image`): + The input image to condition video generation on. + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds` + instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the video generation. If not defined, one has to pass + `negative_prompt_embeds` instead. + num_frames (`int`, defaults to `121`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality video at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + prompt_embeds_mask (`torch.Tensor`, *optional*): + Pre-generated mask for prompt embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_prompt_embeds_mask (`torch.Tensor`, *optional*): + Pre-generated mask for negative prompt embeddings. + prompt_embeds_2 (`torch.Tensor`, *optional*): + Pre-generated text embeddings from the second text encoder. Can be used to easily tweak text inputs. + prompt_embeds_mask_2 (`torch.Tensor`, *optional*): + Pre-generated mask for prompt embeddings from the second text encoder. + negative_prompt_embeds_2 (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings from the second text encoder. + negative_prompt_embeds_mask_2 (`torch.Tensor`, *optional*): + Pre-generated mask for negative prompt embeddings from the second text encoder. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated video. Choose between "np", "pt", or "latent". + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`HunyuanVideo15PipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + + Examples: + + Returns: + [`~HunyuanVideo15PipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`HunyuanVideo15PipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated videos. + """ + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + image=image, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + negative_prompt_embeds_mask=negative_prompt_embeds_mask, + prompt_embeds_2=prompt_embeds_2, + prompt_embeds_mask_2=prompt_embeds_mask_2, + negative_prompt_embeds_2=negative_prompt_embeds_2, + negative_prompt_embeds_mask_2=negative_prompt_embeds_mask_2, + ) + + height, width = self.video_processor.calculate_default_height_width( + height=image.size[1], width=image.size[0], target_size=self.target_size + ) + image = self.video_processor.resize(image, height=height, width=width, resize_mode="crop") + + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode image + image_embeds = self.encode_image( + image=image, + batch_size=batch_size * num_videos_per_prompt, + device=device, + dtype=self.transformer.dtype, + ) + + # 4. Encode input prompt + prompt_embeds, prompt_embeds_mask, prompt_embeds_2, prompt_embeds_mask_2 = self.encode_prompt( + prompt=prompt, + device=device, + dtype=self.transformer.dtype, + batch_size=batch_size, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + prompt_embeds_2=prompt_embeds_2, + prompt_embeds_mask_2=prompt_embeds_mask_2, + ) + + if self.guider._enabled and self.guider.num_conditions > 1: + ( + negative_prompt_embeds, + negative_prompt_embeds_mask, + negative_prompt_embeds_2, + negative_prompt_embeds_mask_2, + ) = self.encode_prompt( + prompt=negative_prompt, + device=device, + dtype=self.transformer.dtype, + batch_size=batch_size, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + prompt_embeds_2=negative_prompt_embeds_2, + prompt_embeds_mask_2=negative_prompt_embeds_mask_2, + ) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas) + + # 6. Prepare latent variables + latents = self.prepare_latents( + batch_size=batch_size * num_videos_per_prompt, + num_channels_latents=self.num_channels_latents, + height=height, + width=width, + num_frames=num_frames, + dtype=self.transformer.dtype, + device=device, + generator=generator, + latents=latents, + ) + + cond_latents_concat, mask_concat = self.prepare_cond_latents_and_mask( + latents=latents, + image=image, + batch_size=batch_size * num_videos_per_prompt, + height=height, + width=width, + dtype=self.transformer.dtype, + device=device, + ) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + latent_model_input = torch.cat([latents, cond_latents_concat, mask_concat], dim=1) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype) + + if self.transformer.config.use_meanflow: + if i == len(timesteps) - 1: + timestep_r = torch.tensor([0.0], device=device) + else: + timestep_r = timesteps[i + 1] + timestep_r = timestep_r.expand(latents.shape[0]).to(latents.dtype) + else: + timestep_r = None + + # Step 1: Collect model inputs needed for the guidance method + # conditional inputs should always be first element in the tuple + guider_inputs = { + "encoder_hidden_states": (prompt_embeds, negative_prompt_embeds), + "encoder_attention_mask": (prompt_embeds_mask, negative_prompt_embeds_mask), + "encoder_hidden_states_2": (prompt_embeds_2, negative_prompt_embeds_2), + "encoder_attention_mask_2": (prompt_embeds_mask_2, negative_prompt_embeds_mask_2), + } + + # Step 2: Update guider's internal state for this denoising step + self.guider.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t) + + # Step 3: Prepare batched model inputs based on the guidance method + # The guider splits model inputs into separate batches for conditional/unconditional predictions. + # For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}: + # you will get a guider_state with two batches: + # guider_state = [ + # {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch + # {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch + # ] + # Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG). + guider_state = self.guider.prepare_inputs(guider_inputs) + # Step 4: Run the denoiser for each batch + # Each batch in guider_state represents a different conditioning (conditional, unconditional, etc.). + # We run the model once per batch and store the noise prediction in guider_state_batch.noise_pred. + for guider_state_batch in guider_state: + self.guider.prepare_models(self.transformer) + + # Extract conditioning kwargs for this batch (e.g., encoder_hidden_states) + cond_kwargs = { + input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys() + } + + # e.g. "pred_cond"/"pred_uncond" + context_name = getattr(guider_state_batch, self.guider._identifier_key) + with self.transformer.cache_context(context_name): + # Run denoiser and store noise prediction in this batch + guider_state_batch.noise_pred = self.transformer( + hidden_states=latent_model_input, + image_embeds=image_embeds, + timestep=timestep, + timestep_r=timestep_r, + attention_kwargs=self.attention_kwargs, + return_dict=False, + **cond_kwargs, + )[0] + + # Cleanup model (e.g., remove hooks) + self.guider.cleanup_models(self.transformer) + + # Step 5: Combine predictions using the guidance method + # The guider takes all noise predictions from guider_state and combines them according to the guidance algorithm. + # Continuing the CFG example, the guider receives: + # guider_state = [ + # {"encoder_hidden_states": prompt_embeds, "noise_pred": noise_pred_cond, "__guidance_identifier__": "pred_cond"}, # batch 0 + # {"encoder_hidden_states": negative_prompt_embeds, "noise_pred": noise_pred_uncond, "__guidance_identifier__": "pred_uncond"}, # batch 1 + # ] + # And extracts predictions using the __guidance_identifier__: + # pred_cond = guider_state[0]["noise_pred"] # extracts noise_pred_cond + # pred_uncond = guider_state[1]["noise_pred"] # extracts noise_pred_uncond + # Then applies CFG formula: + # noise_pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond) + # Returns GuiderOutput(pred=noise_pred, pred_cond=pred_cond, pred_uncond=pred_uncond) + noise_pred = self.guider(guider_state)[0] + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return HunyuanVideo15PipelineOutput(frames=video) diff --git a/diffusers/pipelines/hunyuan_video1_5/pipeline_output.py b/diffusers/pipelines/hunyuan_video1_5/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..ab8767299e55c961bed4d8394e01154029c00544 --- /dev/null +++ b/diffusers/pipelines/hunyuan_video1_5/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import torch + +from diffusers.utils import BaseOutput + + +@dataclass +class HunyuanVideo15PipelineOutput(BaseOutput): + r""" + Output class for HunyuanVideo1.5 pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or list[list[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + frames: torch.Tensor diff --git a/diffusers/pipelines/hunyuandit/__init__.py b/diffusers/pipelines/hunyuandit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8337399106f0585d15fa0a35f607baa2c04b203b --- /dev/null +++ b/diffusers/pipelines/hunyuandit/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_hunyuandit"] = ["HunyuanDiTPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_hunyuandit import HunyuanDiTPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py b/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py new file mode 100644 index 0000000000000000000000000000000000000000..b908dd5dfe8307fba0972ac80ca19530ddbae7c6 --- /dev/null +++ b/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py @@ -0,0 +1,908 @@ +# Copyright 2025 HunyuanDiT Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable + +import numpy as np +import torch +from transformers import BertModel, BertTokenizer, CLIPImageProcessor, T5EncoderModel, T5Tokenizer + +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKL, HunyuanDiT2DModel +from ...models.embeddings import get_2d_rotary_pos_embed +from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from ...schedulers import DDPMScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import HunyuanDiTPipeline + + >>> pipe = HunyuanDiTPipeline.from_pretrained( + ... "Tencent-Hunyuan/HunyuanDiT-Diffusers", torch_dtype=torch.float16 + ... ) + >>> pipe.to("cuda") + + >>> # You may also use English prompt as HunyuanDiT supports both English and Chinese + >>> # prompt = "An astronaut riding a horse" + >>> prompt = "一个宇航员在骑马" + >>> image = pipe(prompt).images[0] + ``` +""" + +STANDARD_RATIO = np.array( + [ + 1.0, # 1:1 + 4.0 / 3.0, # 4:3 + 3.0 / 4.0, # 3:4 + 16.0 / 9.0, # 16:9 + 9.0 / 16.0, # 9:16 + ] +) +STANDARD_SHAPE = [ + [(1024, 1024), (1280, 1280)], # 1:1 + [(1024, 768), (1152, 864), (1280, 960)], # 4:3 + [(768, 1024), (864, 1152), (960, 1280)], # 3:4 + [(1280, 768)], # 16:9 + [(768, 1280)], # 9:16 +] +STANDARD_AREA = [np.array([w * h for w, h in shapes]) for shapes in STANDARD_SHAPE] +SUPPORTED_SHAPE = [ + (1024, 1024), + (1280, 1280), # 1:1 + (1024, 768), + (1152, 864), + (1280, 960), # 4:3 + (768, 1024), + (864, 1152), + (960, 1280), # 3:4 + (1280, 768), # 16:9 + (768, 1280), # 9:16 +] + + +def map_to_standard_shapes(target_width, target_height): + target_ratio = target_width / target_height + closest_ratio_idx = np.argmin(np.abs(STANDARD_RATIO - target_ratio)) + closest_area_idx = np.argmin(np.abs(STANDARD_AREA[closest_ratio_idx] - target_width * target_height)) + width, height = STANDARD_SHAPE[closest_ratio_idx][closest_area_idx] + return width, height + + +def get_resize_crop_region_for_grid(src, tgt_size): + th = tw = tgt_size + h, w = src + + r = h / w + + # resize + if r > 1: + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class HunyuanDiTPipeline(DiffusionPipeline): + r""" + Pipeline for English/Chinese-to-image generation using HunyuanDiT. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + HunyuanDiT uses two text encoders: [mT5](https://huggingface.co/google/mt5-base) and [bilingual CLIP](fine-tuned by + ourselves) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. We use + `sdxl-vae-fp16-fix`. + text_encoder (`~transformers.BertModel`, `~transformers.CLIPTextModel` | None): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + HunyuanDiT uses a fine-tuned [bilingual CLIP]. + tokenizer (`~transformers.BertTokenizer`, `~transformers.CLIPTokenizer` | None): + A `BertTokenizer` or `CLIPTokenizer` to tokenize text. + transformer ([`HunyuanDiT2DModel`]): + The HunyuanDiT model designed by Tencent Hunyuan. + text_encoder_2 (`T5EncoderModel`): + The mT5 embedder. Specifically, it is 't5-v1_1-xxl'. + tokenizer_2 (`T5Tokenizer`): + The tokenizer for the mT5 embedder. + scheduler ([`DDPMScheduler`]): + A scheduler to be used in combination with HunyuanDiT to denoise the encoded image latents. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _optional_components = [ + "safety_checker", + "feature_extractor", + "text_encoder_2", + "tokenizer_2", + "text_encoder", + "tokenizer", + ] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "prompt_embeds_2", + "negative_prompt_embeds_2", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: BertModel, + tokenizer: BertTokenizer, + transformer: HunyuanDiT2DModel, + scheduler: DDPMScheduler, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + text_encoder_2: T5EncoderModel | None = None, + tokenizer_2: T5Tokenizer | None = None, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + text_encoder_2=text_encoder_2, + ) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + self.default_sample_size = ( + self.transformer.config.sample_size + if hasattr(self, "transformer") and self.transformer is not None + else 128 + ) + + def encode_prompt( + self, + prompt: str, + device: torch.device = None, + dtype: torch.dtype = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: str | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + max_sequence_length: int | None = None, + text_encoder_index: int = 0, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + dtype (`torch.dtype`): + torch dtype + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the prompt. Required when `prompt_embeds` is passed directly. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly. + max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt. + text_encoder_index (`int`, *optional*): + Index of the text encoder to use. `0` for clip and `1` for T5. + """ + if dtype is None: + if self.text_encoder_2 is not None: + dtype = self.text_encoder_2.dtype + elif self.transformer is not None: + dtype = self.transformer.dtype + else: + dtype = None + + if device is None: + device = self._execution_device + + tokenizers = [self.tokenizer, self.tokenizer_2] + text_encoders = [self.text_encoder, self.text_encoder_2] + + tokenizer = tokenizers[text_encoder_index] + text_encoder = text_encoders[text_encoder_index] + + if max_sequence_length is None: + if text_encoder_index == 0: + max_length = 77 + if text_encoder_index == 1: + max_length = 256 + else: + max_length = max_sequence_length + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_attention_mask = text_inputs.attention_mask.to(device) + prompt_embeds = text_encoder( + text_input_ids.to(device), + attention_mask=prompt_attention_mask, + ) + prompt_embeds = prompt_embeds[0] + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_attention_mask = uncond_input.attention_mask.to(device) + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + attention_mask=negative_prompt_attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + prompt_embeds_2=None, + negative_prompt_embeds_2=None, + prompt_attention_mask_2=None, + negative_prompt_attention_mask_2=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is None and prompt_embeds_2 is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if prompt_embeds_2 is not None and prompt_attention_mask_2 is None: + raise ValueError("Must provide `prompt_attention_mask_2` when specifying `prompt_embeds_2`.") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if negative_prompt_embeds_2 is not None and negative_prompt_attention_mask_2 is None: + raise ValueError( + "Must provide `negative_prompt_attention_mask_2` when specifying `negative_prompt_embeds_2`." + ) + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_embeds_2 is not None and negative_prompt_embeds_2 is not None: + if prompt_embeds_2.shape != negative_prompt_embeds_2.shape: + raise ValueError( + "`prompt_embeds_2` and `negative_prompt_embeds_2` must have the same shape when passed directly, but" + f" got: `prompt_embeds_2` {prompt_embeds_2.shape} != `negative_prompt_embeds_2`" + f" {negative_prompt_embeds_2.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int | None = 50, + guidance_scale: float | None = 5.0, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float | None = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_2: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds_2: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + prompt_attention_mask_2: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask_2: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + guidance_rescale: float = 0.0, + original_size: tuple[int, int] | None = (1024, 1024), + target_size: tuple[int, int] | None = None, + crops_coords_top_left: tuple[int, int] = (0, 0), + use_resolution_binning: bool = True, + ): + r""" + The call function to the pipeline for generation with HunyuanDiT. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter is modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + prompt_embeds_2 (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + negative_prompt_embeds_2 (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the prompt. Required when `prompt_embeds` is passed directly. + prompt_attention_mask_2 (`torch.Tensor`, *optional*): + Attention mask for the prompt. Required when `prompt_embeds_2` is passed directly. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly. + negative_prompt_attention_mask_2 (`torch.Tensor`, *optional*): + Attention mask for the negative prompt. Required when `negative_prompt_embeds_2` is passed directly. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback_on_step_end (`Callable[[int, int], None]`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A callback function or a list of callback functions to be called at the end of each denoising step. + callback_on_step_end_tensor_inputs (`list[str]`, *optional*): + A list of tensor inputs that should be passed to the callback function. If not defined, all tensor + inputs will be passed. + guidance_rescale (`float`, *optional*, defaults to 0.0): + Rescale the noise_cfg according to `guidance_rescale`. Based on findings of [Common Diffusion Noise + Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). See Section 3.4 + original_size (`tuple[int, int]`, *optional*, defaults to `(1024, 1024)`): + The original size of the image. Used to calculate the time ids. + target_size (`tuple[int, int]`, *optional*): + The target size of the image. Used to calculate the time ids. + crops_coords_top_left (`tuple[int, int]`, *optional*, defaults to `(0, 0)`): + The top left coordinates of the crop. Used to calculate the time ids. + use_resolution_binning (`bool`, *optional*, defaults to `True`): + Whether to use resolution binning or not. If `True`, the input resolution will be mapped to the closest + standard resolution. Supported resolutions are 1024x1024, 1280x1280, 1024x768, 1152x864, 1280x960, + 768x1024, 864x1152, 960x1280, 1280x768, and 768x1280. It is recommended to set this to `True`. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 0. default height and width + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + height = int((height // 16) * 16) + width = int((width // 16) * 16) + + if use_resolution_binning and (height, width) not in SUPPORTED_SHAPE: + width, height = map_to_standard_shapes(width, height) + height = int(height) + width = int(width) + logger.warning(f"Reshaped to (height, width)=({height}, {width}), Supported shapes are {SUPPORTED_SHAPE}") + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + prompt_embeds_2, + negative_prompt_embeds_2, + prompt_attention_mask_2, + negative_prompt_attention_mask_2, + callback_on_step_end_tensor_inputs, + ) + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + + ( + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + device=device, + dtype=self.transformer.dtype, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=77, + text_encoder_index=0, + ) + ( + prompt_embeds_2, + negative_prompt_embeds_2, + prompt_attention_mask_2, + negative_prompt_attention_mask_2, + ) = self.encode_prompt( + prompt=prompt, + device=device, + dtype=self.transformer.dtype, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds_2, + negative_prompt_embeds=negative_prompt_embeds_2, + prompt_attention_mask=prompt_attention_mask_2, + negative_prompt_attention_mask=negative_prompt_attention_mask_2, + max_sequence_length=256, + text_encoder_index=1, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7 create image_rotary_emb, style embedding & time ids + grid_height = height // 8 // self.transformer.config.patch_size + grid_width = width // 8 // self.transformer.config.patch_size + base_size = 512 // 8 // self.transformer.config.patch_size + grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size) + image_rotary_emb = get_2d_rotary_pos_embed( + self.transformer.inner_dim // self.transformer.num_heads, + grid_crops_coords, + (grid_height, grid_width), + device=device, + output_type="pt", + ) + + style = torch.tensor([0], device=device) + + target_size = target_size or (height, width) + add_time_ids = list(original_size + target_size + crops_coords_top_left) + add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype) + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask]) + prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2]) + prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2]) + add_time_ids = torch.cat([add_time_ids] * 2, dim=0) + style = torch.cat([style] * 2, dim=0) + + prompt_embeds = prompt_embeds.to(device=device) + prompt_attention_mask = prompt_attention_mask.to(device=device) + prompt_embeds_2 = prompt_embeds_2.to(device=device) + prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device) + add_time_ids = add_time_ids.to(dtype=prompt_embeds.dtype, device=device).repeat( + batch_size * num_images_per_prompt, 1 + ) + style = style.to(device=device).repeat(batch_size * num_images_per_prompt) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input + t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to( + dtype=latent_model_input.dtype + ) + + # predict the noise residual + noise_pred = self.transformer( + latent_model_input, + t_expand, + encoder_hidden_states=prompt_embeds, + text_embedding_mask=prompt_attention_mask, + encoder_hidden_states_t5=prompt_embeds_2, + text_embedding_mask_t5=prompt_attention_mask_2, + image_meta_size=add_time_ids, + style=style, + image_rotary_emb=image_rotary_emb, + return_dict=False, + )[0] + + noise_pred, _ = noise_pred.chunk(2, dim=1) + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://huggingface.co/papers/2305.08891 + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + prompt_embeds_2 = callback_outputs.pop("prompt_embeds_2", prompt_embeds_2) + negative_prompt_embeds_2 = callback_outputs.pop( + "negative_prompt_embeds_2", negative_prompt_embeds_2 + ) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/pipelines/latent_diffusion/__init__.py b/diffusers/pipelines/latent_diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..561f96fc71dc7b4404e09571e0b7eaa4ee02fde8 --- /dev/null +++ b/diffusers/pipelines/latent_diffusion/__init__.py @@ -0,0 +1,50 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_latent_diffusion"] = ["LDMBertModel", "LDMTextToImagePipeline"] + _import_structure["pipeline_latent_diffusion_superresolution"] = ["LDMSuperResolutionPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_latent_diffusion import LDMBertModel, LDMTextToImagePipeline + from .pipeline_latent_diffusion_superresolution import LDMSuperResolutionPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py b/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..ec43988f93892d0e719874f56dbae914354205ce --- /dev/null +++ b/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py @@ -0,0 +1,745 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect + +import torch +import torch.nn as nn +from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer +from transformers.activations import ACT2FN +from transformers.modeling_outputs import BaseModelOutput +from transformers.utils import logging + +from ...models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel +from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...utils import is_torch_xla_available +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +class LDMTextToImagePipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using latent diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Parameters: + vqvae ([`VQModel`]): + Vector-quantized (VQ) model to encode and decode images to and from latent representations. + bert ([`LDMBertModel`]): + Text-encoder model based on [`~transformers.BERT`]. + tokenizer ([`~transformers.BertTokenizer`]): + A `BertTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + model_cpu_offload_seq = "bert->unet->vqvae" + + def __init__( + self, + vqvae: VQModel | AutoencoderKL, + bert: PreTrainedModel, + tokenizer: PreTrainedTokenizer, + unet: UNet2DModel | UNet2DConditionModel, + scheduler: DDIMScheduler | PNDMScheduler | LMSDiscreteScheduler, + ): + super().__init__() + self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler) + self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1) + + @torch.no_grad() + def __call__( + self, + prompt: str | list[str], + height: int | None = None, + width: int | None = None, + num_inference_steps: int | None = 50, + guidance_scale: float | None = 1.0, + eta: float | None = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + **kwargs, + ) -> tuple | ImagePipelineOutput: + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 1.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`ImagePipelineOutput`] instead of a plain tuple. + + Example: + + ```py + >>> from diffusers import DiffusionPipeline + + >>> # load model and scheduler + >>> ldm = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256") + + >>> # run pipeline in inference (sample random noise and denoise) + >>> prompt = "A painting of a squirrel eating a burger" + >>> images = ldm([prompt], num_inference_steps=50, eta=0.3, guidance_scale=6).images + + >>> # save images + >>> for idx, image in enumerate(images): + ... image.save(f"squirrel-{idx}.png") + ``` + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + # get unconditional embeddings for classifier free guidance + if guidance_scale != 1.0: + uncond_input = self.tokenizer( + [""] * batch_size, padding="max_length", max_length=77, truncation=True, return_tensors="pt" + ) + negative_prompt_embeds = self.bert(uncond_input.input_ids.to(self._execution_device))[0] + + # get prompt text embeddings + text_input = self.tokenizer(prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt") + prompt_embeds = self.bert(text_input.input_ids.to(self._execution_device))[0] + + # get the initial random noise unless the user supplied it + latents_shape = (batch_size, self.unet.config.in_channels, height // 8, width // 8) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor( + latents_shape, generator=generator, device=self._execution_device, dtype=prompt_embeds.dtype + ) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + latents = latents.to(self._execution_device) + + self.scheduler.set_timesteps(num_inference_steps) + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + + extra_kwargs = {} + if accepts_eta: + extra_kwargs["eta"] = eta + + for t in self.progress_bar(self.scheduler.timesteps): + if guidance_scale == 1.0: + # guidance_scale of 1 means no guidance + latents_input = latents + context = prompt_embeds + else: + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + latents_input = torch.cat([latents] * 2) + context = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # predict the noise residual + noise_pred = self.unet(latents_input, t, encoder_hidden_states=context).sample + # perform guidance + if guidance_scale != 1.0: + noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample + + if XLA_AVAILABLE: + xm.mark_step() + + # scale and decode the image latents with vae + latents = 1 / self.vqvae.config.scaling_factor * latents + image = self.vqvae.decode(latents).sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) + + +################################################################################ +# Code for the text transformer model +################################################################################ +""" PyTorch LDMBERT model.""" + + +logger = logging.get_logger(__name__) + +LDMBERT_PRETRAINED_MODEL_ARCHIVE_list = [ + "ldm-bert", + # See all LDMBert models at https://huggingface.co/models?filter=ldmbert +] + + +LDMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "ldm-bert": "https://huggingface.co/valhalla/ldm-bert/blob/main/config.json", +} + + +""" LDMBERT model configuration""" + + +class LDMBertConfig(PretrainedConfig): + model_type = "ldmbert" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"} + + def __init__( + self, + vocab_size=30522, + max_position_embeddings=77, + encoder_layers=32, + encoder_ffn_dim=5120, + encoder_attention_heads=8, + head_dim=64, + encoder_layerdrop=0.0, + activation_function="gelu", + d_model=1280, + dropout=0.1, + attention_dropout=0.0, + activation_dropout=0.0, + init_std=0.02, + classifier_dropout=0.0, + scale_embedding=False, + use_cache=True, + pad_token_id=0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.d_model = d_model + self.encoder_ffn_dim = encoder_ffn_dim + self.encoder_layers = encoder_layers + self.encoder_attention_heads = encoder_attention_heads + self.head_dim = head_dim + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.activation_function = activation_function + self.init_std = init_std + self.encoder_layerdrop = encoder_layerdrop + self.classifier_dropout = classifier_dropout + self.use_cache = use_cache + self.num_hidden_layers = encoder_layers + self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True + + super().__init__(pad_token_id=pad_token_id, **kwargs) + + +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: int | None = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + +# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->LDMBert +class LDMBertAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + head_dim: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = False, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = head_dim + self.inner_dim = head_dim * num_heads + + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + + self.k_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias) + self.out_proj = nn.Linear(self.inner_dim, embed_dim) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: torch.Tensor | None = None, + past_key_value: tuple[torch.Tensor] | None = None, + attention_mask: torch.Tensor | None = None, + layer_head_mask: torch.Tensor | None = None, + output_attentions: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.inner_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +class LDMBertEncoderLayer(nn.Module): + def __init__(self, config: LDMBertConfig): + super().__init__() + self.embed_dim = config.d_model + self.self_attn = LDMBertAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + head_dim=config.head_dim, + dropout=config.attention_dropout, + ) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + layer_head_mask: torch.Tensor, + output_attentions: bool | None = False, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + """ + Args: + hidden_states (`torch.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)` + attention_mask (`torch.Tensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.Tensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, attn_weights, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() + ): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +# Copied from transformers.models.bart.modeling_bart.BartPretrainedModel with Bart->LDMBert +class LDMBertPreTrainedModel(PreTrainedModel): + config_class = LDMBertConfig + base_model_prefix = "model" + _supports_gradient_checkpointing = True + _keys_to_ignore_on_load_unexpected = [r"encoder\.version", r"decoder\.version"] + + def _init_weights(self, module): + std = self.config.init_std + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + @property + def dummy_inputs(self): + pad_token = self.config.pad_token_id + input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device) + dummy_inputs = { + "attention_mask": input_ids.ne(pad_token), + "input_ids": input_ids, + } + return dummy_inputs + + +class LDMBertEncoder(LDMBertPreTrainedModel): + """ + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a + [`LDMBertEncoderLayer`]. + + Args: + config: LDMBertConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__(self, config: LDMBertConfig): + super().__init__(config) + + self.dropout = config.dropout + + embed_dim = config.d_model + self.padding_idx = config.pad_token_id + self.max_source_positions = config.max_position_embeddings + + self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim) + self.embed_positions = nn.Embedding(config.max_position_embeddings, embed_dim) + self.layers = nn.ModuleList([LDMBertEncoderLayer(config) for _ in range(config.encoder_layers)]) + self.layer_norm = nn.LayerNorm(embed_dim) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + head_mask: torch.Tensor | None = None, + inputs_embeds: torch.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + ) -> tuple | BaseModelOutput: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.BaseModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + seq_len = input_shape[1] + if position_ids is None: + position_ids = torch.arange(seq_len, dtype=torch.long, device=inputs_embeds.device).expand((1, -1)) + embed_pos = self.embed_positions(position_ids) + + hidden_states = inputs_embeds + embed_pos + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + if head_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + if torch.is_grad_enabled() and self.gradient_checkpointing: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer, + hidden_states, + attention_mask, + (head_mask[idx] if head_mask is not None else None), + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + hidden_states = self.layer_norm(hidden_states) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +class LDMBertModel(LDMBertPreTrainedModel): + _no_split_modules = [] + + def __init__(self, config: LDMBertConfig): + super().__init__(config) + self.model = LDMBertEncoder(config) + self.to_logits = nn.Linear(config.hidden_size, config.vocab_size) + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + return outputs diff --git a/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py b/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py new file mode 100644 index 0000000000000000000000000000000000000000..18cb8274f9b5fefd481346c25297d069e60984aa --- /dev/null +++ b/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py @@ -0,0 +1,196 @@ +import inspect + +import numpy as np +import PIL.Image +import torch + +from ...models import UNet2DModel, VQModel +from ...schedulers import ( + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) +from ...utils import PIL_INTERPOLATION, is_torch_xla_available +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +def preprocess(image): + w, h = image.size + w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return 2.0 * image - 1.0 + + +class LDMSuperResolutionPipeline(DiffusionPipeline): + r""" + A pipeline for image super-resolution using latent diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Parameters: + vqvae ([`VQModel`]): + Vector-quantized (VQ) model to encode and decode images to and from latent representations. + unet ([`UNet2DModel`]): + A `UNet2DModel` to denoise the encoded image. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], + [`EulerAncestralDiscreteScheduler`], [`DPMSolverMultistepScheduler`], or [`PNDMScheduler`]. + """ + + def __init__( + self, + vqvae: VQModel, + unet: UNet2DModel, + scheduler: DDIMScheduler + | PNDMScheduler + | LMSDiscreteScheduler + | EulerDiscreteScheduler + | EulerAncestralDiscreteScheduler + | DPMSolverMultistepScheduler, + ): + super().__init__() + self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + image: torch.Tensor | PIL.Image.Image = None, + batch_size: int | None = 1, + num_inference_steps: int | None = 100, + eta: float | None = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + ) -> tuple | ImagePipelineOutput: + r""" + The call function to the pipeline for generation. + + Args: + image (`torch.Tensor` or `PIL.Image.Image`): + `Image` or tensor representing an image batch to be used as the starting point for the process. + batch_size (`int`, *optional*, defaults to 1): + Number of images to generate. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`ImagePipelineOutput`] instead of a plain tuple. + + Example: + + ```py + >>> import requests + >>> from PIL import Image + >>> from io import BytesIO + >>> from diffusers import LDMSuperResolutionPipeline + >>> import torch + + >>> # load model and scheduler + >>> pipeline = LDMSuperResolutionPipeline.from_pretrained("CompVis/ldm-super-resolution-4x-openimages") + >>> pipeline = pipeline.to("cuda") + + >>> # let's download an image + >>> url = ( + ... "https://user-images.githubusercontent.com/38061659/199705896-b48e17b8-b231-47cd-a270-4ffa5a93fa3e.png" + ... ) + >>> response = requests.get(url) + >>> low_res_img = Image.open(BytesIO(response.content)).convert("RGB") + >>> low_res_img = low_res_img.resize((128, 128)) + + >>> # run pipeline in inference (sample random noise and denoise) + >>> upscaled_image = pipeline(low_res_img, num_inference_steps=100, eta=1).images[0] + >>> # save image + >>> upscaled_image.save("ldm_generated_image.png") + ``` + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images + """ + if isinstance(image, PIL.Image.Image): + batch_size = 1 + elif isinstance(image, torch.Tensor): + batch_size = image.shape[0] + else: + raise ValueError(f"`image` has to be of type `PIL.Image.Image` or `torch.Tensor` but is {type(image)}") + + if isinstance(image, PIL.Image.Image): + image = preprocess(image) + + height, width = image.shape[-2:] + + # in_channels should be 6: 3 for latents, 3 for low resolution image + latents_shape = (batch_size, self.unet.config.in_channels // 2, height, width) + latents_dtype = next(self.unet.parameters()).dtype + + latents = randn_tensor(latents_shape, generator=generator, device=self.device, dtype=latents_dtype) + + image = image.to(device=self.device, dtype=latents_dtype) + + # set timesteps and move to the correct device + self.scheduler.set_timesteps(num_inference_steps, device=self.device) + timesteps_tensor = self.scheduler.timesteps + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature. + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_kwargs = {} + if accepts_eta: + extra_kwargs["eta"] = eta + + for t in self.progress_bar(timesteps_tensor): + # concat latents and low resolution image in the channel dimension. + latents_input = torch.cat([latents, image], dim=1) + latents_input = self.scheduler.scale_model_input(latents_input, t) + # predict the noise residual + noise_pred = self.unet(latents_input, t).sample + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample + + if XLA_AVAILABLE: + xm.mark_step() + + # decode the image latents with the VQVAE + image = self.vqvae.decode(latents).sample + image = torch.clamp(image, -1.0, 1.0) + image = image / 2 + 0.5 + image = image.cpu().permute(0, 2, 3, 1).numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/pipelines/latte/__init__.py b/diffusers/pipelines/latte/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4296b42e125303c0e036b9f2deecc36bdf959de3 --- /dev/null +++ b/diffusers/pipelines/latte/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_latte"] = ["LattePipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_latte import LattePipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/pipelines/latte/pipeline_latte.py b/diffusers/pipelines/latte/pipeline_latte.py new file mode 100644 index 0000000000000000000000000000000000000000..eed7762cebf15a9dec28e709b27399660ff7d517 --- /dev/null +++ b/diffusers/pipelines/latte/pipeline_latte.py @@ -0,0 +1,914 @@ +# Copyright 2025 the Latte Team and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import inspect +import re +import urllib.parse as ul +from dataclasses import dataclass +from typing import Callable + +import torch +from transformers import T5EncoderModel, T5Tokenizer + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...models import AutoencoderKL, LatteTransformer3DModel +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + BACKENDS_MAPPING, + BaseOutput, + deprecate, + is_bs4_available, + is_ftfy_available, + is_torch_xla_available, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import is_compiled_module, randn_tensor +from ...video_processor import VideoProcessor + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import LattePipeline + >>> from diffusers.utils import export_to_gif + + >>> # You can replace the checkpoint id with "maxin-cn/Latte-1" too. + >>> pipe = LattePipeline.from_pretrained("maxin-cn/Latte-1", torch_dtype=torch.float16) + >>> # Enable memory optimizations. + >>> pipe.enable_model_cpu_offload() + + >>> prompt = "A small cactus with a happy face in the Sahara desert." + >>> videos = pipe(prompt).frames[0] + >>> export_to_gif(videos, "latte.gif") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +@dataclass +class LattePipelineOutput(BaseOutput): + frames: torch.Tensor + + +class LattePipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using Latte. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. Latte uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the + [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. + tokenizer (`T5Tokenizer`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`LatteTransformer3DModel`]): + A text conditioned `LatteTransformer3DModel` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + """ + + bad_punct_regex = re.compile(r"[#®•©™&@·º½¾¿¡§~\)\(\]\[\}\{\|\\/\\*]{1,}") + + _optional_components = ["tokenizer", "text_encoder"] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKL, + transformer: LatteTransformer3DModel, + scheduler: KarrasDiffusionSchedulers, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor) + + # Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py + def mask_text_embeddings(self, emb, mask): + if emb.shape[0] == 1: + keep_index = mask.sum().item() + return emb[:, :, :keep_index, :], keep_index # 1, 120, 4096 -> 1 7 4096 + else: + masked_feature = emb * mask[:, None, :, None] # 1 120 4096 + return masked_feature, emb.shape[2] + + # Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + do_classifier_free_guidance: bool = True, + negative_prompt: str = "", + num_images_per_prompt: int = 1, + device: torch.device | None = None, + prompt_embeds: torch.FloatTensor | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + clean_caption: bool = False, + mask_feature: bool = True, + dtype=None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt not to guide the video generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For + Latte, this should be "". + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of video that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. For Latte, it's should be the embeddings of the "" string. + clean_caption (bool, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + mask_feature: (bool, defaults to `True`): + If `True`, the function will mask the text embeddings. + """ + embeds_initially_provided = prompt_embeds is not None and negative_prompt_embeds is not None + + if device is None: + device = self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + max_length = 120 + if prompt_embeds is None: + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {max_length} tokens: {removed_text}" + ) + + attention_mask = text_inputs.attention_mask.to(device) + prompt_embeds_attention_mask = attention_mask + + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds_attention_mask = torch.ones_like(prompt_embeds) + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.transformer is not None: + dtype = self.transformer.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_embeds_attention_mask = prompt_embeds_attention_mask.view(bs_embed, -1) + prompt_embeds_attention_mask = prompt_embeds_attention_mask.repeat(num_images_per_prompt, 1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt + uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + attention_mask = uncond_input.attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + else: + negative_prompt_embeds = None + + # Perform additional masking. + if mask_feature and not embeds_initially_provided: + prompt_embeds = prompt_embeds.unsqueeze(1) + masked_prompt_embeds, keep_indices = self.mask_text_embeddings(prompt_embeds, prompt_embeds_attention_mask) + masked_prompt_embeds = masked_prompt_embeds.squeeze(1) + masked_negative_prompt_embeds = ( + negative_prompt_embeds[:, :keep_indices, :] if negative_prompt_embeds is not None else None + ) + + return masked_prompt_embeds, masked_negative_prompt_embeds + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip addresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + shape = ( + batch_size, + num_channels_latents, + num_frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str = "", + num_inference_steps: int = 50, + timesteps: list[int] | None = None, + guidance_scale: float = 7.5, + num_images_per_prompt: int = 1, + video_length: int = 16, + height: int = 512, + width: int = 512, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.FloatTensor | None = None, + prompt_embeds: torch.FloatTensor | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + output_type: str = "pil", + return_dict: bool = True, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + clean_caption: bool = True, + mask_feature: bool = True, + enable_temporal_attentions: bool = True, + decode_chunk_size: int = 14, + ) -> LattePipelineOutput | tuple: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the video generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality video at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps are used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate videos that are closely linked to + the text `prompt`, usually at the expense of lower video quality. + video_length (`int`, *optional*, defaults to 16): + The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + height (`int`, *optional*, defaults to self.unet.config.sample_size): + The height in pixels of the generated video. + width (`int`, *optional*, defaults to self.unet.config.sample_size): + The width in pixels of the generated video. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. For Latte this negative prompt should be "". If not provided, + negative_prompt_embeds will be generated from `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate video. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + callback_on_step_end (`Callable[[int, int], None]`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A callback function or a list of callback functions to be called at the end of each denoising step. + callback_on_step_end_tensor_inputs (`list[str]`, *optional*): + A list of tensor inputs that should be passed to the callback function. If not defined, all tensor + inputs will be passed. + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + mask_feature (`bool` defaults to `True`): If set to `True`, the text embeddings will be masked. + enable_temporal_attentions (`bool`, *optional*, defaults to `True`): Whether to enable temporal attentions + decode_chunk_size (`int`, *optional*): + The number of frames to decode at a time. Higher chunk size leads to better temporal consistency at the + expense of more memory usage. By default, the decoder decodes all frames at once for maximal quality. + For lower memory usage, reduce `decode_chunk_size`. + + Examples: + + Returns: + [`~pipelines.latte.pipeline_latte.LattePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.latte.pipeline_latte.LattePipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images + """ + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 0. Default + decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else video_length + + # 1. Check inputs. Raise error if not correct + height = height or self.transformer.config.sample_size * self.vae_scale_factor + width = width or self.transformer.config.sample_size * self.vae_scale_factor + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._current_timestep = None + self._interrupt = False + + # 2. Default height and width to transformer + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + clean_caption=clean_caption, + mask_feature=mask_feature, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps + ) + self._num_timesteps = len(timesteps) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + video_length, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + current_timestep = t + if not torch.is_tensor(current_timestep): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = latent_model_input.device.type == "mps" + is_npu = latent_model_input.device.type == "npu" + if isinstance(current_timestep, float): + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + else: + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 + current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device) + elif len(current_timestep.shape) == 0: + current_timestep = current_timestep[None].to(latent_model_input.device) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + current_timestep = current_timestep.expand(latent_model_input.shape[0]) + + # predict noise model_output + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=current_timestep, + enable_temporal_attentions=enable_temporal_attentions, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # use learned sigma? + if not ( + hasattr(self.scheduler.config, "variance_type") + and self.scheduler.config.variance_type in ["learned", "learned_range"] + ): + noise_pred = noise_pred.chunk(2, dim=1)[0] + + # compute previous video: x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if output_type == "latents": + deprecation_message = ( + "Passing `output_type='latents'` is deprecated. Please pass `output_type='latent'` instead." + ) + deprecate("output_type_latents", "1.0.0", deprecation_message, standard_warn=False) + output_type = "latent" + + if not output_type == "latent": + video = self.decode_latents(latents, video_length, decode_chunk_size=decode_chunk_size) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return LattePipelineOutput(frames=video) + + # Similar to diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion.decode_latents + def decode_latents(self, latents: torch.Tensor, video_length: int, decode_chunk_size: int = 14): + # [batch, channels, frames, height, width] -> [batch*frames, channels, height, width] + latents = latents.permute(0, 2, 1, 3, 4).flatten(0, 1) + + latents = 1 / self.vae.config.scaling_factor * latents + + forward_vae_fn = self.vae._orig_mod.forward if is_compiled_module(self.vae) else self.vae.forward + accepts_num_frames = "num_frames" in set(inspect.signature(forward_vae_fn).parameters.keys()) + + # decode decode_chunk_size frames at a time to avoid OOM + frames = [] + for i in range(0, latents.shape[0], decode_chunk_size): + num_frames_in = latents[i : i + decode_chunk_size].shape[0] + decode_kwargs = {} + if accepts_num_frames: + # we only pass num_frames_in if it's expected + decode_kwargs["num_frames"] = num_frames_in + + frame = self.vae.decode(latents[i : i + decode_chunk_size], **decode_kwargs).sample + frames.append(frame) + frames = torch.cat(frames, dim=0) + + # [batch*frames, channels, height, width] -> [batch, channels, frames, height, width] + frames = frames.reshape(-1, video_length, *frames.shape[1:]).permute(0, 2, 1, 3, 4) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + frames = frames.float() + return frames diff --git a/diffusers/pipelines/longcat_image/pipeline_longcat_image.py b/diffusers/pipelines/longcat_image/pipeline_longcat_image.py new file mode 100644 index 0000000000000000000000000000000000000000..19720d7bbab80e85ab1cc08e34ecdc841bd66263 --- /dev/null +++ b/diffusers/pipelines/longcat_image/pipeline_longcat_image.py @@ -0,0 +1,666 @@ +# Copyright 2025 MeiTuan LongCat-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +import re +from typing import Any + +import numpy as np +import torch +from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor + +from ...image_processor import VaeImageProcessor +from ...loaders import FromSingleFileMixin +from ...models.autoencoders import AutoencoderKL +from ...models.transformers import LongCatImageTransformer2DModel +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from .pipeline_output import LongCatImagePipelineOutput +from .system_messages import SYSTEM_PROMPT_EN, SYSTEM_PROMPT_ZH + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import LongCatImagePipeline + + >>> pipe = LongCatImagePipeline.from_pretrained("meituan-longcat/LongCat-Image", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> prompt = "一个年轻的亚裔女性,身穿黄色针织衫,搭配白色项链。她的双手放在膝盖上,表情恬静。背景是一堵粗糙的砖墙,午后的阳光温暖地洒在她身上,营造出一种宁静而温馨的氛围。镜头采用中距离视角,突出她的神态和服饰的细节。光线柔和地打在她的脸上,强调她的五官和饰品的质感,增加画面的层次感与亲和力。整个画面构图简洁,砖墙的纹理与阳光的光影效果相得益彰,突显出人物的优雅与从容。" + >>> image = pipe( + ... prompt, + ... height=768, + ... width=1344, + ... num_inference_steps=50, + ... guidance_scale=4.5, + ... generator=torch.Generator("cpu").manual_seed(43), + ... enable_cfg_renorm=True, + ... ).images[0] + >>> image.save("longcat_image.png") + ``` +""" + + +def get_prompt_language(prompt): + pattern = re.compile(r"[\u4e00-\u9fff]") + if bool(pattern.search(prompt)): + return "zh" + return "en" + + +def split_quotation(prompt, quote_pairs=None): + """ + Implement a regex-based string splitting algorithm that identifies delimiters defined by single or double quote + pairs. Examples:: + >>> prompt_en = "Please write 'Hello' on the blackboard for me." >>> print(split_quotation(prompt_en)) >>> # + output: [('Please write ', False), ("'Hello'", True), (' on the blackboard for me.', False)] + """ + word_internal_quote_pattern = re.compile(r"[a-zA-Z]+'[a-zA-Z]+") + matches_word_internal_quote_pattern = word_internal_quote_pattern.findall(prompt) + mapping_word_internal_quote = [] + + for i, word_src in enumerate(set(matches_word_internal_quote_pattern)): + word_tgt = "longcat_$##$_longcat" * (i + 1) + prompt = prompt.replace(word_src, word_tgt) + mapping_word_internal_quote.append([word_src, word_tgt]) + + if quote_pairs is None: + quote_pairs = [("'", "'"), ('"', '"'), ("‘", "’"), ("“", "”")] + pattern = "|".join([re.escape(q1) + r"[^" + re.escape(q1 + q2) + r"]*?" + re.escape(q2) for q1, q2 in quote_pairs]) + parts = re.split(f"({pattern})", prompt) + + result = [] + for part in parts: + for word_src, word_tgt in mapping_word_internal_quote: + part = part.replace(word_tgt, word_src) + if re.match(pattern, part): + if len(part): + result.append((part, True)) + else: + if len(part): + result.append((part, False)) + return result + + +def prepare_pos_ids(modality_id=0, type="text", start=(0, 0), num_token=None, height=None, width=None): + if type == "text": + assert num_token + if height or width: + print('Warning: The parameters of height and width will be ignored in "text" type.') + pos_ids = torch.zeros(num_token, 3) + pos_ids[..., 0] = modality_id + pos_ids[..., 1] = torch.arange(num_token) + start[0] + pos_ids[..., 2] = torch.arange(num_token) + start[1] + elif type == "image": + assert height and width + if num_token: + print('Warning: The parameter of num_token will be ignored in "image" type.') + pos_ids = torch.zeros(height, width, 3) + pos_ids[..., 0] = modality_id + pos_ids[..., 1] = pos_ids[..., 1] + torch.arange(height)[:, None] + start[0] + pos_ids[..., 2] = pos_ids[..., 2] + torch.arange(width)[None, :] + start[1] + pos_ids = pos_ids.reshape(height * width, 3) + else: + raise KeyError(f'Unknow type {type}, only support "text" or "image".') + return pos_ids + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class LongCatImagePipeline(DiffusionPipeline, FromSingleFileMixin): + r""" + The pipeline for text-to-image generation. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: Qwen2_5_VLForConditionalGeneration, + tokenizer: Qwen2Tokenizer, + text_processor: Qwen2VLProcessor, + transformer: LongCatImageTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + text_processor=text_processor, + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + + self.prompt_template_encode_prefix = "<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n" + self.prompt_template_encode_suffix = "<|im_end|>\n<|im_start|>assistant\n" + self.default_sample_size = 128 + self.tokenizer_max_length = 512 + + def rewire_prompt(self, prompt, device): + prompt = [prompt] if isinstance(prompt, str) else prompt + all_text = [] + for each_prompt in prompt: + language = get_prompt_language(each_prompt) + if language == "zh": + question = SYSTEM_PROMPT_ZH + f"\n用户输入为:{each_prompt}\n改写后的prompt为:" + else: + question = SYSTEM_PROMPT_EN + f"\nUser Input: {each_prompt}\nRewritten prompt:" + message = [ + { + "role": "user", + "content": [ + {"type": "text", "text": question}, + ], + } + ] + # Preparation for inference + text = self.text_processor.apply_chat_template(message, tokenize=False, add_generation_prompt=True) + all_text.append(text) + + inputs = self.text_processor(text=all_text, padding=True, return_tensors="pt").to(self.text_encoder.device) + + generated_ids = self.text_encoder.generate(**inputs, max_new_tokens=self.tokenizer_max_length) + generated_ids.to(device) + generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)] + output_text = self.text_processor.batch_decode( + generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False + ) + rewrite_prompt = output_text + return rewrite_prompt + + def _encode_prompt(self, prompt: list[str]): + batch_all_tokens = [] + + for each_prompt in prompt: + all_tokens = [] + for clean_prompt_sub, matched in split_quotation(each_prompt): + if matched: + for sub_word in clean_prompt_sub: + tokens = self.tokenizer(sub_word, add_special_tokens=False)["input_ids"] + all_tokens.extend(tokens) + else: + tokens = self.tokenizer(clean_prompt_sub, add_special_tokens=False)["input_ids"] + all_tokens.extend(tokens) + + if len(all_tokens) > self.tokenizer_max_length: + logger.warning( + "Your input was truncated because `max_sequence_length` is set to " + f" {self.tokenizer_max_length} input token nums : {len(all_tokens)}" + ) + all_tokens = all_tokens[: self.tokenizer_max_length] + batch_all_tokens.append(all_tokens) + + text_tokens_and_mask = self.tokenizer.pad( + {"input_ids": batch_all_tokens}, + max_length=self.tokenizer_max_length, + padding="max_length", + return_attention_mask=True, + return_tensors="pt", + ) + + prefix_tokens = self.tokenizer(self.prompt_template_encode_prefix, add_special_tokens=False)["input_ids"] + suffix_tokens = self.tokenizer(self.prompt_template_encode_suffix, add_special_tokens=False)["input_ids"] + prefix_len = len(prefix_tokens) + suffix_len = len(suffix_tokens) + + prefix_tokens_mask = torch.tensor([1] * len(prefix_tokens), dtype=text_tokens_and_mask.attention_mask[0].dtype) + suffix_tokens_mask = torch.tensor([1] * len(suffix_tokens), dtype=text_tokens_and_mask.attention_mask[0].dtype) + + prefix_tokens = torch.tensor(prefix_tokens, dtype=text_tokens_and_mask.input_ids.dtype) + suffix_tokens = torch.tensor(suffix_tokens, dtype=text_tokens_and_mask.input_ids.dtype) + + batch_size = text_tokens_and_mask.input_ids.size(0) + + prefix_tokens_batch = prefix_tokens.unsqueeze(0).expand(batch_size, -1) + suffix_tokens_batch = suffix_tokens.unsqueeze(0).expand(batch_size, -1) + prefix_mask_batch = prefix_tokens_mask.unsqueeze(0).expand(batch_size, -1) + suffix_mask_batch = suffix_tokens_mask.unsqueeze(0).expand(batch_size, -1) + + input_ids = torch.cat((prefix_tokens_batch, text_tokens_and_mask.input_ids, suffix_tokens_batch), dim=-1) + attention_mask = torch.cat((prefix_mask_batch, text_tokens_and_mask.attention_mask, suffix_mask_batch), dim=-1) + + input_ids = input_ids.to(self.device) + attention_mask = attention_mask.to(self.device) + + text_output = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True) + # [max_sequence_length, batch, hidden_size] -> [batch, max_sequence_length, hidden_size] + # clone to have a contiguous tensor + prompt_embeds = text_output.hidden_states[-1].detach() + prompt_embeds = prompt_embeds[:, prefix_len:-suffix_len, :] + return prompt_embeds + + def encode_prompt( + self, + prompt: str | list[str] = None, + num_images_per_prompt: int | None = 1, + prompt_embeds: torch.Tensor | None = None, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + # If prompt_embeds is provided and prompt is None, skip encoding + if prompt_embeds is None: + prompt_embeds = self._encode_prompt(prompt) + + _, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + text_ids = prepare_pos_ids(modality_id=0, type="text", start=(0, 0), num_token=prompt_embeds.shape[1]).to( + self.device + ) + return prompt_embeds.to(self.device), text_ids + + @staticmethod + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + latent_image_ids = prepare_pos_ids( + modality_id=1, + type="image", + start=(self.tokenizer_max_length, self.tokenizer_max_length), + height=height // 2, + width=width // 2, + ).to(device) + + if latents is not None: + return latents.to(device=device, dtype=dtype), latent_image_ids + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device) + latents = latents.to(dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + + return latents, latent_image_ids + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + def check_inputs( + self, prompt, height, width, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + @replace_example_docstring(EXAMPLE_DOC_STRING) + @torch.no_grad() + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str | list[str] = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + sigmas: list[float] | None = None, + guidance_scale: float = 4.5, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.FloatTensor | None = None, + prompt_embeds: torch.FloatTensor | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + joint_attention_kwargs: dict[str, Any] | None = None, + enable_cfg_renorm: bool | None = True, + cfg_renorm_min: float | None = 0.0, + enable_prompt_rewrite: bool | None = True, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + enable_cfg_renorm: Whether to enable cfg_renorm. Enabling cfg_renorm will improve image quality, + but it may lead to a decrease in the stability of some image outputs.. + cfg_renorm_min: The minimum value of the cfg_renorm_scale range (0-1). + cfg_renorm_min = 1.0, renorm has no effect, while cfg_renorm_min=0.0, the renorm range is larger. + enable_prompt_rewrite: whether to enable prompt rewrite. + Examples: + + Returns: + [`~pipelines.LongCatImagePipelineOutput`] or `tuple`: [`~pipelines.LongCatImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + if enable_prompt_rewrite: + prompt = self.rewire_prompt(prompt, device) + logger.info(f"Rewrite prompt {prompt}!") + + negative_prompt = "" if negative_prompt is None else negative_prompt + (prompt_embeds, text_ids) = self.encode_prompt( + prompt=prompt, prompt_embeds=prompt_embeds, num_images_per_prompt=num_images_per_prompt + ) + if self.do_classifier_free_guidance: + (negative_prompt_embeds, negative_text_ids) = self.encode_prompt( + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + num_images_per_prompt=num_images_per_prompt, + ) + + # 4. Prepare latent variables + num_channels_latents = 16 + latents, latent_image_ids = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1.0 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + guidance = None + + if self.joint_attention_kwargs is None: + self._joint_attention_kwargs = {} + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + timestep = t.expand(latents.shape[0]).to(latents.dtype) + with self.transformer.cache_context("cond"): + noise_pred_text = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + return_dict=False, + )[0] + if self.do_classifier_free_guidance: + with self.transformer.cache_context("uncond"): + noise_pred_uncond = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + encoder_hidden_states=negative_prompt_embeds, + txt_ids=negative_text_ids, + img_ids=latent_image_ids, + return_dict=False, + )[0] + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + if enable_cfg_renorm: + cond_norm = torch.norm(noise_pred_text, dim=-1, keepdim=True) + noise_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + scale = (cond_norm / (noise_norm + 1e-8)).clamp(min=cfg_renorm_min, max=1.0) + noise_pred = noise_pred * scale + else: + noise_pred = noise_pred_text + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + if latents.dtype != self.vae.dtype: + latents = latents.to(dtype=self.vae.dtype) + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return LongCatImagePipelineOutput(images=image) diff --git a/diffusers/pipelines/longcat_image/pipeline_longcat_image_edit.py b/diffusers/pipelines/longcat_image/pipeline_longcat_image_edit.py new file mode 100644 index 0000000000000000000000000000000000000000..69d5d82f18ecd94937cad4bca1f2b8dfc0d5c23e --- /dev/null +++ b/diffusers/pipelines/longcat_image/pipeline_longcat_image_edit.py @@ -0,0 +1,727 @@ +# Copyright 2025 MeiTuan LongCat-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +import math +import re +from typing import Any + +import numpy as np +import PIL +import torch +from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor + +from ...image_processor import VaeImageProcessor +from ...loaders import FromSingleFileMixin +from ...models.autoencoders import AutoencoderKL +from ...models.transformers import LongCatImageTransformer2DModel +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from .pipeline_output import LongCatImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from PIL import Image + >>> import torch + >>> from diffusers import LongCatImageEditPipeline + + >>> pipe = LongCatImageEditPipeline.from_pretrained( + ... "meituan-longcat/LongCat-Image-Edit", torch_dtype=torch.bfloat16 + ... ) + >>> pipe.to("cuda") + + >>> prompt = "change the cat to dog." + >>> input_image = Image.open("test.jpg").convert("RGB") + >>> image = pipe( + ... input_image, + ... prompt, + ... num_inference_steps=50, + ... guidance_scale=4.5, + ... generator=torch.Generator("cpu").manual_seed(43), + ... ).images[0] + >>> image.save("longcat_image_edit.png") + ``` +""" + + +# Copied from diffusers.pipelines.longcat_image.pipeline_longcat_image.split_quotation +def split_quotation(prompt, quote_pairs=None): + """ + Implement a regex-based string splitting algorithm that identifies delimiters defined by single or double quote + pairs. Examples:: + >>> prompt_en = "Please write 'Hello' on the blackboard for me." >>> print(split_quotation(prompt_en)) >>> # + output: [('Please write ', False), ("'Hello'", True), (' on the blackboard for me.', False)] + """ + word_internal_quote_pattern = re.compile(r"[a-zA-Z]+'[a-zA-Z]+") + matches_word_internal_quote_pattern = word_internal_quote_pattern.findall(prompt) + mapping_word_internal_quote = [] + + for i, word_src in enumerate(set(matches_word_internal_quote_pattern)): + word_tgt = "longcat_$##$_longcat" * (i + 1) + prompt = prompt.replace(word_src, word_tgt) + mapping_word_internal_quote.append([word_src, word_tgt]) + + if quote_pairs is None: + quote_pairs = [("'", "'"), ('"', '"'), ("‘", "’"), ("“", "”")] + pattern = "|".join([re.escape(q1) + r"[^" + re.escape(q1 + q2) + r"]*?" + re.escape(q2) for q1, q2 in quote_pairs]) + parts = re.split(f"({pattern})", prompt) + + result = [] + for part in parts: + for word_src, word_tgt in mapping_word_internal_quote: + part = part.replace(word_tgt, word_src) + if re.match(pattern, part): + if len(part): + result.append((part, True)) + else: + if len(part): + result.append((part, False)) + return result + + +# Copied from diffusers.pipelines.longcat_image.pipeline_longcat_image.prepare_pos_ids +def prepare_pos_ids(modality_id=0, type="text", start=(0, 0), num_token=None, height=None, width=None): + if type == "text": + assert num_token + if height or width: + print('Warning: The parameters of height and width will be ignored in "text" type.') + pos_ids = torch.zeros(num_token, 3) + pos_ids[..., 0] = modality_id + pos_ids[..., 1] = torch.arange(num_token) + start[0] + pos_ids[..., 2] = torch.arange(num_token) + start[1] + elif type == "image": + assert height and width + if num_token: + print('Warning: The parameter of num_token will be ignored in "image" type.') + pos_ids = torch.zeros(height, width, 3) + pos_ids[..., 0] = modality_id + pos_ids[..., 1] = pos_ids[..., 1] + torch.arange(height)[:, None] + start[0] + pos_ids[..., 2] = pos_ids[..., 2] + torch.arange(width)[None, :] + start[1] + pos_ids = pos_ids.reshape(height * width, 3) + else: + raise KeyError(f'Unknow type {type}, only support "text" or "image".') + return pos_ids + + +# Copied from diffusers.pipelines.longcat_image.pipeline_longcat_image.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +def calculate_dimensions(target_area, ratio): + width = math.sqrt(target_area * ratio) + height = width / ratio + + width = width if width % 16 == 0 else (width // 16 + 1) * 16 + height = height if height % 16 == 0 else (height // 16 + 1) * 16 + + width = int(width) + height = int(height) + + return width, height + + +class LongCatImageEditPipeline(DiffusionPipeline, FromSingleFileMixin): + r""" + The LongCat-Image-Edit pipeline for image editing. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: Qwen2_5_VLForConditionalGeneration, + tokenizer: Qwen2Tokenizer, + text_processor: Qwen2VLProcessor, + transformer: LongCatImageTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + text_processor=text_processor, + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.image_processor_vl = text_processor.image_processor + + self.image_token = "<|image_pad|>" + self.prompt_template_encode_prefix = "<|im_start|>system\nAs an image editing expert, first analyze the content and attributes of the input image(s). Then, based on the user's editing instructions, clearly and precisely determine how to modify the given image(s), ensuring that only the specified parts are altered and all other aspects remain consistent with the original(s).<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>" + self.prompt_template_encode_suffix = "<|im_end|>\n<|im_start|>assistant\n" + self.default_sample_size = 128 + self.tokenizer_max_length = 512 + + def _encode_prompt(self, prompt, image): + raw_vl_input = self.image_processor_vl(images=image, return_tensors="pt") + pixel_values = raw_vl_input["pixel_values"] + image_grid_thw = raw_vl_input["image_grid_thw"] + all_tokens = [] + for clean_prompt_sub, matched in split_quotation(prompt[0]): + if matched: + for sub_word in clean_prompt_sub: + tokens = self.tokenizer(sub_word, add_special_tokens=False)["input_ids"] + all_tokens.extend(tokens) + else: + tokens = self.tokenizer(clean_prompt_sub, add_special_tokens=False)["input_ids"] + all_tokens.extend(tokens) + + if len(all_tokens) > self.tokenizer_max_length: + logger.warning( + "Your input was truncated because `max_sequence_length` is set to " + f" {self.tokenizer_max_length} input token nums : {len(len(all_tokens))}" + ) + all_tokens = all_tokens[: self.tokenizer_max_length] + + text_tokens_and_mask = self.tokenizer.pad( + {"input_ids": [all_tokens]}, + max_length=self.tokenizer_max_length, + padding="max_length", + return_attention_mask=True, + return_tensors="pt", + ) + + text = self.prompt_template_encode_prefix + + merge_length = self.image_processor_vl.merge_size**2 + while self.image_token in text: + num_image_tokens = image_grid_thw.prod() // merge_length + text = text.replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1) + text = text.replace("<|placeholder|>", self.image_token) + + prefix_tokens = self.tokenizer(text, add_special_tokens=False)["input_ids"] + suffix_tokens = self.tokenizer(self.prompt_template_encode_suffix, add_special_tokens=False)["input_ids"] + + vision_start_token_id = self.tokenizer.convert_tokens_to_ids("<|vision_start|>") + prefix_len = prefix_tokens.index(vision_start_token_id) + suffix_len = len(suffix_tokens) + + prefix_tokens_mask = torch.tensor([1] * len(prefix_tokens), dtype=text_tokens_and_mask.attention_mask[0].dtype) + suffix_tokens_mask = torch.tensor([1] * len(suffix_tokens), dtype=text_tokens_and_mask.attention_mask[0].dtype) + + prefix_tokens = torch.tensor(prefix_tokens, dtype=text_tokens_and_mask.input_ids.dtype) + suffix_tokens = torch.tensor(suffix_tokens, dtype=text_tokens_and_mask.input_ids.dtype) + + input_ids = torch.cat((prefix_tokens, text_tokens_and_mask.input_ids[0], suffix_tokens), dim=-1) + attention_mask = torch.cat( + (prefix_tokens_mask, text_tokens_and_mask.attention_mask[0], suffix_tokens_mask), dim=-1 + ) + + input_ids = input_ids.unsqueeze(0).to(self.device) + attention_mask = attention_mask.unsqueeze(0).to(self.device) + + pixel_values = pixel_values.to(self.device) + image_grid_thw = image_grid_thw.to(self.device) + + text_output = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + output_hidden_states=True, + ) + # [max_sequence_length, batch, hidden_size] -> [batch, max_sequence_length, hidden_size] + # clone to have a contiguous tensor + prompt_embeds = text_output.hidden_states[-1].detach() + prompt_embeds = prompt_embeds[:, prefix_len:-suffix_len, :] + return prompt_embeds + + def encode_prompt( + self, + prompt: list[str] = None, + image: torch.Tensor | None = None, + num_images_per_prompt: int | None = 1, + prompt_embeds: torch.Tensor | None = None, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + # If prompt_embeds is provided and prompt is None, skip encoding + if prompt_embeds is None: + prompt_embeds = self._encode_prompt(prompt, image) + + _, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + text_ids = prepare_pos_ids(modality_id=0, type="text", start=(0, 0), num_token=prompt_embeds.shape[1]).to( + self.device + ) + return prompt_embeds, text_ids + + @staticmethod + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax") + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax") + image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + return image_latents + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + def prepare_latents( + self, + image, + batch_size, + num_channels_latents, + height, + width, + dtype, + prompt_embeds_length, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + image_latents, image_latents_ids = None, None + + if image is not None: + image = image.to(device=self.device, dtype=dtype) + + if image.shape[1] != self.vae.config.latent_channels: + image_latents = self._encode_vae_image(image=image, generator=generator) + else: + image_latents = image + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width) + + image_latents_ids = prepare_pos_ids( + modality_id=2, + type="image", + start=(prompt_embeds_length, prompt_embeds_length), + height=height // 2, + width=width // 2, + ).to(device, dtype=torch.float64) + + shape = (batch_size, num_channels_latents, height, width) + latents_ids = prepare_pos_ids( + modality_id=1, + type="image", + start=(prompt_embeds_length, prompt_embeds_length), + height=height // 2, + width=width // 2, + ).to(device) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + else: + latents = latents.to(device=device, dtype=dtype) + + return latents, image_latents, latents_ids, image_latents_ids + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + def check_inputs( + self, prompt, height, width, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None: + if isinstance(prompt, str): + pass + elif isinstance(prompt, list) and len(prompt) == 1: + pass + else: + raise ValueError( + f"`prompt` must be a `str` or a `list` of length 1, but is {prompt} (type: {type(prompt)})" + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + @replace_example_docstring(EXAMPLE_DOC_STRING) + @torch.no_grad() + def __call__( + self, + image: PIL.Image.Image | None = None, + prompt: str | list[str] = None, + negative_prompt: str | list[str] = None, + num_inference_steps: int = 50, + sigmas: list[float] | None = None, + guidance_scale: float = 4.5, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.FloatTensor | None = None, + prompt_embeds: torch.FloatTensor | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + joint_attention_kwargs: dict[str, Any] | None = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Examples: + + Returns: + [`~pipelines.LongCatImagePipelineOutput`] or `tuple`: [`~pipelines.LongCatImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + + image_size = image[0].size if isinstance(image, list) else image.size + calculated_width, calculated_height = calculate_dimensions(1024 * 1024, image_size[0] * 1.0 / image_size[1]) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + calculated_height, + calculated_width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Preprocess image + if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): + image = self.image_processor.resize(image, calculated_height, calculated_width) + prompt_image = self.image_processor.resize(image, calculated_height // 2, calculated_width // 2) + image = self.image_processor.preprocess(image, calculated_height, calculated_width) + + negative_prompt = "" if negative_prompt is None else negative_prompt + (prompt_embeds, text_ids) = self.encode_prompt( + prompt=prompt, image=prompt_image, prompt_embeds=prompt_embeds, num_images_per_prompt=num_images_per_prompt + ) + if self.do_classifier_free_guidance: + (negative_prompt_embeds, negative_text_ids) = self.encode_prompt( + prompt=negative_prompt, + image=prompt_image, + prompt_embeds=negative_prompt_embeds, + num_images_per_prompt=num_images_per_prompt, + ) + + # 4. Prepare latent variables + num_channels_latents = 16 + latents, image_latents, latents_ids, image_latents_ids = self.prepare_latents( + image, + batch_size * num_images_per_prompt, + num_channels_latents, + calculated_height, + calculated_width, + prompt_embeds.dtype, + prompt_embeds.shape[1], + device, + generator, + latents, + ) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1.0 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + guidance = None + + if self.joint_attention_kwargs is None: + self._joint_attention_kwargs = {} + + if image is not None: + latent_image_ids = torch.cat([latents_ids, image_latents_ids], dim=0) + else: + latent_image_ids = latents_ids + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + latent_model_input = latents + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1) + + timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype) + with self.transformer.cache_context("cond"): + noise_pred_text = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + return_dict=False, + )[0] + noise_pred_text = noise_pred_text[:, :image_seq_len] + if self.do_classifier_free_guidance: + with self.transformer.cache_context("uncond"): + noise_pred_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + encoder_hidden_states=negative_prompt_embeds, + txt_ids=negative_text_ids, + img_ids=latent_image_ids, + return_dict=False, + )[0] + noise_pred_uncond = noise_pred_uncond[:, :image_seq_len] + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + else: + noise_pred = noise_pred_text + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, calculated_height, calculated_width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + if latents.dtype != self.vae.dtype: + latents = latents.to(dtype=self.vae.dtype) + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return LongCatImagePipelineOutput(images=image) diff --git a/diffusers/pipelines/longcat_image/system_messages.py b/diffusers/pipelines/longcat_image/system_messages.py new file mode 100644 index 0000000000000000000000000000000000000000..b8b2318e4e813970ebba7e5da4a07c26f238f4e8 --- /dev/null +++ b/diffusers/pipelines/longcat_image/system_messages.py @@ -0,0 +1,142 @@ +SYSTEM_PROMPT_EN = """ +You are a prompt engineering expert for text-to-image models. Since text-to-image models have limited capabilities in +understanding user prompts, you need to identify the core theme and intent of the user's input and improve the model's +understanding accuracy and generation quality through optimization and rewriting. The rewrite must strictly retain all +information from the user's original prompt without deleting or distorting any details. Specific requirements are as +follows: +1. The rewrite must not affect any information expressed in the user's original prompt; the rewritten prompt should use + coherent natural language, avoid low-information redundant descriptions, and keep the rewritten prompt length as + concise as possible. +2. Ensure consistency between input and output languages: Chinese input yields Chinese output, and English input yields + English output. The rewritten token count should not exceed 512. +3. The rewritten description should further refine subject characteristics and aesthetic techniques appearing in the + original prompt, such as lighting and textures. +4. If the original prompt does not specify an image style, ensure the rewritten prompt uses a **realistic photography + style**. If the user specifies a style, retain the user's style. +5. When the original prompt requires reasoning to clarify user intent, use logical reasoning based on world knowledge + to convert vague abstract descriptions into specific tangible objects (e.g., convert "the tallest animal" to "a + giraffe"). +6. When the original prompt requires text generation, please use double quotes to enclose the text part (e.g., `"50% + OFF"`). +7. When the original prompt requires generating text-heavy scenes like webpages, logos, UIs, or posters, and no + specific text content is specified, you need to infer appropriate text content and enclose it in double quotes. For + example, if the user inputs: "A tourism flyer with a grassland theme," it should be rewritten as: "A tourism flyer + with the image title 'Grassland'." +8. When negative words exist in the original prompt, ensure the rewritten prompt does not contain negative words. For + example, "a lakeside without boats" should be rewritten such that the word "boat" does not appear at all. +9. Except for text content explicitly requested by the user, **adding any extra text content is prohibited**. +Here are examples of rewrites for different types of prompts: # Examples (Few-Shot Learning) + 1. User Input: An animal with nine lives. + Rewrite Output: A cat bathed in soft sunlight, its fur soft and glossy. The background is a comfortable home + environment with light from the window filtering through curtains, creating a warm light and shadow effect. The + shot uses a medium distance perspective to highlight the cat's leisurely and stretched posture. Light cleverly hits + the cat's face, emphasizing its spirited eyes and delicate whiskers, adding depth and affinity to the image. + 2. User Input: Create an anime-style tourism flyer with a grassland theme. + Rewrite Output: In the lower right of the center, a short-haired girl sits sideways on a gray, irregularly shaped + rock. She wears a white short-sleeved dress and brown flat shoes, holding a bunch of small white flowers in her + left hand, smiling with her legs hanging naturally. The girl has dark brown shoulder-length hair with bangs + covering her forehead, brown eyes, and a slightly open mouth. The rock surface has textures of varying depths. To + the girl's left and front is lush grass, with long, yellow-green blades, some glowing golden in the sunlight. The + grass extends into the distance, forming rolling green hills that fade in color as they recede. The sky occupies + the upper half of the picture, pale blue dotted with a few fluffy white clouds. In the upper left corner, there is + a line of text in italic, dark green font reading "Explore Nature's Peace". Colors are dominated by green, blue, + and yellow, fluid lines, and distinct light and shadow contrast, creating a quiet and comfortable atmosphere. + 3. User Input: A Christmas sale poster with a red background, promoting a Buy 1 Get 1 Free milk tea offer. + Rewrite Output: The poster features an overall red tone, embellished with white snowflake patterns on the top and + left side. The upper right features a bunch of holly leaves with red berries and a pine cone. In the upper center, + golden 3D text reads "Christmas Heartwarming Feedback" centered, along with red bold text "Buy 1 Get 1". Below, two + transparent cups filled with bubble tea are placed side by side; the tea is light brown with dark brown pearls + scattered at the bottom and middle. Below the cups, white snow piles up, decorated with pine branches, red berries, + and pine cones. A blurry Christmas tree is faintly visible in the lower right corner. The image has high clarity, + accurate text content, a unified design style, a prominent Christmas theme, and a reasonable layout, providing + strong visual appeal. + 4. User Input: A woman indoors shot in natural light, smiling with arms crossed, showing a relaxed and confident + posture. + Rewrite Output: The image features a young Asian woman with long dark brown hair naturally falling over her + shoulders, with some strands illuminated by light, showing a soft sheen. Her features are delicate, with long + eyebrows, bright and spirited dark brown eyes looking directly at the camera, revealing peace and confidence. She + has a high nose bridge, full lips with nude lipstick, and corners of the mouth slightly raised in a faint smile. + Her skin is fair, with cheeks and collarbones illuminated by warm light, showing a healthy ruddiness. She wears a + black spaghetti strap tank top revealing graceful collarbone lines, and a thin gold necklace with small beads and + metal bars glinting in the light. Her outer layer is a beige knitted cardigan, soft in texture with visible + knitting patterns on the sleeves. Her arms are crossed over her chest, hands covered by the cardigan sleeves, in a + relaxed posture. The background is a pure dark brown without extra decoration, making the figure the absolute + focus. The figure is located in the center of the frame. Light enters from the upper right, creating bright spots + on her left cheek, neck, and collarbone, while the right side is slightly shadowed, creating a three-dimensional + and soft tone. Image details are clear, showcasing skin texture, hair, and clothing materials well. Colors are + dominated by warm tones, with the combination of beige and dark brown creating a warm and comfortable atmosphere. + The overall style is natural, elegant, and artistic. + 5. User Input: Create a series of images showing the growth process of an apple from seed to fruit. The series should + include four stages: 1. Sowing, 2. Seedling growth, 3. Plant maturity, 4. Fruit harvesting. + Rewrite Output: A 4-panel exquisite illustration depicting the growth process of an apple, capturing each stage + precisely and clearly. 1. "Sowing": A close-up shot of a hand gently placing a small apple seed into fertile dark + soil, with visible soil texture and the seed's smooth surface. The background is a soft-focus garden dotted with + green leaves and sunlight filtering through. 2. "Seedling Growth": A young apple sapling breaks through the soil, + stretching tender green leaves toward the sky. The scene is set in a vibrant garden illuminated by warm golden + light, highlighting the seedling's delicate structure. 3. "Plant Maturity": A mature apple tree, lush with branches + and leaves, covered in tender green foliage and developing small apples. The background is a vibrant orchard under + a clear blue sky, with dappled sunlight creating a peaceful atmosphere. 4. "Fruit Harvesting": A hand reaches into + the tree to pick a ripe red apple, its smooth skin glistening in the sun. The scene shows the abundance of the + orchard, with baskets of apples in the background, giving a sense of fulfillment. Each illustration uses a + realistic style, focusing on details and harmonious colors to showcase the natural beauty and development of the + apple's life cycle. + 6. User Input: If 1 represents red, 2 represents green, 3 represents purple, and 4 represents yellow, please generate + a four-color rainbow based on this rule. The color order from top to bottom is 3142. + Rewrite Output: The image consists of four horizontally arranged colored stripes, ordered from top to bottom as + purple, red, yellow, and green. A white number is centered on each stripe. The top purple stripe features the + number "3", the red stripe below it has the number "1", the yellow stripe further down has the number "4", and the + bottom green stripe has the number "2". All numbers use a sans-serif font in pure white, forming a sharp contrast + with the background colors to ensure good readability. The stripes have high color saturation and a slight texture. + The overall layout is simple and clear, with distinct visual effects and no extra decorative elements, emphasizing + the numerical information. The image is high definition, with accurate colors and a consistent style, offering + strong visual appeal. + 7. User Input: A stone tablet carved with "Guan Guan Ju Jiu, On the River Isle", natural light, background is a + Chinese garden. + Rewrite Output: An ancient stone tablet carved with "Guan Guan Ju Jiu, On the River Isle", the surface covered with + traces of time, the writing clear and deep. Natural light falls from above, softly illuminating every detail of the + stone tablet and enhancing its sense of history. The background is an elegant Chinese garden featuring lush bamboo + forests, winding paths, and quiet pools, creating a serene and distant atmosphere. The overall picture uses a + realistic style with rich details and natural light and shadow effects, highlighting the cultural heritage of the + stone tablet and the classical beauty of the garden. +# Output Format Please directly output the rewritten and optimized Prompt content. Do not include any explanatory +language or JSON formatting, and do not add opening or closing quotes yourself.""" + + +SYSTEM_PROMPT_ZH = """ +你是一名文生图模型的prompt +engineering专家。由于文生图模型对用户prompt的理解能力有限,你需要识别用户输入的核心主题和意图,并通过优化改写提升模型的理解准确性和生成质量。改写必须严格保留用户原始prompt的所有信息,不得删减或曲解任何细节。 +具体要求如下: +1. 改写不能影响用户原始prompt里表达的任何信息,改写后的prompt应该使用连贯的自然语言表达,不要出现低信息量的冗余描述,尽可能保持改写后prompt长度精简。 +2. 请确保输入和输出的语言类型一致,中文输入中文输出,英文输入英文输出,改写后的token数量不要超过512个; +3. 改写后的描述应当进一步完善原始prompt中出现的主体特征、美学技巧,如打光、纹理等; +4. 如果原始prompt没有指定图片风格时,确保改写后的prompt使用真实摄影风格,如果用户指定了图片风格,则保留用户风格; +5. 当原始prompt需要推理才能明确用户意图时,根据世界知识进行适当逻辑推理,将模糊抽象描述转化为具体指向事物(例:将"最高的动物"转化为"一头长颈鹿")。 +6. 当原始prompt需要生成文字时,请使用双引号圈定文字部分,例:`"限时5折"`)。 +7. 当原始prompt需要生成网页、logo、ui、海报等文字场景时,且没有指定具体的文字内容时,需要推断出合适的文字内容,并使用双引号圈定,如用户输入:一个旅游宣传单,以草原为主题。应该改写成:一个旅游宣传单,图片标题为“草原”。 +8. 当原始prompt中存在否定词时,需要确保改写后的prompt不存在否定词,如没有船的湖边,改写后的prompt不能出现船这个词汇。 +9. 除非用户指定生成品牌logo,否则不要增加额外的品牌logo. +10. 除了用户明确要求书写的文字内容外,**禁止增加任何额外的文字内容**。 +以下是针对不同类型prompt改写的示例: + +# Examples (Few-Shot Learning) + 1. 用户输入: 九条命的动物。 + 改写输出: + 一只猫,被柔和的阳光笼罩着,毛发柔软而富有光泽。背景是一个舒适的家居环境,窗外的光线透过窗帘,形成温馨的光影效果。镜头采用中距离视角,突出猫悠闲舒展的姿态。光线巧妙地打在猫的脸部,强调它灵动的眼睛和精致的胡须,增加画面的层次感与亲和力。 + 2. 用户输入: 制作一个动画风格的旅游宣传单,以草原为主题。 + 改写输出: + 画面中央偏右下角,一个短发女孩侧身坐在灰色的不规则形状岩石上,她穿着白色短袖连衣裙和棕色平底鞋,左手拿着一束白色小花,面带微笑,双腿自然垂下。女孩的头发为深棕色,齐肩短发,刘海覆盖额头,眼睛呈棕色,嘴巴微张。岩石表面有深浅不一的纹理。女孩的左侧和前方是茂盛的草地,草叶细长,呈黄绿色,部分草叶在阳光下泛着金色的光芒,仿佛被阳光照亮。草地向远处延伸,形成连绵起伏的绿色山丘,山丘的颜色由近及远逐渐变浅。天空占据了画面的上半部分,呈淡蓝色,点缀着几朵白色蓬松的云彩。画面的左上角有一行文字,文字内容是斜体、深绿色的“Explore + Nature's Peace”。色彩以绿色、蓝色和黄色为主,线条流畅,光影明暗对比明显,营造出一种宁静、舒适的氛围。 + 3. 用户输入: 一张以红色为背景的圣诞节促销海报,主要宣传奶茶买一送一的优惠活动。 + 改写输出: 海报整体呈现红色调,上方和左侧点缀着白色雪花图案,右上方有一束冬青叶和红色浆果,以及一个松果。海报中央偏上位置,金色立体字样“圣诞节 + 暖心回馈”居中排列,和红色粗体字“买1送1”。海报下方,两个装满珍珠奶茶的透明杯子并排摆放,杯中奶茶呈浅棕色,底部和中间散布着深棕色珍珠。杯子下方,堆积着白色雪花,雪花上装饰着松枝、红色浆果和松果。右下角隐约可见一棵模糊的圣诞树。图片清晰度高,文字内容准确,整体设计风格统一,圣诞主题突出,排版布局合理,具有较强的视觉吸引力。 + 4. 用户输入: 一位女性在室内以自然光线拍摄,她面带微笑,双臂交叉,展现出轻松自信的姿态。 + 改写输出: + 画面中是一位年轻的亚洲女性,她拥有深棕色的长发,发丝自然地垂落在双肩,部分发丝被光线照亮,呈现出柔和的光泽。她的五官精致,眉毛修长,眼睛明亮有神,瞳孔呈深棕色,眼神直视镜头,流露出平和与自信。鼻梁挺拔,嘴唇丰满,涂有裸色系唇膏,嘴角微微上扬,展现出浅浅的微笑。她的肤色白皙,脸颊和锁骨处被暖色调的光线照亮,呈现出健康的红润感。她穿着一件黑色的细吊带背心,肩带纤细,露出优美的锁骨线条。脖颈上佩戴着一条金色的细项链,项链由小珠子和几个细长的金属条组成,在光线下闪烁着光泽。她的外搭是一件米黄色的针织开衫,材质柔软,袖子部分有明显的针织纹理。她双臂交叉在胸前,双手被开衫的袖子覆盖,姿态放松。背景是纯粹的深棕色,没有多余的装饰,使得人物成为画面的绝对焦点。人物位于画面中央。光线从画面的右上方射入,在人物的左侧脸颊、脖颈和锁骨处形成明亮的光斑,右侧则略显阴影,营造出立体感和柔和的影调。图像细节清晰,人物的皮肤纹理、发丝以及衣物材质都得到了很好的展现。色彩以暖色调为主,米黄色和深棕色的搭配营造出温馨舒适的氛围。整体呈现出一种自然、优雅且富有亲和力的艺术风格。 + 5. 用户输入:创作一系列图片,展现苹果从种子到结果的生长过程。该系列图片应包含以下四个阶段:1. 播种,2. 幼苗生长,3. 植物成熟,4. 果实采摘。 + 改写输出:一个4宫格的精美插图,描绘苹果的生长过程,精确清晰地捕捉每个阶段。1.“播种”:特写镜头,一只手轻轻地将一颗小小的苹果种子放入肥沃的深色土壤中,土壤的纹理和种子光滑的表面清晰可见。背景是花园的柔焦画面,点缀着绿色的树叶和透过树叶洒下的阳光。2.“幼苗生长”:一棵幼小的苹果树苗破土而出,嫩绿的叶子向天空舒展。场景设定在一个生机勃勃的花园中,温暖的金光照亮了它。幼苗的纤细结构。3.“植物的成熟”:一棵成熟的苹果树,枝繁叶茂,挂满了嫩绿的叶子和正在萌发的小苹果。背景是一片生机勃勃的果园,湛蓝的天空下,斑驳的阳光营造出宁静祥和的氛围。4.“采摘果实”:一只手伸向树上,摘下一个成熟的红苹果,苹果光滑的果皮在阳光下闪闪发光。画面展现了果园的丰收景象,背景中摆放着一篮篮的苹果,给人一种圆满满足的感觉。每幅插图都采用写实风格,注重细节,色彩和谐,展现了苹果生命周期的自然之美和发展过程。 + 6. 用户输入: 如果1代表红色,2代表绿色,3代表紫色,4代表黄色,请按照此规则生成四色彩虹。它的颜色顺序从上到下是3142 + 改写输出:图片由四个水平排列的彩色条纹组成,从上到下依次为紫色、红色、黄色和绿色。每个条纹上都居中放置一个白色数字。最上方的紫色条纹上是数字“3”,其下方红色条纹上是数字“1”,再下方黄色条纹上是数字“4”,最下方的绿色条纹上是数字“2”。所有数字均采用无衬线字体,颜色为纯白色,与背景色形成鲜明对比,确保了良好的可读性。条纹的颜色饱和度高,且带有轻微的纹理感,整体排版简洁明了,视觉效果清晰,没有多余的装饰元素,强调了数字信息本身。图片整体清晰度高,色彩准确,风格一致,具有较强的视觉吸引力。 + 7. 用户输入:石碑上刻着“关关雎鸠,在河之洲”,自然光照,背景是中式园林 + 改写输出:一块古老的石碑上刻着“关关雎鸠,在河之洲”,石碑表面布满岁月的痕迹,字迹清晰而深刻。自然光线从上方洒下,柔和地照亮石碑的每一个细节,增强了其历史感。背景是一座典雅的中式园林,园林中有翠绿的竹林、蜿蜒的小径和静谧的水池,营造出一种宁静而悠远的氛围。整体画面采用写实风格,细节丰富,光影效果自然,突出了石碑的文化底蕴和园林的古典美。 +# 输出格式 请直接输出改写优化后的 Prompt 内容,不要包含任何解释性语言或 JSON 格式,不要自行添加开头或结尾的引号。 +""" diff --git a/diffusers/pipelines/lucy/__init__.py b/diffusers/pipelines/lucy/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..580e1f37f30a4bfc91878692c9e0004a233aa64e --- /dev/null +++ b/diffusers/pipelines/lucy/__init__.py @@ -0,0 +1,47 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_lucy_edit"] = ["LucyEditPipeline"] +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_lucy_edit import LucyEditPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/pipelines/lucy/pipeline_lucy_edit.py b/diffusers/pipelines/lucy/pipeline_lucy_edit.py new file mode 100644 index 0000000000000000000000000000000000000000..392af492b702a53bff9dce509ce5fb31dee7c51a --- /dev/null +++ b/diffusers/pipelines/lucy/pipeline_lucy_edit.py @@ -0,0 +1,733 @@ +# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. +# Copyright 2025 The Decart AI Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Modifications by Decart AI Team: +# - Based on pipeline_wan.py, but with supports receiving a condition video appended to the channel dimension. + +import html +from typing import Any, Callable + +import regex as re +import torch +from PIL import Image +from transformers import AutoTokenizer, UMT5EncoderModel + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import WanLoraLoaderMixin +from ...models import AutoencoderKLWan, WanTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import LucyPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_ftfy_available(): + import ftfy + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> from typing import list + + >>> import torch + >>> from PIL import Image + + >>> from diffusers import AutoencoderKLWan, LucyEditPipeline + >>> from diffusers.utils import export_to_video, load_video + + >>> # Arguments + >>> url = "https://d2drjpuinn46lb.cloudfront.net/painter_original_edit.mp4" + >>> prompt = "Change the apron and blouse to a classic clown costume: satin polka-dot jumpsuit in bright primary colors, ruffled white collar, oversized pom-pom buttons, white gloves, oversized red shoes, red foam nose; soft window light from left, eye-level medium shot, natural folds and fabric highlights." + >>> negative_prompt = "" + >>> num_frames = 81 + >>> height = 480 + >>> width = 832 + + + >>> # Load video + >>> def convert_video(video: list[Image.Image]) -> list[Image.Image]: + ... video = load_video(url)[:num_frames] + ... video = [video[i].resize((width, height)) for i in range(num_frames)] + ... return video + + + >>> video = load_video(url, convert_method=convert_video) + + >>> # Load model + >>> model_id = "decart-ai/Lucy-Edit-Dev" + >>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) + >>> pipe = LucyEditPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> # Generate video + >>> output = pipe( + ... prompt=prompt, + ... video=video, + ... negative_prompt=negative_prompt, + ... height=480, + ... width=832, + ... num_frames=81, + ... guidance_scale=5.0, + ... ).frames[0] + + >>> # Export video + >>> export_to_video(output, "output.mp4", fps=24) + ``` +""" + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class LucyEditPipeline(DiffusionPipeline, WanLoraLoaderMixin): + r""" + Pipeline for video-to-video generation using Lucy Edit. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + tokenizer ([`T5Tokenizer`]): + Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer), + specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + text_encoder ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + transformer ([`WanTransformer3DModel`]): + Conditional Transformer to denoise the input latents. + scheduler ([`UniPCMultistepScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLWan`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + transformer_2 ([`WanTransformer3DModel`], *optional*): + Conditional Transformer to denoise the input latents during the low-noise stage. If provided, enables + two-stage denoising where `transformer` handles high-noise stages and `transformer_2` handles low-noise + stages. If not provided, only `transformer` is used. + boundary_ratio (`float`, *optional*, defaults to `None`): + Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising. + The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided, + `transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps < + boundary_timestep. If `None`, only `transformer` is used for the entire denoising process. + """ + + model_cpu_offload_seq = "text_encoder->transformer->transformer_2->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + _optional_components = ["transformer", "transformer_2"] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + vae: AutoencoderKLWan, + scheduler: FlowMatchEulerDiscreteScheduler, + transformer: WanTransformer3DModel | None = None, + transformer_2: WanTransformer3DModel | None = None, + boundary_ratio: float | None = None, + expand_timesteps: bool = False, # Wan2.2 ti2v + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + transformer_2=transformer_2, + ) + self.register_to_config(boundary_ratio=boundary_ratio) + self.register_to_config(expand_timesteps=expand_timesteps) + self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 226, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def check_inputs( + self, + video, + prompt, + negative_prompt, + height, + width, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + guidance_scale_2=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif negative_prompt is not None and ( + not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + if self.config.boundary_ratio is None and guidance_scale_2 is not None: + raise ValueError("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None.") + + if video is None: + raise ValueError("`video` is required, received None.") + + def prepare_latents( + self, + video: torch.Tensor | None = None, + batch_size: int = 1, + num_channels_latents: int = 16, + height: int = 480, + width: int = 832, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | None = None, + latents: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + num_latent_frames = ( + (video.size(2) - 1) // self.vae_scale_factor_temporal + 1 if latents is None else latents.size(1) + ) + shape = ( + batch_size, + num_channels_latents, + num_latent_frames, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + # Prepare noise latents + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # Prepare condition latents + condition_latents = [ + retrieve_latents(self.vae.encode(vid.unsqueeze(0)), sample_mode="argmax") for vid in video + ] + + condition_latents = torch.cat(condition_latents, dim=0).to(dtype) + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1).to(device, dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + device, dtype + ) + + condition_latents = (condition_latents - latents_mean) * latents_std + + # Check shapes + assert latents.shape == condition_latents.shape, ( + f"Latents shape {latents.shape} does not match expected shape {condition_latents.shape}. Please check the input." + ) + + return latents, condition_latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + video: list[Image.Image], + prompt: str | list[str] = None, + negative_prompt: str | list[str] = None, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + guidance_scale_2: float | None = None, + num_videos_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "np", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + The call function to the pipeline for generation. + + Args: + video (`list[Image.Image]`): + The video to use as the condition for the video generation. + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, pass `prompt_embeds` instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to avoid during image generation. If not defined, pass `negative_prompt_embeds` + instead. Ignored when not using guidance (`guidance_scale` < `1`). + height (`int`, defaults to `480`): + The height in pixels of the generated image. + width (`int`, defaults to `832`): + The width in pixels of the generated image. + num_frames (`int`, defaults to `81`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to `5.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + guidance_scale_2 (`float`, *optional*, defaults to `None`): + Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's + `boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2` + and the pipeline's `boundary_ratio` are not None. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`LucyPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `512`): + The maximum sequence length of the text encoder. If the prompt is longer than this, it will be + truncated. If the prompt is shorter, it will be padded to this length. + + Examples: + + Returns: + [`~LucyPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`LucyPipelineOutput`] is returned, otherwise a `tuple` is returned where + the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + video, + prompt, + negative_prompt, + height, + width, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + guidance_scale_2, + ) + + if num_frames % self.vae_scale_factor_temporal != 1: + logger.warning( + f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." + ) + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + num_frames = max(num_frames, 1) + + if self.config.boundary_ratio is not None and guidance_scale_2 is None: + guidance_scale_2 = guidance_scale + + self._guidance_scale = guidance_scale + self._guidance_scale_2 = guidance_scale_2 + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + + transformer_dtype = self.transformer.dtype if self.transformer is not None else self.transformer_2.dtype + prompt_embeds = prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = ( + self.transformer.config.out_channels + if self.transformer is not None + else self.transformer_2.config.out_channels + ) + video = self.video_processor.preprocess_video(video, height=height, width=width).to( + device, dtype=torch.float32 + ) + latents, condition_latents = self.prepare_latents( + video, + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + torch.float32, + device, + generator, + latents, + ) + + mask = torch.ones(latents.shape, dtype=torch.float32, device=device) + + # 6. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + if self.config.boundary_ratio is not None: + boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps + else: + boundary_timestep = None + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + if boundary_timestep is None or t >= boundary_timestep: + # wan2.1 or high-noise stage in wan2.2 + current_model = self.transformer + current_guidance_scale = guidance_scale + else: + # low-noise stage in wan2.2 + current_model = self.transformer_2 + current_guidance_scale = guidance_scale_2 + + # latent_model_input = latents.to(transformer_dtype) + latent_model_input = torch.cat([latents, condition_latents], dim=1).to(transformer_dtype) + # latent_model_input = torch.cat([latents, latents], dim=1).to(transformer_dtype) + if self.config.expand_timesteps: + # seq_len: num_latent_frames * latent_height//2 * latent_width//2 + temp_ts = (mask[0][0][:, ::2, ::2] * t).flatten() + # batch_size, seq_len + timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1) + else: + timestep = t.expand(latents.shape[0]) + + with current_model.cache_context("cond"): + noise_pred = current_model( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + with current_model.cache_context("uncond"): + noise_uncond = current_model( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return LucyPipelineOutput(frames=video) diff --git a/diffusers/pipelines/lucy/pipeline_output.py b/diffusers/pipelines/lucy/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..197ce194f475e5b738bd19cb6201fe943ad603ea --- /dev/null +++ b/diffusers/pipelines/lucy/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import torch + +from diffusers.utils import BaseOutput + + +@dataclass +class LucyPipelineOutput(BaseOutput): + r""" + Output class for Lucy pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or list[list[PIL.Image.Image]]): + list of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + frames: torch.Tensor diff --git a/diffusers/pipelines/lumina/__init__.py b/diffusers/pipelines/lumina/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a19dc7e94641f59cad8588e867412829cd4fa793 --- /dev/null +++ b/diffusers/pipelines/lumina/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_lumina import LuminaPipeline, LuminaText2ImgPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/pipelines/lumina/pipeline_lumina.py b/diffusers/pipelines/lumina/pipeline_lumina.py new file mode 100644 index 0000000000000000000000000000000000000000..cc123218f4ee5f7452c7dcb77b7de41865101c4d --- /dev/null +++ b/diffusers/pipelines/lumina/pipeline_lumina.py @@ -0,0 +1,957 @@ +# Copyright 2025 Alpha-VLLM and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import inspect +import math +import re +import urllib.parse as ul +from typing import Callable + +import torch +from transformers import GemmaPreTrainedModel, GemmaTokenizer, GemmaTokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKL +from ...models.embeddings import get_2d_rotary_pos_embed_lumina +from ...models.transformers.lumina_nextdit2d import LuminaNextDiT2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + BACKENDS_MAPPING, + deprecate, + is_bs4_available, + is_ftfy_available, + is_torch_xla_available, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import LuminaPipeline + + >>> pipe = LuminaPipeline.from_pretrained("Alpha-VLLM/Lumina-Next-SFT-diffusers", torch_dtype=torch.bfloat16) + >>> # Enable memory optimizations. + >>> pipe.enable_model_cpu_offload() + + >>> prompt = "Upper body of a young woman in a Victorian-era outfit with brass goggles and leather straps. Background shows an industrial revolution cityscape with smoky skies and tall, metal structures" + >>> image = pipe(prompt).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class LuminaPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using Lumina-T2I. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`GemmaPreTrainedModel`]): + Frozen Gemma text-encoder. + tokenizer (`GemmaTokenizer` or `GemmaTokenizerFast`): + Gemma tokenizer. + transformer ([`Transformer2DModel`]): + A text conditioned `Transformer2DModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + """ + + bad_punct_regex = re.compile( + r"[" + + "#®•©™&@·º½¾¿¡§~" + + r"\)" + + r"\(" + + r"\]" + + r"\[" + + r"\}" + + r"\{" + + r"\|" + + "\\" + + r"\/" + + r"\*" + + r"]{1,}" + ) # noqa + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + ] + + def __init__( + self, + transformer: LuminaNextDiT2DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: GemmaPreTrainedModel, + tokenizer: GemmaTokenizer | GemmaTokenizerFast, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.max_sequence_length = 256 + self.default_sample_size = ( + self.transformer.config.sample_size + if hasattr(self, "transformer") and self.transformer is not None + else 128 + ) + self.default_image_size = self.default_sample_size * self.vae_scale_factor + + def _get_gemma_prompt_embeds( + self, + prompt: str | list[str], + num_images_per_prompt: int = 1, + device: torch.device | None = None, + clean_caption: bool | None = False, + max_length: int | None = None, + ): + device = device or self._execution_device + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + text_inputs = self.tokenizer( + prompt, + pad_to_multiple_of=8, + max_length=self.max_sequence_length, + truncation=True, + padding=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids.to(device) + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids.to(device) + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because Gemma can only handle sequences up to" + f" {self.max_sequence_length} tokens: {removed_text}" + ) + + prompt_attention_mask = text_inputs.attention_mask.to(device) + prompt_embeds = self.text_encoder( + text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True + ) + prompt_embeds = prompt_embeds.hidden_states[-2] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.transformer is not None: + dtype = self.transformer.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + prompt_attention_mask = prompt_attention_mask.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds, prompt_attention_mask + + # Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + do_classifier_free_guidance: bool = True, + negative_prompt: str | list[str] = None, + num_images_per_prompt: int = 1, + device: torch.device | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + clean_caption: bool = False, + **kwargs, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For + Lumina-T2I, this should be "". + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For Lumina-T2I, it's should be the embeddings of the "" string. + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 256): Maximum sequence length to use for the prompt. + """ + if device is None: + device = self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + clean_caption=clean_caption, + ) + + # Get negative embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt if negative_prompt is not None else "" + + # Normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + # Padding negative prompt to the same length with prompt + prompt_max_length = prompt_embeds.shape[1] + negative_text_inputs = self.tokenizer( + negative_prompt, + padding="max_length", + max_length=prompt_max_length, + truncation=True, + return_tensors="pt", + ) + negative_text_input_ids = negative_text_inputs.input_ids.to(device) + negative_prompt_attention_mask = negative_text_inputs.attention_mask.to(device) + # Get the negative prompt embeddings + negative_prompt_embeds = self.text_encoder( + negative_text_input_ids, + attention_mask=negative_prompt_attention_mask, + output_hidden_states=True, + ) + + negative_dtype = self.text_encoder.dtype + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + _, seq_len, _ = negative_prompt_embeds.shape + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=negative_dtype, device=device) + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + negative_prompt_attention_mask = negative_prompt_attention_mask.view( + batch_size * num_images_per_prompt, -1 + ) + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + raise ValueError( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip addresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + width: int | None = None, + height: int | None = None, + num_inference_steps: int = 30, + guidance_scale: float = 4.0, + negative_prompt: str | list[str] = None, + sigmas: list[float] = None, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + clean_caption: bool = True, + max_sequence_length: int = 256, + scaling_watershed: float | None = 1.0, + proportional_attn: bool | None = True, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + ) -> ImagePipelineOutput | tuple: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_inference_steps (`int`, *optional*, defaults to 30): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*, defaults to self.unet.config.sample_size): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size): + The width in pixels of the generated image. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For Lumina-T2I this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + max_sequence_length (`int` defaults to 120): + Maximum sequence length to use with the `prompt`. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images + """ + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + + cross_attention_kwargs = {} + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if proportional_attn: + cross_attention_kwargs["base_sequence_length"] = (self.default_image_size // 16) ** 2 + + scaling_factor = math.sqrt(width * height / self.default_image_size**2) + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt, + do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([prompt_attention_mask, negative_prompt_attention_mask], dim=0) + + # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, sigmas=sigmas + ) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + self._num_timesteps = len(timesteps) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + current_timestep = t + if not torch.is_tensor(current_timestep): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = latent_model_input.device.type == "mps" + is_npu = latent_model_input.device.type == "npu" + if isinstance(current_timestep, float): + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + else: + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 + current_timestep = torch.tensor( + [current_timestep], + dtype=dtype, + device=latent_model_input.device, + ) + elif len(current_timestep.shape) == 0: + current_timestep = current_timestep[None].to(latent_model_input.device) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + current_timestep = current_timestep.expand(latent_model_input.shape[0]) + + # reverse the timestep since Lumina uses t=0 as the noise and t=1 as the image + current_timestep = 1 - current_timestep / self.scheduler.config.num_train_timesteps + + # prepare image_rotary_emb for positional encoding + # dynamic scaling_factor for different resolution. + # NOTE: For `Time-aware` denosing mechanism from Lumina-Next + # https://huggingface.co/papers/2406.18583, Sec 2.3 + # NOTE: We should compute different image_rotary_emb with different timestep. + if current_timestep[0] < scaling_watershed: + linear_factor = scaling_factor + ntk_factor = 1.0 + else: + linear_factor = 1.0 + ntk_factor = scaling_factor + image_rotary_emb = get_2d_rotary_pos_embed_lumina( + self.transformer.head_dim, + 384, + 384, + linear_factor=linear_factor, + ntk_factor=ntk_factor, + ) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=current_timestep, + encoder_hidden_states=prompt_embeds, + encoder_mask=prompt_attention_mask, + image_rotary_emb=image_rotary_emb, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_pred.chunk(2, dim=1)[0] + + # perform guidance scale + # NOTE: For exact reproducibility reasons, we apply classifier-free guidance on only + # three channels by default. The standard approach to cfg applies it to all channels. + # This can be done by uncommenting the following line and commenting-out the line following that. + # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:] + if do_classifier_free_guidance: + noise_pred_eps, noise_pred_rest = noise_pred[:, :3], noise_pred[:, 3:] + noise_pred_cond_eps, noise_pred_uncond_eps = torch.split( + noise_pred_eps, len(noise_pred_eps) // 2, dim=0 + ) + noise_pred_half = noise_pred_uncond_eps + guidance_scale * ( + noise_pred_cond_eps - noise_pred_uncond_eps + ) + noise_pred_eps = torch.cat([noise_pred_half, noise_pred_half], dim=0) + + noise_pred = torch.cat([noise_pred_eps, noise_pred_rest], dim=1) + noise_pred, _ = noise_pred.chunk(2, dim=0) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + noise_pred = -noise_pred + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + progress_bar.update() + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + latents = latents / self.vae.config.scaling_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + else: + image = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) + + +class LuminaText2ImgPipeline(LuminaPipeline): + def __init__( + self, + transformer: LuminaNextDiT2DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: GemmaPreTrainedModel, + tokenizer: GemmaTokenizer | GemmaTokenizerFast, + ): + deprecation_message = "`LuminaText2ImgPipeline` has been renamed to `LuminaPipeline` and will be removed in a future version. Please use `LuminaPipeline` instead." + deprecate("diffusers.pipelines.lumina.pipeline_lumina.LuminaText2ImgPipeline", "0.34", deprecation_message) + super().__init__( + transformer=transformer, + scheduler=scheduler, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + ) diff --git a/diffusers/pipelines/ovis_image/__init__.py b/diffusers/pipelines/ovis_image/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..275061b1f6260c9cccc38d0116e6bf582e3d7e19 --- /dev/null +++ b/diffusers/pipelines/ovis_image/__init__.py @@ -0,0 +1,50 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa: F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_output"] = ["OvisImagePipelineOutput"] + _import_structure["pipeline_ovis_image"] = ["OvisImagePipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_output import OvisImagePipelineOutput + from .pipeline_ovis_image import OvisImagePipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/pipelines/ovis_image/pipeline_output.py b/diffusers/pipelines/ovis_image/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..4daecf3b2197565c973c1efb9d887a27677fbc9b --- /dev/null +++ b/diffusers/pipelines/ovis_image/pipeline_output.py @@ -0,0 +1,34 @@ +# Copyright 2025 Alibaba Ovis-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +import numpy as np +import PIL.Image + +from diffusers.utils import BaseOutput + + +@dataclass +class OvisImagePipelineOutput(BaseOutput): + """ + Output class for Ovis-Image pipelines. + + Args: + images (`list[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: list[PIL.Image.Image, np.ndarray] diff --git a/diffusers/pipelines/ovis_image/pipeline_ovis_image.py b/diffusers/pipelines/ovis_image/pipeline_ovis_image.py new file mode 100644 index 0000000000000000000000000000000000000000..c8ff8227f27e7e15f2664e74413fd60bbdf13102 --- /dev/null +++ b/diffusers/pipelines/ovis_image/pipeline_ovis_image.py @@ -0,0 +1,668 @@ +# Copyright 2025 Alibaba Ovis-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import torch +from transformers import Qwen2TokenizerFast, Qwen3Model + +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKL, OvisImageTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import OvisImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import OvisImagePipeline + + >>> pipe = OvisImagePipeline.from_pretrained("AIDC-AI/Ovis-Image-7B", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + >>> prompt = 'A creative 3D artistic render where the text "OVIS-IMAGE" is written in a bold, expressive handwritten brush style using thick, wet oil paint. The paint is a mix of vibrant rainbow colors (red, blue, yellow) swirling together like toothpaste or impasto art. You can see the ridges of the brush bristles and the glossy, wet texture of the paint. The background is a clean artist\'s canvas. Dynamic lighting creates soft shadows behind the floating paint strokes. Colorful, expressive, tactile texture, 4k detail.' + >>> image = pipe(prompt, negative_prompt="", num_inference_steps=50, guidance_scale=5.0).images[0] + >>> image.save("ovis_image.png") + ``` +""" + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class OvisImagePipeline( + DiffusionPipeline, +): + r""" + The Ovis-Image pipeline for text-to-image generation. + + Reference: https://github.com/AIDC-AI/Ovis-Image + + Args: + transformer ([`OvisImageTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`Qwen3Model`]): + Text encoder of class + [Qwen3Model](https://huggingface.co/docs/transformers/en/model_doc/qwen3#transformers.Qwen3Model). + tokenizer (`Qwen2TokenizerFast`): + Tokenizer of class + [Qwen2TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/qwen2#transformers.Qwen2TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: Qwen3Model, + tokenizer: Qwen2TokenizerFast, + transformer: OvisImageTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + # Ovis-Image latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.system_prompt = "Describe the image by detailing the color, quantity, text, shape, size, texture, spatial relationships of the objects and background: " + self.user_prompt_begin_id = 28 + self.tokenizer_max_length = 256 + self.user_prompt_begin_id + self.default_sample_size = 128 + + def _get_messages( + self, + prompt: str | list[str] = None, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + messages = [] + for each_prompt in prompt: + message = [ + { + "role": "user", + "content": self.system_prompt + each_prompt, + } + ] + message = self.tokenizer.apply_chat_template( + message, tokenize=False, add_generation_prompt=True, enable_thinking=False + ) + messages.append(message) + return messages + + def _get_ovis_prompt_embeds( + self, + prompt: str | list[str] = None, + num_images_per_prompt: int = 1, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + messages = self._get_messages(prompt) + batch_size = len(messages) + + tokens = self.tokenizer( + messages, + padding="max_length", + truncation=True, + max_length=self.tokenizer_max_length, + return_tensors="pt", + add_special_tokens=False, + ) + input_ids = tokens.input_ids.to(device) + attention_mask = tokens.attention_mask.to(device) + outputs = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + ) + prompt_embeds = outputs.last_hidden_state + prompt_embeds = prompt_embeds * attention_mask[..., None] + prompt_embeds = prompt_embeds[:, self.user_prompt_begin_id :, :] + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: str | list[str], + device: torch.device | None = None, + num_images_per_prompt: int = 1, + prompt_embeds: torch.FloatTensor | None = None, + ): + r""" + + Args: + prompt (`str`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + """ + device = device or self._execution_device + + if prompt_embeds is None: + prompt_embeds = self._get_ovis_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3) + text_ids[..., 1] = text_ids[..., 1] + torch.arange(prompt_embeds.shape[1])[None, :] + text_ids[..., 2] = text_ids[..., 2] + torch.arange(prompt_embeds.shape[1])[None, :] + text_ids = text_ids.to(device=device, dtype=dtype) + return prompt_embeds, text_ids + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if max_sequence_length is not None and max_sequence_length > 256: + raise ValueError(f"`max_sequence_length` cannot be greater than 256 but is {max_sequence_length}") + + @staticmethod + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is not None: + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + return latents.to(device=device, dtype=dtype), latent_image_ids + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + + return latents, latent_image_ids + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str | list[str] = "", + guidance_scale: float = 5.0, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + sigmas: list[float] | None = None, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.FloatTensor | None = None, + prompt_embeds: torch.FloatTensor | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + joint_attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 256, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + not greater than `1`). + guidance_scale (`float`, *optional*, defaults to 1.0): + True classifier-free guidance (guidance scale) is enabled when `guidance_scale` > 1 and + `negative_prompt` is provided. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.ovis_image.OvisImagePipelineOutput`] or `tuple`: + [`~pipelines.ovis_image.OvisImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is a list with the generated images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._joint_attention_kwargs = joint_attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + do_classifier_free_guidance = guidance_scale > 1 + ( + prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_embeds=prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + if do_classifier_free_guidance: + ( + negative_prompt_embeds, + negative_text_ids, + ) = self.encode_prompt( + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, latent_image_ids = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas: + sigmas = None + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + if self.joint_attention_kwargs is None: + self._joint_attention_kwargs = {} + + # 6. Denoising loop + # We set the index here to remove DtoH sync, helpful especially during compilation. + # Check out more details here: https://github.com/huggingface/diffusers/pull/11696 + self.scheduler.set_begin_index(0) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + return_dict=False, + )[0] + + if do_classifier_free_guidance: + with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + encoder_hidden_states=negative_prompt_embeds, + txt_ids=negative_text_ids, + img_ids=latent_image_ids, + return_dict=False, + )[0] + noise_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return OvisImagePipelineOutput(images=image) diff --git a/diffusers/pipelines/pag/__init__.py b/diffusers/pipelines/pag/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..176efe3adef6a792ba2e7fef194e2a0a9475fbfd --- /dev/null +++ b/diffusers/pipelines/pag/__init__.py @@ -0,0 +1,80 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_flax_available, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_pag_controlnet_sd"] = ["StableDiffusionControlNetPAGPipeline"] + _import_structure["pipeline_pag_controlnet_sd_inpaint"] = ["StableDiffusionControlNetPAGInpaintPipeline"] + _import_structure["pipeline_pag_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPAGPipeline"] + _import_structure["pipeline_pag_controlnet_sd_xl_img2img"] = ["StableDiffusionXLControlNetPAGImg2ImgPipeline"] + _import_structure["pipeline_pag_hunyuandit"] = ["HunyuanDiTPAGPipeline"] + _import_structure["pipeline_pag_kolors"] = ["KolorsPAGPipeline"] + _import_structure["pipeline_pag_pixart_sigma"] = ["PixArtSigmaPAGPipeline"] + _import_structure["pipeline_pag_sana"] = ["SanaPAGPipeline"] + _import_structure["pipeline_pag_sd"] = ["StableDiffusionPAGPipeline"] + _import_structure["pipeline_pag_sd_3"] = ["StableDiffusion3PAGPipeline"] + _import_structure["pipeline_pag_sd_3_img2img"] = ["StableDiffusion3PAGImg2ImgPipeline"] + _import_structure["pipeline_pag_sd_animatediff"] = ["AnimateDiffPAGPipeline"] + _import_structure["pipeline_pag_sd_img2img"] = ["StableDiffusionPAGImg2ImgPipeline"] + _import_structure["pipeline_pag_sd_inpaint"] = ["StableDiffusionPAGInpaintPipeline"] + + _import_structure["pipeline_pag_sd_xl"] = ["StableDiffusionXLPAGPipeline"] + _import_structure["pipeline_pag_sd_xl_img2img"] = ["StableDiffusionXLPAGImg2ImgPipeline"] + _import_structure["pipeline_pag_sd_xl_inpaint"] = ["StableDiffusionXLPAGInpaintPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_pag_controlnet_sd import StableDiffusionControlNetPAGPipeline + from .pipeline_pag_controlnet_sd_inpaint import StableDiffusionControlNetPAGInpaintPipeline + from .pipeline_pag_controlnet_sd_xl import StableDiffusionXLControlNetPAGPipeline + from .pipeline_pag_controlnet_sd_xl_img2img import StableDiffusionXLControlNetPAGImg2ImgPipeline + from .pipeline_pag_hunyuandit import HunyuanDiTPAGPipeline + from .pipeline_pag_kolors import KolorsPAGPipeline + from .pipeline_pag_pixart_sigma import PixArtSigmaPAGPipeline + from .pipeline_pag_sana import SanaPAGPipeline + from .pipeline_pag_sd import StableDiffusionPAGPipeline + from .pipeline_pag_sd_3 import StableDiffusion3PAGPipeline + from .pipeline_pag_sd_3_img2img import StableDiffusion3PAGImg2ImgPipeline + from .pipeline_pag_sd_animatediff import AnimateDiffPAGPipeline + from .pipeline_pag_sd_img2img import StableDiffusionPAGImg2ImgPipeline + from .pipeline_pag_sd_inpaint import StableDiffusionPAGInpaintPipeline + from .pipeline_pag_sd_xl import StableDiffusionXLPAGPipeline + from .pipeline_pag_sd_xl_img2img import StableDiffusionXLPAGImg2ImgPipeline + from .pipeline_pag_sd_xl_inpaint import StableDiffusionXLPAGInpaintPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/pipelines/pag/pag_utils.py b/diffusers/pipelines/pag/pag_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..41395ece742155a9f0a9003cf53e398ca8661c23 --- /dev/null +++ b/diffusers/pipelines/pag/pag_utils.py @@ -0,0 +1,242 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re + +import torch +import torch.nn as nn + +from ...models.attention_processor import ( + Attention, + AttentionProcessor, + PAGCFGIdentitySelfAttnProcessor2_0, + PAGIdentitySelfAttnProcessor2_0, +) +from ...utils import logging + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class PAGMixin: + r"""Mixin class for [Pertubed Attention Guidance](https://huggingface.co/papers/2403.17377v1).""" + + def _set_pag_attn_processor(self, pag_applied_layers, do_classifier_free_guidance): + r""" + Set the attention processor for the PAG layers. + """ + pag_attn_processors = self._pag_attn_processors + if pag_attn_processors is None: + raise ValueError( + "No PAG attention processors have been set. Set the attention processors by calling `set_pag_applied_layers` and passing the relevant parameters." + ) + + pag_attn_proc = pag_attn_processors[0] if do_classifier_free_guidance else pag_attn_processors[1] + + if hasattr(self, "unet"): + model: nn.Module = self.unet + else: + model: nn.Module = self.transformer + + def is_self_attn(module: nn.Module) -> bool: + r""" + Check if the module is self-attention module based on its name. + """ + return isinstance(module, Attention) and not module.is_cross_attention + + def is_fake_integral_match(layer_id, name): + layer_id = layer_id.split(".")[-1] + name = name.split(".")[-1] + return layer_id.isnumeric() and name.isnumeric() and layer_id == name + + for layer_id in pag_applied_layers: + # for each PAG layer input, we find corresponding self-attention layers in the unet model + target_modules = [] + + for name, module in model.named_modules(): + # Identify the following simple cases: + # (1) Self Attention layer existing + # (2) Whether the module name matches pag layer id even partially + # (3) Make sure it's not a fake integral match if the layer_id ends with a number + # For example, blocks.1, blocks.10 should be differentiable if layer_id="blocks.1" + if ( + is_self_attn(module) + and re.search(layer_id, name) is not None + and not is_fake_integral_match(layer_id, name) + ): + logger.debug(f"Applying PAG to layer: {name}") + target_modules.append(module) + + if len(target_modules) == 0: + raise ValueError(f"Cannot find PAG layer to set attention processor for: {layer_id}") + + for module in target_modules: + module.processor = pag_attn_proc + + def _get_pag_scale(self, t): + r""" + Get the scale factor for the perturbed attention guidance at timestep `t`. + """ + + if self.do_pag_adaptive_scaling: + signal_scale = self.pag_scale - self.pag_adaptive_scale * (1000 - t) + if signal_scale < 0: + signal_scale = 0 + return signal_scale + else: + return self.pag_scale + + def _apply_perturbed_attention_guidance( + self, noise_pred, do_classifier_free_guidance, guidance_scale, t, return_pred_text=False + ): + r""" + Apply perturbed attention guidance to the noise prediction. + + Args: + noise_pred (torch.Tensor): The noise prediction tensor. + do_classifier_free_guidance (bool): Whether to apply classifier-free guidance. + guidance_scale (float): The scale factor for the guidance term. + t (int): The current time step. + return_pred_text (bool): Whether to return the text noise prediction. + + Returns: + torch.Tensor | tuple[torch.Tensor, torch.Tensor]: The updated noise prediction tensor after applying + perturbed attention guidance and the text noise prediction. + """ + pag_scale = self._get_pag_scale(t) + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3) + noise_pred = ( + noise_pred_uncond + + guidance_scale * (noise_pred_text - noise_pred_uncond) + + pag_scale * (noise_pred_text - noise_pred_perturb) + ) + else: + noise_pred_text, noise_pred_perturb = noise_pred.chunk(2) + noise_pred = noise_pred_text + pag_scale * (noise_pred_text - noise_pred_perturb) + if return_pred_text: + return noise_pred, noise_pred_text + return noise_pred + + def _prepare_perturbed_attention_guidance(self, cond, uncond, do_classifier_free_guidance): + """ + Prepares the perturbed attention guidance for the PAG model. + + Args: + cond (torch.Tensor): The conditional input tensor. + uncond (torch.Tensor): The unconditional input tensor. + do_classifier_free_guidance (bool): Flag indicating whether to perform classifier-free guidance. + + Returns: + torch.Tensor: The prepared perturbed attention guidance tensor. + """ + + cond = torch.cat([cond] * 2, dim=0) + + if do_classifier_free_guidance: + cond = torch.cat([uncond, cond], dim=0) + return cond + + def set_pag_applied_layers( + self, + pag_applied_layers: str | list[str], + pag_attn_processors: tuple[AttentionProcessor, AttentionProcessor] = ( + PAGCFGIdentitySelfAttnProcessor2_0(), + PAGIdentitySelfAttnProcessor2_0(), + ), + ): + r""" + Set the self-attention layers to apply PAG. Raise ValueError if the input is invalid. + + Args: + pag_applied_layers (`str` or `list[str]`): + One or more strings identifying the layer names, or a simple regex for matching multiple layers, where + PAG is to be applied. A few ways of expected usage are as follows: + - Single layers specified as - "blocks.{layer_index}" + - Multiple layers as a list - ["blocks.{layers_index_1}", "blocks.{layer_index_2}", ...] + - Multiple layers as a block name - "mid" + - Multiple layers as regex - "blocks.({layer_index_1}|{layer_index_2})" + pag_attn_processors: + (`tuple[AttentionProcessor, AttentionProcessor]`, defaults to `(PAGCFGIdentitySelfAttnProcessor2_0(), + PAGIdentitySelfAttnProcessor2_0())`): A tuple of two attention processors. The first attention + processor is for PAG with Classifier-free guidance enabled (conditional and unconditional). The second + attention processor is for PAG with CFG disabled (unconditional only). + """ + + if not hasattr(self, "_pag_attn_processors"): + self._pag_attn_processors = None + + if not isinstance(pag_applied_layers, list): + pag_applied_layers = [pag_applied_layers] + if pag_attn_processors is not None: + if not isinstance(pag_attn_processors, tuple) or len(pag_attn_processors) != 2: + raise ValueError("Expected a tuple of two attention processors") + + for i in range(len(pag_applied_layers)): + if not isinstance(pag_applied_layers[i], str): + raise ValueError( + f"Expected either a string or a list of string but got type {type(pag_applied_layers[i])}" + ) + + self.pag_applied_layers = pag_applied_layers + self._pag_attn_processors = pag_attn_processors + + @property + def pag_scale(self) -> float: + r"""Get the scale factor for the perturbed attention guidance.""" + return self._pag_scale + + @property + def pag_adaptive_scale(self) -> float: + r"""Get the adaptive scale factor for the perturbed attention guidance.""" + return self._pag_adaptive_scale + + @property + def do_pag_adaptive_scaling(self) -> bool: + r"""Check if the adaptive scaling is enabled for the perturbed attention guidance.""" + return self._pag_adaptive_scale > 0 and self._pag_scale > 0 and len(self.pag_applied_layers) > 0 + + @property + def do_perturbed_attention_guidance(self) -> bool: + r"""Check if the perturbed attention guidance is enabled.""" + return self._pag_scale > 0 and len(self.pag_applied_layers) > 0 + + @property + def pag_attn_processors(self) -> dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of PAG attention processors: A dictionary contains all PAG attention processors used in the model + with the key as the name of the layer. + """ + + if self._pag_attn_processors is None: + return {} + + valid_attn_processors = {x.__class__ for x in self._pag_attn_processors} + + processors = {} + # We could have iterated through the self.components.items() and checked if a component is + # `ModelMixin` subclassed but that can include a VAE too. + if hasattr(self, "unet"): + denoiser_module = self.unet + elif hasattr(self, "transformer"): + denoiser_module = self.transformer + else: + raise ValueError("No denoiser module found.") + + for name, proc in denoiser_module.attn_processors.items(): + if proc.__class__ in valid_attn_processors: + processors[name] = proc + + return processors diff --git a/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py b/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py new file mode 100644 index 0000000000000000000000000000000000000000..807c42d21bb4f3bc858bcf25bb4a9d1528e888b3 --- /dev/null +++ b/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py @@ -0,0 +1,1348 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import inspect +from typing import Any, Callable + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ControlNetModel, ImageProjection, MultiControlNetModel, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import empty_device_cache, is_compiled_module, is_torch_version, randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from .pag_utils import PAGMixin + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # !pip install opencv-python transformers accelerate + >>> from diffusers import AutoPipelineForText2Image, ControlNetModel, UniPCMultistepScheduler + >>> from diffusers.utils import load_image + >>> import numpy as np + >>> import torch + + >>> import cv2 + >>> from PIL import Image + + >>> # download an image + >>> image = load_image( + ... "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png" + ... ) + >>> image = np.array(image) + + >>> # get canny image + >>> image = cv2.Canny(image, 100, 200) + >>> image = image[:, :, None] + >>> image = np.concatenate([image, image, image], axis=2) + >>> canny_image = Image.fromarray(image) + + >>> # load control net and stable diffusion v1-5 + >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16) + >>> pipe = AutoPipelineForText2Image.from_pretrained( + ... "stable-diffusion-v1-5/stable-diffusion-v1-5", + ... controlnet=controlnet, + ... torch_dtype=torch.float16, + ... enable_pag=True, + ... ) + + >>> # speed up diffusion process with faster scheduler and memory optimization + >>> # remove following line if xformers is not installed + >>> pipe.enable_xformers_memory_efficient_attention() + + >>> pipe.enable_model_cpu_offload() + + >>> # generate image + >>> generator = torch.manual_seed(0) + >>> image = pipe( + ... "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting", + ... guidance_scale=7.5, + ... generator=generator, + ... image=canny_image, + ... pag_scale=10, + ... ).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusionControlNetPAGPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionLoraLoaderMixin, + IPAdapterMixin, + FromSingleFileMixin, + PAGMixin, +): + r""" + Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + controlnet ([`ControlNetModel`] or `list[ControlNetModel]`): + Provides additional conditioning to the `unet` during the denoising process. If you set multiple + ControlNets as a list, the outputs from each ControlNet are added together to create one combined + additional conditioning. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for + more details about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + controlnet: ControlNetModel | list[ControlNetModel] | tuple[ControlNetModel] | MultiControlNetModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection = None, + requires_safety_checker: bool = True, + pag_applied_layers: str | list[str] = "mid", + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetModel(controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + self.set_pag_applied_layers(pag_applied_layers) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, + ): + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Check `image` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + self.check_image(image, prompt, prompt_embeds) + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if not isinstance(image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + + # When `image` is a nested list: + # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) + elif any(isinstance(i, list) for i in image): + transposed_image = [list(t) for t in zip(*image)] + if len(transposed_image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: if you pass`image` as a list of list, each sublist must have the same length as the number of controlnets, but the sublists in `image` got {len(transposed_image)} images and {len(self.controlnet.nets)} ControlNets." + ) + for image_ in transposed_image: + self.check_image(image_, prompt, prompt_embeds) + elif len(image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." + ) + else: + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if isinstance(controlnet_conditioning_scale, list): + if any(isinstance(i, list) for i in controlnet_conditioning_scale): + raise ValueError( + "A single batch of varying conditioning scale settings (e.g. [[1.0, 0.5], [0.2, 0.8]]) is not supported at the moment. " + "The conditioning scale must be fixed across the batch." + ) + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + assert False + + if not isinstance(control_guidance_start, (tuple, list)): + control_guidance_start = [control_guidance_start] + + if not isinstance(control_guidance_end, (tuple, list)): + control_guidance_end = [control_guidance_end] + + if len(control_guidance_start) != len(control_guidance_end): + raise ValueError( + f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." + ) + + if isinstance(self.controlnet, MultiControlNetModel): + if len(control_guidance_start) != len(self.controlnet.nets): + raise ValueError( + f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." + ) + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + image: PipelineImageInput = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + timesteps: list[int] = None, + sigmas: list[float] = None, + guidance_scale: float = 7.5, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + controlnet_conditioning_scale: float | list[float] = 1.0, + guess_mode: bool = False, + control_guidance_start: float | list[float] = 0.0, + control_guidance_end: float | list[float] = 1.0, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + pag_scale: float = 3.0, + pag_adaptive_scale: float = 0.0, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, `list[np.ndarray]`,: + `list[list[torch.Tensor]]`, `list[list[np.ndarray]]` or `list[list[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. When `prompt` is a list, and if a list of images is passed for a single + ControlNet, each will be paired with each prompt in the `prompt` list. This also applies to multiple + ControlNets, where a list of image lists can be passed to batch for each prompt and each ControlNet. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `list[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + guess_mode (`bool`, *optional*, defaults to `False`): + The ControlNet encoder tries to recognize the content of the input image even if you remove all + prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended. + control_guidance_start (`float` or `list[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `list[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + pag_scale (`float`, *optional*, defaults to 3.0): + The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention + guidance will not be used. + pag_adaptive_scale (`float`, *optional*, defaults to 0.0): + The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is + used. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + image, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._pag_scale = pag_scale + self._pag_adaptive_scale = pag_adaptive_scale + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + + # 3. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_perturbed_attention_guidance: + prompt_embeds = self._prepare_perturbed_attention_guidance( + prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Prepare image + if isinstance(controlnet, ControlNetModel): + image = self.prepare_image( + image=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + height, width = image.shape[-2:] + elif isinstance(controlnet, MultiControlNetModel): + images = [] + + # Nested lists as ControlNet condition + if isinstance(image[0], list): + # Transpose the nested image list + image = [list(t) for t in zip(*image)] + + for image_ in image: + image_ = self.prepare_image( + image=image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + images.append(image_) + + image = images + height, width = image[0].shape[-2:] + else: + assert False + + # 5. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas + ) + self._num_timesteps = len(timesteps) + + # 6. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6.5 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Add image embeds for IP-Adapter + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + for i, image_embeds in enumerate(ip_adapter_image_embeds): + negative_image_embeds = None + if self.do_classifier_free_guidance: + negative_image_embeds, image_embeds = image_embeds.chunk(2) + + if self.do_perturbed_attention_guidance: + image_embeds = self._prepare_perturbed_attention_guidance( + image_embeds, negative_image_embeds, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) + image_embeds = image_embeds.to(device) + ip_adapter_image_embeds[i] = image_embeds + + added_cond_kwargs = ( + {"image_embeds": ip_adapter_image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) + + controlnet_prompt_embeds = prompt_embeds + + # 7.2 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + + images = image if isinstance(image, list) else [image] + for i, single_image in enumerate(images): + if self.do_classifier_free_guidance: + single_image = single_image.chunk(2)[0] + + if self.do_perturbed_attention_guidance: + single_image = self._prepare_perturbed_attention_guidance( + single_image, single_image, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + single_image = torch.cat([single_image] * 2) + single_image = single_image.to(device) + images[i] = single_image + + image = images if isinstance(image, list) else images[0] + + # 8. Denoising loop + if self.do_perturbed_attention_guidance: + original_attn_proc = self.unet.attn_processors + self._set_pag_attn_processor( + pag_applied_layers=self.pag_applied_layers, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + is_unet_compiled = is_compiled_module(self.unet) + is_controlnet_compiled = is_compiled_module(self.controlnet) + is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Relevant thread: + # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 + if ( + torch.cuda.is_available() + and (is_unet_compiled and is_controlnet_compiled) + and is_torch_higher_equal_2_1 + ): + torch._inductor.cudagraph_mark_step_begin() + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # controlnet(s) inference + control_model_input = latent_model_input + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=image, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + return_dict=False, + ) + + if guess_mode and self.do_classifier_free_guidance: + # Inferred ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_perturbed_attention_guidance: + noise_pred = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t + ) + elif self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + # If we do sequential model offloading, let's offload unet and controlnet + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + self.controlnet.to("cpu") + empty_device_cache() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if self.do_perturbed_attention_guidance: + self.unet.set_attn_processor(original_attn_proc) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py b/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py new file mode 100644 index 0000000000000000000000000000000000000000..ebc2e882868cc2d2ddc9f83f3711af950921a298 --- /dev/null +++ b/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py @@ -0,0 +1,1552 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This model implementation is heavily inspired by https://github.com/haofanwang/ControlNet-for-Diffusers/ + +import inspect +from typing import Any, Callable + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ControlNetModel, ImageProjection, MultiControlNetModel, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion import StableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from .pag_utils import PAGMixin + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # !pip install transformers accelerate + >>> import cv2 + >>> from diffusers import AutoPipelineForInpainting, ControlNetModel, DDIMScheduler + >>> from diffusers.utils import load_image + >>> import numpy as np + >>> from PIL import Image + >>> import torch + + >>> init_image = load_image( + ... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy.png" + ... ) + >>> init_image = init_image.resize((512, 512)) + + >>> generator = torch.Generator(device="cpu").manual_seed(1) + + >>> mask_image = load_image( + ... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy_mask.png" + ... ) + >>> mask_image = mask_image.resize((512, 512)) + + + >>> def make_canny_condition(image): + ... image = np.array(image) + ... image = cv2.Canny(image, 100, 200) + ... image = image[:, :, None] + ... image = np.concatenate([image, image, image], axis=2) + ... image = Image.fromarray(image) + ... return image + + + >>> control_image = make_canny_condition(init_image) + + >>> controlnet = ControlNetModel.from_pretrained( + ... "lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float16 + ... ) + >>> pipe = AutoPipelineForInpainting.from_pretrained( + ... "stable-diffusion-v1-5/stable-diffusion-v1-5", + ... controlnet=controlnet, + ... torch_dtype=torch.float16, + ... enable_pag=True, + ... ) + + >>> pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + >>> pipe.enable_model_cpu_offload() + + >>> # generate image + >>> image = pipe( + ... "a handsome man with ray-ban sunglasses", + ... num_inference_steps=20, + ... generator=generator, + ... eta=1.0, + ... image=init_image, + ... mask_image=mask_image, + ... control_image=control_image, + ... pag_scale=0.3, + ... ).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class StableDiffusionControlNetPAGInpaintPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionLoraLoaderMixin, + IPAdapterMixin, + FromSingleFileMixin, + PAGMixin, +): + r""" + Pipeline for image inpainting using Stable Diffusion with ControlNet guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + > [!TIP] > This pipeline can be used with checkpoints that have been specifically fine-tuned for inpainting > + ([stable-diffusion-v1-5/stable-diffusion-inpainting](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-inpainting)) + as well as > default text-to-image Stable Diffusion checkpoints > + ([stable-diffusion-v1-5/stable-diffusion-v1-5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5)). + Default text-to-image > Stable Diffusion checkpoints might be preferable for ControlNets that have been fine-tuned + on those, such as > + [lllyasviel/control_v11p_sd15_inpaint](https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint). + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + controlnet ([`ControlNetModel`] or `list[ControlNetModel]`): + Provides additional conditioning to the `unet` during the denoising process. If you set multiple + ControlNets as a list, the outputs from each ControlNet are added together to create one combined + additional conditioning. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for + more details about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + controlnet: ControlNetModel | list[ControlNetModel] | tuple[ControlNetModel] | MultiControlNetModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection = None, + requires_safety_checker: bool = True, + pag_applied_layers: str | list[str] = "mid", + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetModel(controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True + ) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + self.set_pag_applied_layers(pag_applied_layers) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def check_inputs( + self, + prompt, + image, + mask_image, + height, + width, + output_type, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, + padding_mask_crop=None, + ): + if height is not None and height % 8 != 0 or width is not None and width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if padding_mask_crop is not None: + if not isinstance(image, PIL.Image.Image): + raise ValueError( + f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}." + ) + if not isinstance(mask_image, PIL.Image.Image): + raise ValueError( + f"The mask image should be a PIL image when inpainting mask crop, but is of type" + f" {type(mask_image)}." + ) + if output_type != "pil": + raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.") + + # `prompt` needs more sophisticated handling when there are multiple + # conditionings. + if isinstance(self.controlnet, MultiControlNetModel): + if isinstance(prompt, list): + logger.warning( + f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" + " prompts. The conditionings will be fixed across the prompts." + ) + + # Check `image` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + self.check_image(image, prompt, prompt_embeds) + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if not isinstance(image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + + # When `image` is a nested list: + # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) + elif any(isinstance(i, list) for i in image): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif len(image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." + ) + + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if isinstance(controlnet_conditioning_scale, list): + if any(isinstance(i, list) for i in controlnet_conditioning_scale): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + assert False + + if len(control_guidance_start) != len(control_guidance_end): + raise ValueError( + f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." + ) + + if isinstance(self.controlnet, MultiControlNetModel): + if len(control_guidance_start) != len(self.controlnet.nets): + raise ValueError( + f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." + ) + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_inpaint.StableDiffusionControlNetInpaintPipeline.prepare_control_image + def prepare_control_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + crops_coords, + resize_mode, + do_classifier_free_guidance=False, + guess_mode=False, + ): + image = self.control_image_processor.preprocess( + image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode + ).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_latents + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + image=None, + timestep=None, + is_strength_max=True, + return_noise=False, + return_image_latents=False, + ): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if (image is None or timestep is None) and not is_strength_max: + raise ValueError( + "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." + "However, either the image or the noise timestep has not been provided." + ) + + if return_image_latents or (latents is None and not is_strength_max): + image = image.to(device=device, dtype=dtype) + + if image.shape[1] == 4: + image_latents = image + else: + image_latents = self._encode_vae_image(image=image, generator=generator) + image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) + + if latents is None: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # if strength is 1. then initialise the latents to noise, else initial to image + noise + latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep) + # if pure noise then scale the initial latents by the Scheduler's init sigma + latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents + else: + noise = latents.to(device) + latents = noise * self.scheduler.init_noise_sigma + + outputs = (latents,) + + if return_noise: + outputs += (noise,) + + if return_image_latents: + outputs += (image_latents,) + + return outputs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_mask_latents + def prepare_mask_latents( + self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + masked_image = masked_image.to(device=device, dtype=dtype) + + if masked_image.shape[1] == 4: + masked_image_latents = masked_image + else: + masked_image_latents = self._encode_vae_image(masked_image, generator=generator) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) + + mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask + masked_image_latents = ( + torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + return mask, masked_image_latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + image_latents = self.vae.config.scaling_factor * image_latents + + return image_latents + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + image: PipelineImageInput = None, + mask_image: PipelineImageInput = None, + control_image: PipelineImageInput = None, + height: int | None = None, + width: int | None = None, + padding_mask_crop: int | None = None, + strength: float = 1.0, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + controlnet_conditioning_scale: float | list[float] = 0.5, + control_guidance_start: float | list[float] = 0.0, + control_guidance_end: float | list[float] = 1.0, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + pag_scale: float = 3.0, + pag_adaptive_scale: float = 0.0, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, + `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image`, NumPy array or tensor representing an image batch to be used as the starting point. For both + NumPy array and PyTorch tensor, the expected value range is between `[0, 1]`. If it's a tensor or a + list or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a NumPy array or + a list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)`. It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, + `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image`, NumPy array or tensor representing an image batch to mask `image`. White pixels in the mask + are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a + single channel (luminance) before use. If it's a NumPy array or PyTorch tensor, it should contain one + color channel (L) instead of 3, so the expected shape for PyTorch tensor would be `(B, 1, H, W)`, `(B, + H, W)`, `(1, H, W)`, `(H, W)`. And for NumPy array, it would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, + W, 1)`, or `(H, W)`. + control_image (`torch.Tensor`, `PIL.Image.Image`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, + `list[list[torch.Tensor]]`, or `list[list[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + padding_mask_crop (`int`, *optional*, defaults to `None`): + The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to + image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region + with the same aspect ration of the image and contains all masked area, and then expand that area based + on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before + resizing to the original image size for inpainting. This is useful when the masked area is small while + the image is large and contain information irrelevant for inpainting, such as background. + strength (`float`, *optional*, defaults to 1.0): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `list[float]`, *optional*, defaults to 0.5): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + control_guidance_start (`float` or `list[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `list[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + pag_scale (`float`, *optional*, defaults to 3.0): + The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention + guidance will not be used. + pag_adaptive_scale (`float`, *optional*, defaults to 0.0): + The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is + used. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + control_image, + mask_image, + height, + width, + output_type, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + callback_on_step_end_tensor_inputs, + padding_mask_crop, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._pag_scale = pag_scale + self._pag_adaptive_scale = pag_adaptive_scale + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if padding_mask_crop is not None: + height, width = self.image_processor.get_default_height_width(image, height, width) + crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop) + resize_mode = "fill" + else: + crops_coords = None + resize_mode = "default" + + device = self._execution_device + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + # 3. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_perturbed_attention_guidance: + prompt_embeds = self._prepare_perturbed_attention_guidance( + prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 4. Prepare control image + if isinstance(controlnet, ControlNetModel): + control_image = self.prepare_control_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + crops_coords=crops_coords, + resize_mode=resize_mode, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=False, + ) + elif isinstance(controlnet, MultiControlNetModel): + control_images = [] + + for control_image_ in control_image: + control_image_ = self.prepare_control_image( + image=control_image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + crops_coords=crops_coords, + resize_mode=resize_mode, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=False, + ) + + control_images.append(control_image_) + + control_image = control_images + else: + assert False + + # 4.1 Preprocess mask and image - resizes image and mask w.r.t height and width + original_image = image + init_image = self.image_processor.preprocess( + image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode + ) + init_image = init_image.to(dtype=torch.float32) + + mask = self.mask_processor.preprocess( + mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords + ) + + masked_image = init_image * (mask < 0.5) + _, _, height, width = init_image.shape + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps=num_inference_steps, strength=strength, device=device + ) + # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise + is_strength_max = strength == 1.0 + self._num_timesteps = len(timesteps) + + # 6. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + num_channels_unet = self.unet.config.in_channels + return_image_latents = num_channels_unet == 4 + latents_outputs = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + image=init_image, + timestep=latent_timestep, + is_strength_max=is_strength_max, + return_noise=True, + return_image_latents=return_image_latents, + ) + + if return_image_latents: + latents, noise, image_latents = latents_outputs + else: + latents, noise = latents_outputs + + # 7. Prepare mask latent variables + mask, masked_image_latents = self.prepare_mask_latents( + mask, + masked_image, + batch_size * num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + self.do_classifier_free_guidance, + ) + + # 7.1 Check that sizes of mask, masked image and latents match + if num_channels_unet == 9: + # default case for stable-diffusion-v1-5/stable-diffusion-inpainting + num_channels_mask = mask.shape[1] + num_channels_masked_image = masked_image_latents.shape[1] + if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of" + " `pipeline.unet` or your `mask_image` or `image` input." + ) + elif num_channels_unet != 4: + raise ValueError( + f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}." + ) + + # 7.2 Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.3 Prepare embeddings + # ip-adapter + if ip_adapter_image_embeds is not None: + for i, image_embeds in enumerate(ip_adapter_image_embeds): + negative_image_embeds = None + if self.do_classifier_free_guidance: + negative_image_embeds, image_embeds = image_embeds.chunk(2) + + if self.do_perturbed_attention_guidance: + image_embeds = self._prepare_perturbed_attention_guidance( + image_embeds, negative_image_embeds, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) + image_embeds = image_embeds.to(device) + ip_adapter_image_embeds[i] = image_embeds + + added_cond_kwargs = ( + {"image_embeds": ip_adapter_image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) + + # control image + control_images = control_image if isinstance(control_image, list) else [control_image] + for i, single_control_image in enumerate(control_images): + if self.do_classifier_free_guidance: + single_control_image = single_control_image.chunk(2)[0] + + if self.do_perturbed_attention_guidance: + single_control_image = self._prepare_perturbed_attention_guidance( + single_control_image, single_control_image, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + single_control_image = torch.cat([single_control_image] * 2) + single_control_image = single_control_image.to(device) + control_images[i] = single_control_image + + control_image = control_images if isinstance(control_image, list) else control_images[0] + controlnet_prompt_embeds = prompt_embeds + + # 7.4 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + + # 7.5 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + if self.do_perturbed_attention_guidance: + original_attn_proc = self.unet.attn_processors + self._set_pag_attn_processor( + pag_applied_layers=self.pag_applied_layers, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # controlnet(s) inference + control_model_input = latent_model_input + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=control_image, + conditioning_scale=cond_scale, + guess_mode=False, + return_dict=False, + ) + + # concat latents, mask, masked_image_latents in the channel dimension + if num_channels_unet == 9: + first_dim_size = latent_model_input.shape[0] + # Ensure mask and masked_image_latents have the right dimensions + if mask.shape[0] < first_dim_size: + repeat_factor = (first_dim_size + mask.shape[0] - 1) // mask.shape[0] + mask = mask.repeat(repeat_factor, 1, 1, 1)[:first_dim_size] + if masked_image_latents.shape[0] < first_dim_size: + repeat_factor = ( + first_dim_size + masked_image_latents.shape[0] - 1 + ) // masked_image_latents.shape[0] + masked_image_latents = masked_image_latents.repeat(repeat_factor, 1, 1, 1)[:first_dim_size] + # Perform the concatenation + latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + + # Predict noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_perturbed_attention_guidance: + noise_pred = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t + ) + elif self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if num_channels_unet == 4: + init_latents_proper = image_latents + if self.do_classifier_free_guidance: + init_mask, _ = mask.chunk(2) + else: + init_mask = mask + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.add_noise( + init_latents_proper, noise, torch.tensor([noise_timestep]) + ) + + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + # If we do sequential model offloading, let's offload unet and controlnet + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + self.controlnet.to("cpu") + empty_device_cache() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + if padding_mask_crop is not None: + image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image] + + # Offload all models + self.maybe_free_model_hooks() + + if self.do_perturbed_attention_guidance: + self.unet.set_attn_processor(original_attn_proc) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py b/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py new file mode 100644 index 0000000000000000000000000000000000000000..b0e2c03faed76357f1d691bcd500c065f81a927d --- /dev/null +++ b/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py @@ -0,0 +1,1621 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import inspect +from typing import Any, Callable + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from diffusers.utils.import_utils import is_invisible_watermark_available + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from ...models import AutoencoderKL, ControlNetModel, ImageProjection, MultiControlNetModel, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput +from .pag_utils import PAGMixin + + +if is_invisible_watermark_available(): + from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker + + +from ...utils import is_torch_xla_available + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # !pip install opencv-python transformers accelerate + >>> from diffusers import AutoPipelineForText2Image, ControlNetModel, AutoencoderKL + >>> from diffusers.utils import load_image + >>> import numpy as np + >>> import torch + + >>> import cv2 + >>> from PIL import Image + + >>> prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting" + >>> negative_prompt = "low quality, bad quality, sketches" + + >>> # download an image + >>> image = load_image( + ... "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png" + ... ) + + >>> # initialize the models and pipeline + >>> controlnet_conditioning_scale = 0.5 # recommended for good generalization + >>> controlnet = ControlNetModel.from_pretrained( + ... "diffusers/controlnet-canny-sdxl-1.0", torch_dtype=torch.float16 + ... ) + >>> vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) + >>> pipe = AutoPipelineForText2Image.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-1.0", + ... controlnet=controlnet, + ... vae=vae, + ... torch_dtype=torch.float16, + ... enable_pag=True, + ... ) + >>> pipe.enable_model_cpu_offload() + + >>> # get canny image + >>> image = np.array(image) + >>> image = cv2.Canny(image, 100, 200) + >>> image = image[:, :, None] + >>> image = np.concatenate([image, image, image], axis=2) + >>> canny_image = Image.fromarray(image) + + >>> # generate image + >>> image = pipe( + ... prompt, controlnet_conditioning_scale=controlnet_conditioning_scale, image=canny_image, pag_scale=0.3 + ... ).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusionXLControlNetPAGPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionXLLoraLoaderMixin, + IPAdapterMixin, + FromSingleFileMixin, + PAGMixin, +): + r""" + Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + text_encoder_2 ([`~transformers.CLIPTextModelWithProjection`]): + Second frozen text-encoder + ([laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + tokenizer_2 ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + controlnet ([`ControlNetModel`] or `list[ControlNetModel]`): + Provides additional conditioning to the `unet` during the denoising process. If you set multiple + ControlNets as a list, the outputs from each ControlNet are added together to create one combined + additional conditioning. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): + Whether the negative prompt embeddings should always be set to 0. Also see the config of + `stabilityai/stable-diffusion-xl-base-1-0`. + add_watermarker (`bool`, *optional*): + Whether to use the [invisible_watermark](https://github.com/ShieldMnt/invisible-watermark/) library to + watermark output images. If not defined, it defaults to `True` if the package is installed; otherwise no + watermarker is used. + """ + + # leave controlnet out on purpose because it iterates with unet + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "feature_extractor", + "image_encoder", + ] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "add_text_embeds", + "add_time_ids", + "negative_pooled_prompt_embeds", + "negative_add_time_ids", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + controlnet: ControlNetModel | list[ControlNetModel] | tuple[ControlNetModel] | MultiControlNetModel, + scheduler: KarrasDiffusionSchedulers, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: bool | None = None, + feature_extractor: CLIPImageProcessor = None, + image_encoder: CLIPVisionModelWithProjection = None, + pag_applied_layers: str | list[str] = "mid", # ["down.block_2", "up.block_1.attentions_0"], "mid" + ): + super().__init__() + + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetModel(controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None + + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.set_pag_applied_layers(pag_applied_layers) + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt: str, + prompt_2: str | None = None, + device: torch.device | None = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: str | None = None, + negative_prompt_2: str | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2: + pooled_prompt_embeds = prompt_embeds[0] + + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: list[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2: + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.check_inputs + def check_inputs( + self, + prompt, + prompt_2, + image, + callback_steps, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + negative_pooled_prompt_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, + ): + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + # `prompt` needs more sophisticated handling when there are multiple + # conditionings. + if isinstance(self.controlnet, MultiControlNetModel): + if isinstance(prompt, list): + logger.warning( + f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" + " prompts. The conditionings will be fixed across the prompts." + ) + + # Check `image` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + self.check_image(image, prompt, prompt_embeds) + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if not isinstance(image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + + # When `image` is a nested list: + # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) + elif any(isinstance(i, list) for i in image): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif len(image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." + ) + + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if isinstance(controlnet_conditioning_scale, list): + if any(isinstance(i, list) for i in controlnet_conditioning_scale): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + assert False + + if not isinstance(control_guidance_start, (tuple, list)): + control_guidance_start = [control_guidance_start] + + if not isinstance(control_guidance_end, (tuple, list)): + control_guidance_end = [control_guidance_end] + + if len(control_guidance_start) != len(control_guidance_end): + raise ValueError( + f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." + ) + + if isinstance(self.controlnet, MultiControlNetModel): + if len(control_guidance_start) != len(self.controlnet.nets): + raise ValueError( + f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." + ) + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + deprecate( + "upcast_vae", + "1.0.0", + "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.", + ) + self.vae.to(dtype=torch.float32) + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def denoising_end(self): + return self._denoising_end + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + prompt_2: str | list[str] | None = None, + image: PipelineImageInput = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + timesteps: list[int] = None, + sigmas: list[float] = None, + denoising_end: float | None = None, + guidance_scale: float = 5.0, + negative_prompt: str | list[str] | None = None, + negative_prompt_2: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + controlnet_conditioning_scale: float | list[float] = 1.0, + control_guidance_start: float | list[float] = 0.0, + control_guidance_end: float | list[float] = 1.0, + original_size: tuple[int, int] = None, + crops_coords_top_left: tuple[int, int] = (0, 0), + target_size: tuple[int, int] = None, + negative_original_size: tuple[int, int] | None = None, + negative_crops_coords_top_left: tuple[int, int] = (0, 0), + negative_target_size: tuple[int, int] | None = None, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + pag_scale: float = 3.0, + pag_adaptive_scale: float = 0.0, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, `list[np.ndarray]`,: + `list[list[torch.Tensor]]`, `list[list[np.ndarray]]` or `list[list[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 5.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. This is sent to `tokenizer_2` + and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, pooled text embeddings are generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt + weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input + argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `list[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + control_guidance_start (`float` or `list[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `list[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + pag_scale (`float`, *optional*, defaults to 3.0): + The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention + guidance will not be used. + pag_adaptive_scale (`float`, *optional*, defaults to 0.0): + The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is + used. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned containing the output images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + image, + None, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + negative_pooled_prompt_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._denoising_end = denoising_end + self._pag_scale = pag_scale + self._pag_adaptive_scale = pag_adaptive_scale + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + # 3.1 Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt, + prompt_2, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + # 3.2 Encode ip_adapter_image + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 4. Prepare image + if isinstance(controlnet, ControlNetModel): + image = self.prepare_image( + image=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=False, + ) + height, width = image.shape[-2:] + elif isinstance(controlnet, MultiControlNetModel): + images = [] + + for image_ in image: + image_ = self.prepare_image( + image=image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=False, + ) + + images.append(image_) + + image = images + height, width = image[0].shape[-2:] + else: + assert False + + # 5. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas + ) + self._num_timesteps = len(timesteps) + + # 6. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6.1 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + + # 7.2 Prepare added time ids & embeddings + if isinstance(image, list): + original_size = original_size or image[0].shape[-2:] + else: + original_size = original_size or image.shape[-2:] + target_size = target_size or (height, width) + + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids + + images = image if isinstance(image, list) else [image] + for i, single_image in enumerate(images): + if self.do_classifier_free_guidance: + single_image = single_image.chunk(2)[0] + + if self.do_perturbed_attention_guidance: + single_image = self._prepare_perturbed_attention_guidance( + single_image, single_image, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + single_image = torch.cat([single_image] * 2) + single_image = single_image.to(device) + images[i] = single_image + + image = images if isinstance(image, list) else images[0] + + if ip_adapter_image_embeds is not None: + for i, image_embeds in enumerate(ip_adapter_image_embeds): + negative_image_embeds = None + if self.do_classifier_free_guidance: + negative_image_embeds, image_embeds = image_embeds.chunk(2) + + if self.do_perturbed_attention_guidance: + image_embeds = self._prepare_perturbed_attention_guidance( + image_embeds, negative_image_embeds, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) + image_embeds = image_embeds.to(device) + ip_adapter_image_embeds[i] = image_embeds + + if self.do_perturbed_attention_guidance: + prompt_embeds = self._prepare_perturbed_attention_guidance( + prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance + ) + add_text_embeds = self._prepare_perturbed_attention_guidance( + add_text_embeds, negative_pooled_prompt_embeds, self.do_classifier_free_guidance + ) + add_time_ids = self._prepare_perturbed_attention_guidance( + add_time_ids, negative_add_time_ids, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + controlnet_prompt_embeds = prompt_embeds + controlnet_added_cond_kwargs = added_cond_kwargs + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + # 8.1 Apply denoising_end + if ( + self.denoising_end is not None + and isinstance(self.denoising_end, float) + and self.denoising_end > 0 + and self.denoising_end < 1 + ): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (self.denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + if self.do_perturbed_attention_guidance: + original_attn_proc = self.unet.attn_processors + self._set_pag_attn_processor( + pag_applied_layers=self.pag_applied_layers, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + + is_unet_compiled = is_compiled_module(self.unet) + is_controlnet_compiled = is_compiled_module(self.controlnet) + is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Relevant thread: + # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 + if ( + torch.cuda.is_available() + and (is_unet_compiled and is_controlnet_compiled) + and is_torch_higher_equal_2_1 + ): + torch._inductor.cudagraph_mark_step_begin() + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # controlnet(s) inference + control_model_input = latent_model_input + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=image, + conditioning_scale=cond_scale, + guess_mode=False, + added_cond_kwargs=controlnet_added_cond_kwargs, + return_dict=False, + ) + + if ip_adapter_image_embeds is not None: + added_cond_kwargs["image_embeds"] = ip_adapter_image_embeds + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_perturbed_attention_guidance: + noise_pred = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t + ) + elif self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) + negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None + has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None + if has_latents_mean and has_latents_std: + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean + else: + latents = latents / self.vae.config.scaling_factor + + image = self.vae.decode(latents, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + if not output_type == "latent": + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if self.do_perturbed_attention_guidance: + self.unet.set_attn_processor(original_attn_proc) + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py b/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py new file mode 100644 index 0000000000000000000000000000000000000000..8967e50251b9b89e544c45130c2390aa86a7e649 --- /dev/null +++ b/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py @@ -0,0 +1,1685 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import inspect +from typing import Any, Callable + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from diffusers.utils.import_utils import is_invisible_watermark_available + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from ...models import AutoencoderKL, ControlNetModel, ImageProjection, MultiControlNetModel, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput +from .pag_utils import PAGMixin + + +if is_invisible_watermark_available(): + from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker + + +from ...utils import is_torch_xla_available + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # pip install accelerate transformers safetensors diffusers + + >>> import torch + >>> import numpy as np + >>> from PIL import Image + + >>> from transformers import DPTFeatureExtractor, DPTForDepthEstimation + >>> from diffusers import ControlNetModel, StableDiffusionXLControlNetPAGImg2ImgPipeline, AutoencoderKL + >>> from diffusers.utils import load_image + + >>> depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda") + >>> feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas") + >>> controlnet = ControlNetModel.from_pretrained( + ... "diffusers/controlnet-depth-sdxl-1.0-small", + ... variant="fp16", + ... use_safetensors="True", + ... torch_dtype=torch.float16, + ... ) + >>> vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) + >>> pipe = StableDiffusionXLControlNetPAGImg2ImgPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-1.0", + ... controlnet=controlnet, + ... vae=vae, + ... variant="fp16", + ... use_safetensors=True, + ... torch_dtype=torch.float16, + ... enable_pag=True, + ... ) + >>> pipe.enable_model_cpu_offload() + + + >>> def get_depth_map(image): + ... image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda") + ... with torch.no_grad(), torch.autocast("cuda"): + ... depth_map = depth_estimator(image).predicted_depth + + ... depth_map = torch.nn.functional.interpolate( + ... depth_map.unsqueeze(1), + ... size=(1024, 1024), + ... mode="bicubic", + ... align_corners=False, + ... ) + ... depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True) + ... depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True) + ... depth_map = (depth_map - depth_min) / (depth_max - depth_min) + ... image = torch.cat([depth_map] * 3, dim=1) + ... image = image.permute(0, 2, 3, 1).cpu().numpy()[0] + ... image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8)) + ... return image + + + >>> prompt = "A robot, 4k photo" + >>> image = load_image( + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + ... "/kandinsky/cat.png" + ... ).resize((1024, 1024)) + >>> controlnet_conditioning_scale = 0.5 # recommended for good generalization + >>> depth_image = get_depth_map(image) + + >>> images = pipe( + ... prompt, + ... image=image, + ... control_image=depth_image, + ... strength=0.99, + ... num_inference_steps=50, + ... controlnet_conditioning_scale=controlnet_conditioning_scale, + ... ).images + >>> images[0].save(f"robot_cat.png") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class StableDiffusionXLControlNetPAGImg2ImgPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionXLLoraLoaderMixin, + FromSingleFileMixin, + IPAdapterMixin, + PAGMixin, +): + r""" + Pipeline for image-to-image generation using Stable Diffusion XL with ControlNet guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([` CLIPTextModelWithProjection`]): + Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + controlnet ([`ControlNetModel`] or `list[ControlNetModel]`): + Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets + as a list, the outputs from each ControlNet are added together to create one combined additional + conditioning. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`): + Whether the `unet` requires an `aesthetic_score` condition to be passed during inference. Also see the + config of `stabilityai/stable-diffusion-xl-refiner-1-0`. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): + Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of + `stabilityai/stable-diffusion-xl-base-1-0`. + add_watermarker (`bool`, *optional*): + Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to + watermark output images. If not defined, it will default to True if the package is installed, otherwise no + watermarker will be used. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "feature_extractor", + "image_encoder", + ] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "add_text_embeds", + "add_time_ids", + "negative_pooled_prompt_embeds", + "add_neg_time_ids", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + controlnet: ControlNetModel | list[ControlNetModel] | tuple[ControlNetModel] | MultiControlNetModel, + scheduler: KarrasDiffusionSchedulers, + requires_aesthetics_score: bool = False, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: bool | None = None, + feature_extractor: CLIPImageProcessor = None, + image_encoder: CLIPVisionModelWithProjection = None, + pag_applied_layers: str | list[str] = "mid", # ["mid"], ["down.block_1", "up.block_0.attentions_0"] + ): + super().__init__() + + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetModel(controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None + + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.register_to_config(requires_aesthetics_score=requires_aesthetics_score) + + self.set_pag_applied_layers(pag_applied_layers) + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt: str, + prompt_2: str | None = None, + device: torch.device | None = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: str | None = None, + negative_prompt_2: str | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2: + pooled_prompt_embeds = prompt_embeds[0] + + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: list[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2: + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl_img2img.StableDiffusionXLControlNetImg2ImgPipeline.check_inputs + def check_inputs( + self, + prompt, + prompt_2, + image, + strength, + num_inference_steps, + callback_steps, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + if num_inference_steps is None: + raise ValueError("`num_inference_steps` cannot be None.") + elif not isinstance(num_inference_steps, int) or num_inference_steps <= 0: + raise ValueError( + f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type" + f" {type(num_inference_steps)}." + ) + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + # `prompt` needs more sophisticated handling when there are multiple + # conditionings. + if isinstance(self.controlnet, MultiControlNetModel): + if isinstance(prompt, list): + logger.warning( + f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" + " prompts. The conditionings will be fixed across the prompts." + ) + + # Check `image` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + self.check_image(image, prompt, prompt_embeds) + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if not isinstance(image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + + # When `image` is a nested list: + # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) + elif any(isinstance(i, list) for i in image): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif len(image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." + ) + + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if isinstance(controlnet_conditioning_scale, list): + if any(isinstance(i, list) for i in controlnet_conditioning_scale): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + assert False + + if not isinstance(control_guidance_start, (tuple, list)): + control_guidance_start = [control_guidance_start] + + if not isinstance(control_guidance_end, (tuple, list)): + control_guidance_end = [control_guidance_end] + + if len(control_guidance_start) != len(control_guidance_end): + raise ValueError( + f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." + ) + + if isinstance(self.controlnet, MultiControlNetModel): + if len(control_guidance_start) != len(self.controlnet.nets): + raise ValueError( + f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." + ) + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.check_image + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image + def prepare_control_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents + def prepare_latents( + self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True + ): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + latents_mean = latents_std = None + if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: + latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: + latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) + + # Offload text encoder if `enable_model_cpu_offload` was enabled + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.text_encoder_2.to("cpu") + empty_device_cache() + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + + else: + # make sure the VAE is in float32 mode, as it overflows in float16 + if self.vae.config.force_upcast: + image = image.float() + self.vae.to(dtype=torch.float32) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: + image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) + elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " + ) + + init_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + if self.vae.config.force_upcast: + self.vae.to(dtype) + + init_latents = init_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=device, dtype=dtype) + latents_std = latents_std.to(device=device, dtype=dtype) + init_latents = (init_latents - latents_mean) * self.vae.config.scaling_factor / latents_std + else: + init_latents = self.vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + if add_noise: + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + + latents = init_latents + + return latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids + def _get_add_time_ids( + self, + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype, + text_encoder_projection_dim=None, + ): + if self.config.requires_aesthetics_score: + add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) + add_neg_time_ids = list( + negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,) + ) + else: + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if ( + expected_add_embed_dim > passed_add_embed_dim + and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." + ) + elif ( + expected_add_embed_dim < passed_add_embed_dim + and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." + ) + elif expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) + + return add_time_ids, add_neg_time_ids + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + deprecate( + "upcast_vae", + "1.0.0", + "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.", + ) + self.vae.to(dtype=torch.float32) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + prompt_2: str | list[str] | None = None, + image: PipelineImageInput = None, + control_image: PipelineImageInput = None, + height: int | None = None, + width: int | None = None, + strength: float = 0.8, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + negative_prompt: str | list[str] | None = None, + negative_prompt_2: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + controlnet_conditioning_scale: float | list[float] = 0.8, + guess_mode: bool = False, + control_guidance_start: float | list[float] = 0.0, + control_guidance_end: float | list[float] = 1.0, + original_size: tuple[int, int] = None, + crops_coords_top_left: tuple[int, int] = (0, 0), + target_size: tuple[int, int] = None, + negative_original_size: tuple[int, int] | None = None, + negative_crops_coords_top_left: tuple[int, int] = (0, 0), + negative_target_size: tuple[int, int] | None = None, + aesthetic_score: float = 6.0, + negative_aesthetic_score: float = 2.5, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + pag_scale: float = 3.0, + pag_adaptive_scale: float = 0.0, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, `list[np.ndarray]`,: + `list[list[torch.Tensor]]`, `list[list[np.ndarray]]` or `list[list[PIL.Image.Image]]`): + The initial image will be used as the starting point for the image generation process. Can also accept + image latents as `image`, if passing latents directly, it will not be encoded again. + control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, `list[np.ndarray]`,: + `list[list[torch.Tensor]]`, `list[list[np.ndarray]]` or `list[list[PIL.Image.Image]]`): + The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If + the type is specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also + be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height + and/or width are passed, `image` is resized according to them. If multiple ControlNets are specified in + init, images must be passed as a list such that each element of the list can be correctly batched for + input to a single controlnet. + height (`int`, *optional*, defaults to the size of control_image): + The height in pixels of the generated image. Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to the size of control_image): + The width in pixels of the generated image. Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + strength (`float`, *optional*, defaults to 0.8): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `list[float]`, *optional*, defaults to 1.0): + The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original unet. If multiple ControlNets are specified in init, you can set the + corresponding scale as a list. + guess_mode (`bool`, *optional*, defaults to `False`): + In this mode, the ControlNet encoder will try best to recognize the content of the input image even if + you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended. + control_guidance_start (`float` or `list[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the controlnet starts applying. + control_guidance_end (`float` or `list[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the controlnet stops applying. + original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + aesthetic_score (`float`, *optional*, defaults to 6.0): + Used to simulate an aesthetic score of the generated image by influencing the positive text condition. + Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_aesthetic_score (`float`, *optional*, defaults to 2.5): + Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to + simulate an aesthetic score of the generated image by influencing the negative text condition. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + pag_scale (`float`, *optional*, defaults to 3.0): + The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention + guidance will not be used. + pag_adaptive_scale (`float`, *optional*, defaults to 0.0): + The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is + used. + + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple` containing the output images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + control_image, + strength, + num_inference_steps, + None, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._pag_scale = pag_scale + self._pag_adaptive_scale = pag_adaptive_scale + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + # 3.1 Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt, + prompt_2, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + # 3.2 Encode ip_adapter_image + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 4. Prepare image and controlnet_conditioning_image + image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + + if isinstance(controlnet, ControlNetModel): + control_image = self.prepare_control_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=False, + ) + height, width = control_image.shape[-2:] + elif isinstance(controlnet, MultiControlNetModel): + control_images = [] + + for control_image_ in control_image: + control_image_ = self.prepare_control_image( + image=control_image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=False, + ) + + control_images.append(control_image_) + + control_image = control_images + height, width = control_image[0].shape[-2:] + else: + assert False + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + self._num_timesteps = len(timesteps) + + # 6. Prepare latent variables + if latents is None: + latents = self.prepare_latents( + image, + latent_timestep, + batch_size, + num_images_per_prompt, + prompt_embeds.dtype, + device, + generator, + True, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + + # 7.2 Prepare added time ids & embeddings + if isinstance(control_image, list): + original_size = original_size or control_image[0].shape[-2:] + else: + original_size = original_size or control_image.shape[-2:] + target_size = target_size or (height, width) + + if negative_original_size is None: + negative_original_size = original_size + if negative_target_size is None: + negative_target_size = target_size + add_text_embeds = pooled_prompt_embeds + + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids, add_neg_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) + add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1) + + control_images = control_image if isinstance(control_image, list) else [control_image] + for i, single_image in enumerate(control_images): + if self.do_classifier_free_guidance: + single_image = single_image.chunk(2)[0] + + if self.do_perturbed_attention_guidance: + single_image = self._prepare_perturbed_attention_guidance( + single_image, single_image, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + single_image = torch.cat([single_image] * 2) + single_image = single_image.to(device) + control_images[i] = single_image + + control_image = control_images if isinstance(control_image, list) else control_images[0] + + if ip_adapter_image_embeds is not None: + for i, image_embeds in enumerate(ip_adapter_image_embeds): + negative_image_embeds = None + if self.do_classifier_free_guidance: + negative_image_embeds, image_embeds = image_embeds.chunk(2) + + if self.do_perturbed_attention_guidance: + image_embeds = self._prepare_perturbed_attention_guidance( + image_embeds, negative_image_embeds, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) + image_embeds = image_embeds.to(device) + ip_adapter_image_embeds[i] = image_embeds + + if self.do_perturbed_attention_guidance: + prompt_embeds = self._prepare_perturbed_attention_guidance( + prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance + ) + add_text_embeds = self._prepare_perturbed_attention_guidance( + add_text_embeds, negative_pooled_prompt_embeds, self.do_classifier_free_guidance + ) + add_time_ids = self._prepare_perturbed_attention_guidance( + add_time_ids, add_neg_time_ids, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device) + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + controlnet_prompt_embeds = prompt_embeds + controlnet_added_cond_kwargs = added_cond_kwargs + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + if self.do_perturbed_attention_guidance: + original_attn_proc = self.unet.attn_processors + self._set_pag_attn_processor( + pag_applied_layers=self.pag_applied_layers, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # controlnet(s) inference + control_model_input = latent_model_input + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=control_image, + conditioning_scale=cond_scale, + guess_mode=False, + added_cond_kwargs=controlnet_added_cond_kwargs, + return_dict=False, + ) + + if ip_adapter_image_embeds is not None: + added_cond_kwargs["image_embeds"] = ip_adapter_image_embeds + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=self.cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_perturbed_attention_guidance: + noise_pred = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t + ) + elif self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) + add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + # If we do sequential model offloading, let's offload unet and controlnet + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + self.controlnet.to("cpu") + empty_device_cache() + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None + has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None + if has_latents_mean and has_latents_std: + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean + else: + latents = latents / self.vae.config.scaling_factor + + image = self.vae.decode(latents, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + return StableDiffusionXLPipelineOutput(images=image) + + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if self.do_perturbed_attention_guidance: + self.unet.set_attn_processor(original_attn_proc) + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py b/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py new file mode 100644 index 0000000000000000000000000000000000000000..15ac665acd2b0f3ff75886f21695b9af087e513f --- /dev/null +++ b/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py @@ -0,0 +1,964 @@ +# Copyright 2025 HunyuanDiT Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable + +import numpy as np +import torch +from transformers import BertModel, BertTokenizer, CLIPImageProcessor, T5EncoderModel, T5Tokenizer + +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKL, HunyuanDiT2DModel +from ...models.attention_processor import PAGCFGHunyuanAttnProcessor2_0, PAGHunyuanAttnProcessor2_0 +from ...models.embeddings import get_2d_rotary_pos_embed +from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from ...schedulers import DDPMScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pag_utils import PAGMixin + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import AutoPipelineForText2Image + + >>> pipe = AutoPipelineForText2Image.from_pretrained( + ... "Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers", + ... torch_dtype=torch.float16, + ... enable_pag=True, + ... pag_applied_layers=[14], + ... ).to("cuda") + + >>> # prompt = "an astronaut riding a horse" + >>> prompt = "一个宇航员在骑马" + >>> image = pipe(prompt, guidance_scale=4, pag_scale=3).images[0] + ``` +""" + +STANDARD_RATIO = np.array( + [ + 1.0, # 1:1 + 4.0 / 3.0, # 4:3 + 3.0 / 4.0, # 3:4 + 16.0 / 9.0, # 16:9 + 9.0 / 16.0, # 9:16 + ] +) +STANDARD_SHAPE = [ + [(1024, 1024), (1280, 1280)], # 1:1 + [(1024, 768), (1152, 864), (1280, 960)], # 4:3 + [(768, 1024), (864, 1152), (960, 1280)], # 3:4 + [(1280, 768)], # 16:9 + [(768, 1280)], # 9:16 +] +STANDARD_AREA = [np.array([w * h for w, h in shapes]) for shapes in STANDARD_SHAPE] +SUPPORTED_SHAPE = [ + (1024, 1024), + (1280, 1280), # 1:1 + (1024, 768), + (1152, 864), + (1280, 960), # 4:3 + (768, 1024), + (864, 1152), + (960, 1280), # 3:4 + (1280, 768), # 16:9 + (768, 1280), # 9:16 +] + + +def map_to_standard_shapes(target_width, target_height): + target_ratio = target_width / target_height + closest_ratio_idx = np.argmin(np.abs(STANDARD_RATIO - target_ratio)) + closest_area_idx = np.argmin(np.abs(STANDARD_AREA[closest_ratio_idx] - target_width * target_height)) + width, height = STANDARD_SHAPE[closest_ratio_idx][closest_area_idx] + return width, height + + +def get_resize_crop_region_for_grid(src, tgt_size): + th = tw = tgt_size + h, w = src + + r = h / w + + # resize + if r > 1: + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class HunyuanDiTPAGPipeline(DiffusionPipeline, PAGMixin): + r""" + Pipeline for English/Chinese-to-image generation using HunyuanDiT and [Perturbed Attention + Guidance](https://huggingface.co/docs/diffusers/en/using-diffusers/pag). + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + HunyuanDiT uses two text encoders: [mT5](https://huggingface.co/google/mt5-base) and [bilingual CLIP](fine-tuned by + ourselves) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. We use + `sdxl-vae-fp16-fix`. + text_encoder (`~transformers.BertModel`, `~transformers.CLIPTextModel` | None): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + HunyuanDiT uses a fine-tuned [bilingual CLIP]. + tokenizer (`~transformers.BertTokenizer`, `~transformers.CLIPTokenizer` | None): + A `BertTokenizer` or `CLIPTokenizer` to tokenize text. + transformer ([`HunyuanDiT2DModel`]): + The HunyuanDiT model designed by Tencent Hunyuan. + text_encoder_2 (`T5EncoderModel`): + The mT5 embedder. Specifically, it is 't5-v1_1-xxl'. + tokenizer_2 (`T5Tokenizer`): + The tokenizer for the mT5 embedder. + scheduler ([`DDPMScheduler`]): + A scheduler to be used in combination with HunyuanDiT to denoise the encoded image latents. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _optional_components = [ + "safety_checker", + "feature_extractor", + "text_encoder_2", + "tokenizer_2", + "text_encoder", + "tokenizer", + ] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "prompt_embeds_2", + "negative_prompt_embeds_2", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: BertModel, + tokenizer: BertTokenizer, + transformer: HunyuanDiT2DModel, + scheduler: DDPMScheduler, + safety_checker: StableDiffusionSafetyChecker | None = None, + feature_extractor: CLIPImageProcessor | None = None, + requires_safety_checker: bool = True, + text_encoder_2: T5EncoderModel | None = None, + tokenizer_2: T5Tokenizer | None = None, + pag_applied_layers: str | list[str] = "blocks.1", # "blocks.16.attn1", "blocks.16", "16", 16 + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + text_encoder_2=text_encoder_2, + ) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + self.default_sample_size = ( + self.transformer.config.sample_size + if hasattr(self, "transformer") and self.transformer is not None + else 128 + ) + + self.set_pag_applied_layers( + pag_applied_layers, pag_attn_processors=(PAGCFGHunyuanAttnProcessor2_0(), PAGHunyuanAttnProcessor2_0()) + ) + + # Copied from diffusers.pipelines.hunyuandit.pipeline_hunyuandit.HunyuanDiTPipeline.encode_prompt + def encode_prompt( + self, + prompt: str, + device: torch.device = None, + dtype: torch.dtype = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: str | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + max_sequence_length: int | None = None, + text_encoder_index: int = 0, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + dtype (`torch.dtype`): + torch dtype + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the prompt. Required when `prompt_embeds` is passed directly. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly. + max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt. + text_encoder_index (`int`, *optional*): + Index of the text encoder to use. `0` for clip and `1` for T5. + """ + if dtype is None: + if self.text_encoder_2 is not None: + dtype = self.text_encoder_2.dtype + elif self.transformer is not None: + dtype = self.transformer.dtype + else: + dtype = None + + if device is None: + device = self._execution_device + + tokenizers = [self.tokenizer, self.tokenizer_2] + text_encoders = [self.text_encoder, self.text_encoder_2] + + tokenizer = tokenizers[text_encoder_index] + text_encoder = text_encoders[text_encoder_index] + + if max_sequence_length is None: + if text_encoder_index == 0: + max_length = 77 + if text_encoder_index == 1: + max_length = 256 + else: + max_length = max_sequence_length + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_attention_mask = text_inputs.attention_mask.to(device) + prompt_embeds = text_encoder( + text_input_ids.to(device), + attention_mask=prompt_attention_mask, + ) + prompt_embeds = prompt_embeds[0] + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_attention_mask = uncond_input.attention_mask.to(device) + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + attention_mask=negative_prompt_attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.hunyuandit.pipeline_hunyuandit.HunyuanDiTPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + prompt_embeds_2=None, + negative_prompt_embeds_2=None, + prompt_attention_mask_2=None, + negative_prompt_attention_mask_2=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is None and prompt_embeds_2 is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if prompt_embeds_2 is not None and prompt_attention_mask_2 is None: + raise ValueError("Must provide `prompt_attention_mask_2` when specifying `prompt_embeds_2`.") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if negative_prompt_embeds_2 is not None and negative_prompt_attention_mask_2 is None: + raise ValueError( + "Must provide `negative_prompt_attention_mask_2` when specifying `negative_prompt_embeds_2`." + ) + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_embeds_2 is not None and negative_prompt_embeds_2 is not None: + if prompt_embeds_2.shape != negative_prompt_embeds_2.shape: + raise ValueError( + "`prompt_embeds_2` and `negative_prompt_embeds_2` must have the same shape when passed directly, but" + f" got: `prompt_embeds_2` {prompt_embeds_2.shape} != `negative_prompt_embeds_2`" + f" {negative_prompt_embeds_2.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_2: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds_2: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + prompt_attention_mask_2: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask_2: torch.Tensor | None = None, + output_type: str = "pil", + return_dict: bool = True, + callback_on_step_end: Callable[[int, int, dict], None] + | PipelineCallback + | MultiPipelineCallbacks + | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + guidance_rescale: float = 0.0, + original_size: tuple[int, int] = (1024, 1024), + target_size: tuple[int, int] = None, + crops_coords_top_left: tuple[int, int] = (0, 0), + use_resolution_binning: bool = True, + pag_scale: float = 3.0, + pag_adaptive_scale: float = 0.0, + ): + r""" + The call function to the pipeline for generation with HunyuanDiT. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter is modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + prompt_embeds_2 (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + negative_prompt_embeds_2 (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the prompt. Required when `prompt_embeds` is passed directly. + prompt_attention_mask_2 (`torch.Tensor`, *optional*): + Attention mask for the prompt. Required when `prompt_embeds_2` is passed directly. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly. + negative_prompt_attention_mask_2 (`torch.Tensor`, *optional*): + Attention mask for the negative prompt. Required when `negative_prompt_embeds_2` is passed directly. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback_on_step_end (`Callable[[int, int], None]`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A callback function or a list of callback functions to be called at the end of each denoising step. + callback_on_step_end_tensor_inputs (`list[str]`, *optional*): + A list of tensor inputs that should be passed to the callback function. If not defined, all tensor + inputs will be passed. + guidance_rescale (`float`, *optional*, defaults to 0.0): + Rescale the noise_cfg according to `guidance_rescale`. Based on findings of [Common Diffusion Noise + Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). See Section 3.4 + original_size (`tuple[int, int]`, *optional*, defaults to `(1024, 1024)`): + The original size of the image. Used to calculate the time ids. + target_size (`tuple[int, int]`, *optional*): + The target size of the image. Used to calculate the time ids. + crops_coords_top_left (`tuple[int, int]`, *optional*, defaults to `(0, 0)`): + The top left coordinates of the crop. Used to calculate the time ids. + use_resolution_binning (`bool`, *optional*, defaults to `True`): + Whether to use resolution binning or not. If `True`, the input resolution will be mapped to the closest + standard resolution. Supported resolutions are 1024x1024, 1280x1280, 1024x768, 1152x864, 1280x960, + 768x1024, 864x1152, 960x1280, 1280x768, and 768x1280. It is recommended to set this to `True`. + pag_scale (`float`, *optional*, defaults to 3.0): + The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention + guidance will not be used. + pag_adaptive_scale (`float`, *optional*, defaults to 0.0): + The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is + used. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 0. Default height and width + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + height = int((height // 16) * 16) + width = int((width // 16) * 16) + + if use_resolution_binning and (height, width) not in SUPPORTED_SHAPE: + width, height = map_to_standard_shapes(width, height) + height = int(height) + width = int(width) + logger.warning(f"Reshaped to (height, width)=({height}, {width}), Supported shapes are {SUPPORTED_SHAPE}") + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + prompt_embeds_2, + negative_prompt_embeds_2, + prompt_attention_mask_2, + negative_prompt_attention_mask_2, + callback_on_step_end_tensor_inputs, + ) + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._interrupt = False + self._pag_scale = pag_scale + self._pag_adaptive_scale = pag_adaptive_scale + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + ( + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + device=device, + dtype=self.transformer.dtype, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=77, + text_encoder_index=0, + ) + ( + prompt_embeds_2, + negative_prompt_embeds_2, + prompt_attention_mask_2, + negative_prompt_attention_mask_2, + ) = self.encode_prompt( + prompt=prompt, + device=device, + dtype=self.transformer.dtype, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds_2, + negative_prompt_embeds=negative_prompt_embeds_2, + prompt_attention_mask=prompt_attention_mask_2, + negative_prompt_attention_mask=negative_prompt_attention_mask_2, + max_sequence_length=256, + text_encoder_index=1, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Create image_rotary_emb, style embedding & time ids + grid_height = height // 8 // self.transformer.config.patch_size + grid_width = width // 8 // self.transformer.config.patch_size + base_size = 512 // 8 // self.transformer.config.patch_size + grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size) + image_rotary_emb = get_2d_rotary_pos_embed( + self.transformer.inner_dim // self.transformer.num_heads, + grid_crops_coords, + (grid_height, grid_width), + device=device, + output_type="pt", + ) + + style = torch.tensor([0], device=device) + + target_size = target_size or (height, width) + add_time_ids = list(original_size + target_size + crops_coords_top_left) + add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_perturbed_attention_guidance: + prompt_embeds = self._prepare_perturbed_attention_guidance( + prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance + ) + prompt_attention_mask = self._prepare_perturbed_attention_guidance( + prompt_attention_mask, negative_prompt_attention_mask, self.do_classifier_free_guidance + ) + prompt_embeds_2 = self._prepare_perturbed_attention_guidance( + prompt_embeds_2, negative_prompt_embeds_2, self.do_classifier_free_guidance + ) + prompt_attention_mask_2 = self._prepare_perturbed_attention_guidance( + prompt_attention_mask_2, negative_prompt_attention_mask_2, self.do_classifier_free_guidance + ) + add_time_ids = torch.cat([add_time_ids] * 3, dim=0) + style = torch.cat([style] * 3, dim=0) + elif self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask]) + prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2]) + prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2]) + add_time_ids = torch.cat([add_time_ids] * 2, dim=0) + style = torch.cat([style] * 2, dim=0) + + prompt_embeds = prompt_embeds.to(device=device) + prompt_attention_mask = prompt_attention_mask.to(device=device) + prompt_embeds_2 = prompt_embeds_2.to(device=device) + prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device) + add_time_ids = add_time_ids.to(dtype=prompt_embeds.dtype, device=device).repeat( + batch_size * num_images_per_prompt, 1 + ) + style = style.to(device=device).repeat(batch_size * num_images_per_prompt) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + if self.do_perturbed_attention_guidance: + original_attn_proc = self.transformer.attn_processors + self._set_pag_attn_processor( + pag_applied_layers=self.pag_applied_layers, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input + t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to( + dtype=latent_model_input.dtype + ) + + # predict the noise residual + noise_pred = self.transformer( + latent_model_input, + t_expand, + encoder_hidden_states=prompt_embeds, + text_embedding_mask=prompt_attention_mask, + encoder_hidden_states_t5=prompt_embeds_2, + text_embedding_mask_t5=prompt_attention_mask_2, + image_meta_size=add_time_ids, + style=style, + image_rotary_emb=image_rotary_emb, + return_dict=False, + )[0] + + noise_pred, _ = noise_pred.chunk(2, dim=1) + + # perform guidance + if self.do_perturbed_attention_guidance: + noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t, True + ) + elif self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://huggingface.co/papers/2305.08891 + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + prompt_embeds_2 = callback_outputs.pop("prompt_embeds_2", prompt_embeds_2) + negative_prompt_embeds_2 = callback_outputs.pop( + "negative_prompt_embeds_2", negative_prompt_embeds_2 + ) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # 9. Offload all models + self.maybe_free_model_hooks() + + if self.do_perturbed_attention_guidance: + self.transformer.set_attn_processor(original_attn_proc) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/pipelines/pag/pipeline_pag_kolors.py b/diffusers/pipelines/pag/pipeline_pag_kolors.py new file mode 100644 index 0000000000000000000000000000000000000000..4f138d91d9c6d1529e9627925c9def05a3184f38 --- /dev/null +++ b/diffusers/pipelines/pag/pipeline_pag_kolors.py @@ -0,0 +1,1129 @@ +# Copyright 2025 Stability AI, Kwai-Kolors Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +from typing import Any, Callable + +import torch +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import IPAdapterMixin, StableDiffusionXLLoraLoaderMixin +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..kolors.pipeline_output import KolorsPipelineOutput +from ..kolors.text_encoder import ChatGLMModel +from ..kolors.tokenizer import ChatGLMTokenizer +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from .pag_utils import PAGMixin + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import AutoPipelineForText2Image + + >>> pipe = AutoPipelineForText2Image.from_pretrained( + ... "Kwai-Kolors/Kolors-diffusers", + ... variant="fp16", + ... torch_dtype=torch.float16, + ... enable_pag=True, + ... pag_applied_layers=["down.block_2.attentions_1", "up.block_0.attentions_1"], + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = ( + ... "A photo of a ladybug, macro, zoom, high quality, film, holding a wooden sign with the text 'KOLORS'" + ... ) + >>> image = pipe(prompt, guidance_scale=5.5, pag_scale=1.5).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class KolorsPAGPipeline( + DiffusionPipeline, StableDiffusionMixin, StableDiffusionXLLoraLoaderMixin, IPAdapterMixin, PAGMixin +): + r""" + Pipeline for text-to-image generation using Kolors. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + The pipeline also inherits the following loading methods: + - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`ChatGLMModel`]): + Frozen text-encoder. Kolors uses [ChatGLM3-6B](https://huggingface.co/THUDM/chatglm3-6b). + tokenizer (`ChatGLMTokenizer`): + Tokenizer of class + [ChatGLMTokenizer](https://huggingface.co/THUDM/chatglm3-6b/blob/main/tokenization_chatglm.py). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"False"`): + Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of + `Kwai-Kolors/Kolors-diffusers`. + pag_applied_layers (`str` or `list[str]``, *optional*, defaults to `"mid"`): + Set the transformer attention layers where to apply the perturbed attention guidance. Can be a string or a + list of strings with "down", "mid", "up", a whole transformer block or specific transformer block attention + layers, e.g.: + ["mid"] ["down", "mid"] ["down", "mid", "up.block_1"] ["down", "mid", "up.block_1.attentions_0", + "up.block_1.attentions_1"] + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = [ + "image_encoder", + "feature_extractor", + ] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "add_text_embeds", + "add_time_ids", + "negative_pooled_prompt_embeds", + "negative_add_time_ids", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: ChatGLMModel, + tokenizer: ChatGLMTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, + force_zeros_for_empty_prompt: bool = False, + pag_applied_layers: str | list[str] = "mid", + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + self.default_sample_size = ( + self.unet.config.sample_size + if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size") + else 128 + ) + + self.set_pag_applied_layers(pag_applied_layers) + + # Copied from diffusers.pipelines.kolors.pipeline_kolors.KolorsPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device: torch.device | None = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt=None, + prompt_embeds: torch.FloatTensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 256, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. + """ + # from IPython import embed; embed(); exit() + device = device or self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer] + text_encoders = [self.text_encoder] + + if prompt_embeds is None: + prompt_embeds_list = [] + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ).to(device) + output = text_encoder( + input_ids=text_inputs["input_ids"], + attention_mask=text_inputs["attention_mask"], + position_ids=text_inputs["position_ids"], + output_hidden_states=True, + ) + + # [max_sequence_length, batch, hidden_size] -> [batch, max_sequence_length, hidden_size] + # clone to have a contiguous tensor + prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone() + # [max_sequence_length, batch, hidden_size] -> [batch, hidden_size] + pooled_prompt_embeds = output.hidden_states[-1][-1, :, :].clone() + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = prompt_embeds_list[0] + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + negative_prompt_embeds_list = [] + + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + uncond_input = tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ).to(device) + output = text_encoder( + input_ids=uncond_input["input_ids"], + attention_mask=uncond_input["attention_mask"], + position_ids=uncond_input["position_ids"], + output_hidden_states=True, + ) + + # [max_sequence_length, batch, hidden_size] -> [batch, max_sequence_length, hidden_size] + # clone to have a contiguous tensor + negative_prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone() + # [max_sequence_length, batch, hidden_size] -> [batch, hidden_size] + negative_pooled_prompt_embeds = output.hidden_states[-1][-1, :, :].clone() + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = negative_prompt_embeds_list[0] + + bs_embed = pooled_prompt_embeds.shape[0] + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.kolors.pipeline_kolors.KolorsPipeline.check_inputs + def check_inputs( + self, + prompt, + num_inference_steps, + height, + width, + negative_prompt=None, + prompt_embeds=None, + pooled_prompt_embeds=None, + negative_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if not isinstance(num_inference_steps, int) or num_inference_steps <= 0: + raise ValueError( + f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type" + f" {type(num_inference_steps)}." + ) + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + if max_sequence_length is not None and max_sequence_length > 256: + raise ValueError(f"`max_sequence_length` cannot be greater than 256 but is {max_sequence_length}") + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae + def upcast_vae(self): + deprecate( + "upcast_vae", + "1.0.0", + "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.", + ) + self.vae.to(dtype=torch.float32) + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def denoising_end(self): + return self._denoising_end + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + timesteps: list[int] = None, + sigmas: list[float] = None, + denoising_end: float | None = None, + guidance_scale: float = 5.0, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + original_size: tuple[int, int] | None = None, + crops_coords_top_left: tuple[int, int] = (0, 0), + target_size: tuple[int, int] | None = None, + negative_original_size: tuple[int, int] | None = None, + negative_crops_coords_top_left: tuple[int, int] = (0, 0), + negative_target_size: tuple[int, int] | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + pag_scale: float = 3.0, + pag_adaptive_scale: float = 0.0, + max_sequence_length: int = 256, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [Kwai-Kolors/Kolors-diffusers](https://huggingface.co/Kwai-Kolors/Kolors-diffusers) and checkpoints + that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [Kwai-Kolors/Kolors-diffusers](https://huggingface.co/Kwai-Kolors/Kolors-diffusers) and checkpoints + that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.kolors.KolorsPipelineOutput`] instead of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + pag_scale (`float`, *optional*, defaults to 3.0): + The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention + guidance will not be used. + pag_adaptive_scale (`float`, *optional*, defaults to 0.0): + The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is + used. + max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.kolors.KolorsPipelineOutput`] or `tuple`: [`~pipelines.kolors.KolorsPipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 0. Default height and width to unet + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + num_inference_steps, + height, + width, + negative_prompt, + prompt_embeds, + pooled_prompt_embeds, + negative_prompt_embeds, + negative_pooled_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._cross_attention_kwargs = cross_attention_kwargs + self._denoising_end = denoising_end + self._interrupt = False + self._pag_scale = pag_scale + self._pag_adaptive_scale = pag_adaptive_scale + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + ) + + # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas + ) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids + + if self.do_perturbed_attention_guidance: + prompt_embeds = self._prepare_perturbed_attention_guidance( + prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance + ) + add_text_embeds = self._prepare_perturbed_attention_guidance( + add_text_embeds, negative_pooled_prompt_embeds, self.do_classifier_free_guidance + ) + add_time_ids = self._prepare_perturbed_attention_guidance( + add_time_ids, negative_add_time_ids, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + for i, image_embeds in enumerate(ip_adapter_image_embeds): + negative_image_embeds = None + if self.do_classifier_free_guidance: + negative_image_embeds, image_embeds = image_embeds.chunk(2) + + if self.do_perturbed_attention_guidance: + image_embeds = self._prepare_perturbed_attention_guidance( + image_embeds, negative_image_embeds, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) + image_embeds = image_embeds.to(device) + ip_adapter_image_embeds[i] = image_embeds + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # 8.1 Apply denoising_end + if ( + self.denoising_end is not None + and isinstance(self.denoising_end, float) + and self.denoising_end > 0 + and self.denoising_end < 1 + ): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (self.denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + # 9. Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + if self.do_perturbed_attention_guidance: + original_attn_proc = self.unet.attn_processors + self._set_pag_attn_processor( + pag_applied_layers=self.pag_applied_layers, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + added_cond_kwargs["image_embeds"] = image_embeds + + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_perturbed_attention_guidance: + noise_pred = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t + ) + elif self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) + negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + elif latents.dtype != self.vae.dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + self.vae = self.vae.to(latents.dtype) + + # unscale/denormalize the latents + latents = latents / self.vae.config.scaling_factor + + image = self.vae.decode(latents, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + if not output_type == "latent": + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if self.do_perturbed_attention_guidance: + self.unet.set_attn_processor(original_attn_proc) + + if not return_dict: + return (image,) + + return KolorsPipelineOutput(images=image) diff --git a/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py b/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py new file mode 100644 index 0000000000000000000000000000000000000000..2e7e8ab9814f890b6e3e2b8c0306da87c8a8d53a --- /dev/null +++ b/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py @@ -0,0 +1,882 @@ +# Copyright 2025 PixArt-Sigma Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import inspect +import re +import urllib.parse as ul +from typing import Callable + +import torch +from transformers import T5EncoderModel, T5Tokenizer + +from ...image_processor import PixArtImageProcessor +from ...models import AutoencoderKL, PixArtTransformer2DModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + BACKENDS_MAPPING, + deprecate, + is_bs4_available, + is_ftfy_available, + is_torch_xla_available, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from ..pixart_alpha.pipeline_pixart_alpha import ( + ASPECT_RATIO_256_BIN, + ASPECT_RATIO_512_BIN, + ASPECT_RATIO_1024_BIN, +) +from ..pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN +from .pag_utils import PAGMixin + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import AutoPipelineForText2Image + + >>> pipe = AutoPipelineForText2Image.from_pretrained( + ... "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", + ... torch_dtype=torch.float16, + ... pag_applied_layers=["blocks.14"], + ... enable_pag=True, + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = "A small cactus with a happy face in the Sahara desert" + >>> image = pipe(prompt, pag_scale=4.0, guidance_scale=1.0).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class PixArtSigmaPAGPipeline(DiffusionPipeline, PAGMixin): + r""" + [PAG pipeline](https://huggingface.co/docs/diffusers/main/en/using-diffusers/pag) for text-to-image generation + using PixArt-Sigma. + """ + + bad_punct_regex = re.compile( + r"[" + + "#®•©™&@·º½¾¿¡§~" + + r"\)" + + r"\(" + + r"\]" + + r"\[" + + r"\}" + + r"\{" + + r"\|" + + "\\" + + r"\/" + + r"\*" + + r"]{1,}" + ) # noqa + + _optional_components = ["tokenizer", "text_encoder"] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKL, + transformer: PixArtTransformer2DModel, + scheduler: KarrasDiffusionSchedulers, + pag_applied_layers: str | list[str] = "blocks.1", # 1st transformer block + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) + + self.set_pag_applied_layers(pag_applied_layers) + + # Copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.PixArtAlphaPipeline.encode_prompt with 120->300 + def encode_prompt( + self, + prompt: str | list[str], + do_classifier_free_guidance: bool = True, + negative_prompt: str = "", + num_images_per_prompt: int = 1, + device: torch.device | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + clean_caption: bool = False, + max_sequence_length: int = 300, + **kwargs, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For + PixArt-Alpha, this should be "". + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the "" + string. + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt. + """ + + if "mask_feature" in kwargs: + deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version." + deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False) + + if device is None: + device = self._execution_device + + # See Section 3.1. of the paper. + max_length = max_sequence_length + + if prompt_embeds is None: + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because T5 can only handle sequences up to" + f" {max_length} tokens: {removed_text}" + ) + + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.to(device) + + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask) + prompt_embeds = prompt_embeds[0] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.transformer is not None: + dtype = self.transformer.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.repeat(1, num_images_per_prompt) + prompt_attention_mask = prompt_attention_mask.view(bs_embed * num_images_per_prompt, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens = [negative_prompt] * bs_embed if isinstance(negative_prompt, str) else negative_prompt + uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + negative_prompt_attention_mask = uncond_input.attention_mask + negative_prompt_attention_mask = negative_prompt_attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(1, num_images_per_prompt) + negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed * num_images_per_prompt, -1) + else: + negative_prompt_embeds = None + negative_prompt_attention_mask = None + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.PixArtAlphaPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_steps, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip addresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str = "", + num_inference_steps: int = 20, + timesteps: list[int] = None, + sigmas: list[float] = None, + guidance_scale: float = 4.5, + num_images_per_prompt: int | None = 1, + height: int | None = None, + width: int | None = None, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + clean_caption: bool = True, + use_resolution_binning: bool = True, + max_sequence_length: int = 300, + pag_scale: float = 3.0, + pag_adaptive_scale: float = 0.0, + ) -> ImagePipelineOutput | tuple: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 4.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*, defaults to self.unet.config.sample_size): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size): + The width in pixels of the generated image. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + use_resolution_binning (`bool` defaults to `True`): + If set to `True`, the requested height and width are first mapped to the closest resolutions using + `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to + the requested resolution. Useful for generating non-square images. + max_sequence_length (`int` defaults to 300): Maximum sequence length to use with the `prompt`. + pag_scale (`float`, *optional*, defaults to 3.0): + The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention + guidance will not be used. + pag_adaptive_scale (`float`, *optional*, defaults to 0.0): + The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is + used. + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images + """ + # 1. Check inputs. Raise error if not correct + height = height or self.transformer.config.sample_size * self.vae_scale_factor + width = width or self.transformer.config.sample_size * self.vae_scale_factor + if use_resolution_binning: + if self.transformer.config.sample_size == 256: + aspect_ratio_bin = ASPECT_RATIO_2048_BIN + elif self.transformer.config.sample_size == 128: + aspect_ratio_bin = ASPECT_RATIO_1024_BIN + elif self.transformer.config.sample_size == 64: + aspect_ratio_bin = ASPECT_RATIO_512_BIN + elif self.transformer.config.sample_size == 32: + aspect_ratio_bin = ASPECT_RATIO_256_BIN + else: + raise ValueError("Invalid sample size") + orig_height, orig_width = height, width + height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin) + + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_steps, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) + self._pag_scale = pag_scale + self._pag_adaptive_scale = pag_adaptive_scale + + # 2. Default height and width to transformer + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt, + do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + ) + if self.do_perturbed_attention_guidance: + prompt_embeds = self._prepare_perturbed_attention_guidance( + prompt_embeds, negative_prompt_embeds, do_classifier_free_guidance + ) + prompt_attention_mask = self._prepare_perturbed_attention_guidance( + prompt_attention_mask, negative_prompt_attention_mask, do_classifier_free_guidance + ) + elif do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas + ) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + if self.do_perturbed_attention_guidance: + original_attn_proc = self.transformer.attn_processors + self._set_pag_attn_processor( + pag_applied_layers=self.pag_applied_layers, + do_classifier_free_guidance=do_classifier_free_guidance, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 6.1 Prepare micro-conditions. + added_cond_kwargs = {"resolution": None, "aspect_ratio": None} + + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance, perturbed-attention guidance, or both + latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + current_timestep = t + if not torch.is_tensor(current_timestep): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = latent_model_input.device.type == "mps" + is_npu = latent_model_input.device.type == "npu" + if isinstance(current_timestep, float): + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + else: + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 + current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device) + elif len(current_timestep.shape) == 0: + current_timestep = current_timestep[None].to(latent_model_input.device) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + current_timestep = current_timestep.expand(latent_model_input.shape[0]) + + # predict noise model_output + noise_pred = self.transformer( + latent_model_input, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + timestep=current_timestep, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_perturbed_attention_guidance: + noise_pred = self._apply_perturbed_attention_guidance( + noise_pred, do_classifier_free_guidance, guidance_scale, current_timestep + ) + elif do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # learned sigma + if self.transformer.config.out_channels // 2 == latent_channels: + noise_pred = noise_pred.chunk(2, dim=1)[0] + else: + noise_pred = noise_pred + + # compute previous image: x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + if use_resolution_binning: + image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height) + else: + image = latents + + if not output_type == "latent": + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if self.do_perturbed_attention_guidance: + self.transformer.set_attn_processor(original_attn_proc) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/pipelines/pag/pipeline_pag_sana.py b/diffusers/pipelines/pag/pipeline_pag_sana.py new file mode 100644 index 0000000000000000000000000000000000000000..71861f36647717c37e7f18c4a1854e4f68cfe6ed --- /dev/null +++ b/diffusers/pipelines/pag/pipeline_pag_sana.py @@ -0,0 +1,978 @@ +# Copyright 2025 PixArt-Sigma Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import inspect +import re +import urllib.parse as ul +import warnings +from typing import Callable + +import torch +from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PixArtImageProcessor +from ...models import AutoencoderDC, SanaTransformer2DModel +from ...models.attention_processor import PAGCFGSanaLinearAttnProcessor2_0, PAGIdentitySanaLinearAttnProcessor2_0 +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + BACKENDS_MAPPING, + deprecate, + is_bs4_available, + is_ftfy_available, + is_torch_xla_available, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import get_device, is_torch_version, randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from ..pixart_alpha.pipeline_pixart_alpha import ( + ASPECT_RATIO_512_BIN, + ASPECT_RATIO_1024_BIN, +) +from ..pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN +from ..sana.pipeline_sana import ASPECT_RATIO_4096_BIN +from .pag_utils import PAGMixin + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import SanaPAGPipeline + + >>> pipe = SanaPAGPipeline.from_pretrained( + ... "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", + ... pag_applied_layers=["transformer_blocks.8"], + ... torch_dtype=torch.float32, + ... ) + >>> pipe.to("cuda") + >>> pipe.text_encoder.to(torch.bfloat16) + >>> pipe.transformer = pipe.transformer.to(torch.bfloat16) + + >>> image = pipe(prompt='a cyberpunk cat with a neon sign that says "Sana"')[0] + >>> image[0].save("output.png") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class SanaPAGPipeline(DiffusionPipeline, PAGMixin): + r""" + Pipeline for text-to-image generation using [Sana](https://huggingface.co/papers/2410.10629). This pipeline + supports the use of [Perturbed Attention Guidance + (PAG)](https://huggingface.co/docs/diffusers/main/en/using-diffusers/pag). + """ + + # fmt: off + bad_punct_regex = re.compile(r"[" + "#®•©™&@·º½¾¿¡§~" + r"\)" + r"\(" + r"\]" + r"\[" + r"\}" + r"\{" + r"\|" + "\\" + r"\/" + r"\*" + r"]{1,}") + # fmt: on + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + tokenizer: GemmaTokenizer | GemmaTokenizerFast, + text_encoder: Gemma2PreTrainedModel, + vae: AutoencoderDC, + transformer: SanaTransformer2DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + pag_applied_layers: str | list[str] = "transformer_blocks.0", + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.encoder_block_out_channels) - 1) + if hasattr(self, "vae") and self.vae is not None + else 8 + ) + self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) + + self.set_pag_applied_layers( + pag_applied_layers, + pag_attn_processors=(PAGCFGSanaLinearAttnProcessor2_0(), PAGIdentitySanaLinearAttnProcessor2_0()), + ) + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." + deprecate( + "enable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." + deprecate( + "disable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." + deprecate( + "enable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." + deprecate( + "disable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.disable_tiling() + + def encode_prompt( + self, + prompt: str | list[str], + do_classifier_free_guidance: bool = True, + negative_prompt: str = "", + num_images_per_prompt: int = 1, + device: torch.device | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + clean_caption: bool = False, + max_sequence_length: int = 300, + complex_human_instruction: list[str] | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For + PixArt-Alpha, this should be "". + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For Sana, it's should be the embeddings of the "" string. + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt. + complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`): + If `complex_human_instruction` is not empty, the function will use the complex Human instruction for + the prompt. + """ + + if device is None: + device = self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if getattr(self, "tokenizer", None) is not None: + self.tokenizer.padding_side = "right" + + # See Section 3.1. of the paper. + max_length = max_sequence_length + select_index = [0] + list(range(-max_length + 1, 0)) + + if prompt_embeds is None: + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + + # prepare complex human instruction + if not complex_human_instruction: + max_length_all = max_length + else: + chi_prompt = "\n".join(complex_human_instruction) + prompt = [chi_prompt + p for p in prompt] + num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt)) + max_length_all = num_chi_prompt_tokens + max_length - 2 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length_all, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.to(device) + + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask) + prompt_embeds = prompt_embeds[0][:, select_index] + prompt_attention_mask = prompt_attention_mask[:, select_index] + + if self.transformer is not None: + dtype = self.transformer.dtype + elif self.text_encoder is not None: + dtype = self.text_encoder.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt + uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + negative_prompt_attention_mask = uncond_input.attention_mask + negative_prompt_attention_mask = negative_prompt_attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1) + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + else: + negative_prompt_embeds = None + negative_prompt_attention_mask = None + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + callback_on_step_end_tensor_inputs=None, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if height % 32 != 0 or width % 32 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip addresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + if latents is not None: + return latents.to(device=device, dtype=dtype) + + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str = "", + num_inference_steps: int = 20, + timesteps: list[int] = None, + sigmas: list[float] = None, + guidance_scale: float = 4.5, + num_images_per_prompt: int | None = 1, + height: int = 1024, + width: int = 1024, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + clean_caption: bool = False, + use_resolution_binning: bool = True, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 300, + complex_human_instruction: list[str] = [ + "Given a user prompt, generate an 'Enhanced prompt' that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:", + "- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.", + "- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.", + "Here are examples of how to transform or refine prompts:", + "- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.", + "- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.", + "Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:", + "User Prompt: ", + ], + pag_scale: float = 3.0, + pag_adaptive_scale: float = 0.0, + ) -> ImagePipelineOutput | tuple: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_inference_steps (`int`, *optional*, defaults to 20): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 4.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*, defaults to self.unet.config.sample_size): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size): + The width in pixels of the generated image. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + use_resolution_binning (`bool` defaults to `True`): + If set to `True`, the requested height and width are first mapped to the closest resolutions using + `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to + the requested resolution. Useful for generating non-square images. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 300): Maximum sequence length to use with the `prompt`. + complex_human_instruction (`list[str]`, *optional*): + Instructions for complex human attention: + https://github.com/NVlabs/Sana/blob/main/configs/sana_app_config/Sana_1600M_app.yaml#L55. + pag_scale (`float`, *optional*, defaults to 3.0): + The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention + guidance will not be used. + pag_adaptive_scale (`float`, *optional*, defaults to 0.0): + The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is + used. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + if use_resolution_binning: + if self.transformer.config.sample_size == 128: + aspect_ratio_bin = ASPECT_RATIO_4096_BIN + elif self.transformer.config.sample_size == 64: + aspect_ratio_bin = ASPECT_RATIO_2048_BIN + elif self.transformer.config.sample_size == 32: + aspect_ratio_bin = ASPECT_RATIO_1024_BIN + elif self.transformer.config.sample_size == 16: + aspect_ratio_bin = ASPECT_RATIO_512_BIN + else: + raise ValueError("Invalid sample size") + orig_height, orig_width = height, width + height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin) + + self.check_inputs( + prompt, + height, + width, + callback_on_step_end_tensor_inputs, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) + + self._pag_scale = pag_scale + self._pag_adaptive_scale = pag_adaptive_scale + self._guidance_scale = guidance_scale + self._interrupt = False + + # 2. Default height and width to transformer + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt, + self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + complex_human_instruction=complex_human_instruction, + ) + + if self.do_perturbed_attention_guidance: + prompt_embeds = self._prepare_perturbed_attention_guidance( + prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance + ) + prompt_attention_mask = self._prepare_perturbed_attention_guidance( + prompt_attention_mask, negative_prompt_attention_mask, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas + ) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + height, + width, + torch.float32, + device, + generator, + latents, + ) + if self.do_perturbed_attention_guidance: + original_attn_proc = self.transformer.attn_processors + self._set_pag_attn_processor( + pag_applied_layers=self.pag_applied_layers, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance, perturbed-attention guidance, or both + latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) + latent_model_input = latent_model_input.to(prompt_embeds.dtype) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype) + + # predict noise model_output + noise_pred = self.transformer( + latent_model_input, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + timestep=timestep, + return_dict=False, + )[0] + noise_pred = noise_pred.float() + + # perform guidance + if self.do_perturbed_attention_guidance: + noise_pred = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, guidance_scale, t + ) + elif self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute previous image: x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + else: + latents = latents.to(self.vae.dtype) + torch_accelerator_module = getattr(torch, get_device(), torch.cuda) + oom_error = ( + torch.OutOfMemoryError + if is_torch_version(">=", "2.5.0") + else torch_accelerator_module.OutOfMemoryError + ) + try: + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + except oom_error as e: + warnings.warn( + f"{e}. \n" + f"Try to use VAE tiling for large images. For example: \n" + f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)" + ) + if use_resolution_binning: + image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height) + + if not output_type == "latent": + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if self.do_perturbed_attention_guidance: + self.transformer.set_attn_processor(original_attn_proc) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/pipelines/pag/pipeline_pag_sd.py b/diffusers/pipelines/pag/pipeline_pag_sd.py new file mode 100644 index 0000000000000000000000000000000000000000..26ea717556c5f8cf5f59c4b2504cbead71279316 --- /dev/null +++ b/diffusers/pipelines/pag/pipeline_pag_sd.py @@ -0,0 +1,1082 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +from typing import Any, Callable + +import torch +from packaging import version +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...configuration_utils import FrozenDict +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from .pag_utils import PAGMixin + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import AutoPipelineForText2Image + + >>> pipe = AutoPipelineForText2Image.from_pretrained( + ... "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, enable_pag=True + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt, pag_scale=0.3).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusionPAGPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionLoraLoaderMixin, + IPAdapterMixin, + FromSingleFileMixin, + PAGMixin, +): + r""" + Pipeline for text-to-image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for + more details about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection = None, + requires_safety_checker: bool = True, + pag_applied_layers: str | list[str] = "mid", + ): + super().__init__() + + if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = ( + unet is not None + and hasattr(unet.config, "_diffusers_version") + and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0") + ) + is_unet_sample_size_less_64 = ( + unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5" + " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + self.set_pag_applied_layers(pag_applied_layers) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + timesteps: list[int] = None, + sigmas: list[float] = None, + guidance_scale: float = 7.5, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + guidance_rescale: float = 0.0, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + pag_scale: float = 3.0, + pag_adaptive_scale: float = 0.0, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when + using zero terminal SNR. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + pag_scale (`float`, *optional*, defaults to 3.0): + The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention + guidance will not be used. + pag_adaptive_scale (`float`, *optional*, defaults to 0.0): + The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is + used. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + # to deal with lora scaling and other possible forward hooks + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + None, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False + self._pag_scale = pag_scale + self._pag_adaptive_scale = pag_adaptive_scale + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + clip_skip=self.clip_skip, + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_perturbed_attention_guidance: + prompt_embeds = self._prepare_perturbed_attention_guidance( + prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + for i, image_embeds in enumerate(ip_adapter_image_embeds): + negative_image_embeds = None + if self.do_classifier_free_guidance: + negative_image_embeds, image_embeds = image_embeds.chunk(2) + if self.do_perturbed_attention_guidance: + image_embeds = self._prepare_perturbed_attention_guidance( + image_embeds, negative_image_embeds, self.do_classifier_free_guidance + ) + + elif self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) + image_embeds = image_embeds.to(device) + ip_adapter_image_embeds[i] = image_embeds + + # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas + ) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 6.1 Add image embeds for IP-Adapter + added_cond_kwargs = ( + {"image_embeds": ip_adapter_image_embeds} + if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) + else None + ) + + # 6.2 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + if self.do_perturbed_attention_guidance: + original_attn_proc = self.unet.attn_processors + self._set_pag_attn_processor( + pag_applied_layers=self.pag_applied_layers, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_perturbed_attention_guidance: + noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t, True + ) + + elif self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://huggingface.co/papers/2305.08891 + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if self.do_perturbed_attention_guidance: + self.unet.set_attn_processor(original_attn_proc) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/pipelines/pag/pipeline_pag_sd_3.py b/diffusers/pipelines/pag/pipeline_pag_sd_3.py new file mode 100644 index 0000000000000000000000000000000000000000..f0fbef29b699bdfe31cc676b2b0742876f5ba0f0 --- /dev/null +++ b/diffusers/pipelines/pag/pipeline_pag_sd_3.py @@ -0,0 +1,998 @@ +# Copyright 2025 Stability AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import torch +from transformers import ( + CLIPTextModelWithProjection, + CLIPTokenizer, + T5EncoderModel, + T5TokenizerFast, +) + +from ...image_processor import VaeImageProcessor +from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin +from ...models.attention_processor import PAGCFGJointAttnProcessor2_0, PAGJointAttnProcessor2_0 +from ...models.autoencoders import AutoencoderKL +from ...models.transformers import SD3Transformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from ..stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput +from .pag_utils import PAGMixin + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import AutoPipelineForText2Image + + >>> pipe = AutoPipelineForText2Image.from_pretrained( + ... "stabilityai/stable-diffusion-3-medium-diffusers", + ... torch_dtype=torch.float16, + ... enable_pag=True, + ... pag_applied_layers=["blocks.13"], + ... ) + >>> pipe.to("cuda") + >>> prompt = "A cat holding a sign that says hello world" + >>> image = pipe(prompt, guidance_scale=5.0, pag_scale=0.7).images[0] + >>> image.save("sd3_pag.png") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusion3PAGPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, PAGMixin): + r""" + [PAG pipeline](https://huggingface.co/docs/diffusers/main/en/using-diffusers/pag) for text-to-image generation + using Stable Diffusion 3. + + Args: + transformer ([`SD3Transformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModelWithProjection`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant, + with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size` + as its dimension. + text_encoder_2 ([`CLIPTextModelWithProjection`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + text_encoder_3 ([`T5EncoderModel`]): + Frozen text-encoder. Stable Diffusion 3 uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the + [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_3 (`T5TokenizerFast`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"] + + def __init__( + self, + transformer: SD3Transformer2DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer_2: CLIPTokenizer, + text_encoder_3: T5EncoderModel, + tokenizer_3: T5TokenizerFast, + pag_applied_layers: str | list[str] = "blocks.1", # 1st transformer block + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + text_encoder_3=text_encoder_3, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + tokenizer_3=tokenizer_3, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = ( + self.transformer.config.sample_size + if hasattr(self, "transformer") and self.transformer is not None + else 128 + ) + self.patch_size = ( + self.transformer.config.patch_size if hasattr(self, "transformer") and self.transformer is not None else 2 + ) + + self.set_pag_applied_layers( + pag_applied_layers, pag_attn_processors=(PAGCFGJointAttnProcessor2_0(), PAGJointAttnProcessor2_0()) + ) + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 256, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if self.text_encoder_3 is None: + return torch.zeros( + ( + batch_size * num_images_per_prompt, + max_sequence_length, + self.transformer.config.joint_attention_dim, + ), + device=device, + dtype=dtype, + ) + + text_inputs = self.tokenizer_3( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0] + + dtype = self.text_encoder_3.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( + self, + prompt: str | list[str], + num_images_per_prompt: int = 1, + device: torch.device | None = None, + clip_skip: int | None = None, + clip_model_index: int = 0, + ): + device = device or self._execution_device + + clip_tokenizers = [self.tokenizer, self.tokenizer_2] + clip_text_encoders = [self.text_encoder, self.text_encoder_2] + + tokenizer = clip_tokenizers[clip_model_index] + text_encoder = clip_text_encoders[clip_model_index] + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + pooled_prompt_embeds = prompt_embeds[0] + + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt) + pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds, pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + prompt_2: str | list[str], + prompt_3: str | list[str], + device: torch.device | None = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: str | list[str] | None = None, + negative_prompt_2: str | list[str] | None = None, + negative_prompt_3: str | list[str] | None = None, + prompt_embeds: torch.FloatTensor | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + pooled_prompt_embeds: torch.FloatTensor | None = None, + negative_pooled_prompt_embeds: torch.FloatTensor | None = None, + clip_skip: int | None = None, + max_sequence_length: int = 256, + lora_scale: float | None = None, + ): + r""" + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + prompt_3 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. + negative_prompt_3 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and + `text_encoder_3`. If not defined, `negative_prompt` is used in all the text-encoders. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + prompt_3 = prompt_3 or prompt + prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3 + + prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=clip_skip, + clip_model_index=0, + ) + prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds( + prompt=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=clip_skip, + clip_model_index=1, + ) + clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1) + + t5_prompt_embed = self._get_t5_prompt_embeds( + prompt=prompt_3, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + clip_prompt_embeds = torch.nn.functional.pad( + clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]) + ) + + prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) + pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + negative_prompt_3 = negative_prompt_3 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + negative_prompt_3 = ( + batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3 + ) + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds( + negative_prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=None, + clip_model_index=0, + ) + negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds( + negative_prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=None, + clip_model_index=1, + ) + negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1) + + t5_negative_prompt_embed = self._get_t5_prompt_embeds( + prompt=negative_prompt_3, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + negative_clip_prompt_embeds = torch.nn.functional.pad( + negative_clip_prompt_embeds, + (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]), + ) + + negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2) + negative_pooled_prompt_embeds = torch.cat( + [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1 + ) + + if self.text_encoder is not None: + if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.check_inputs + def check_inputs( + self, + prompt, + prompt_2, + prompt_3, + height, + width, + negative_prompt=None, + negative_prompt_2=None, + negative_prompt_3=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if ( + height % (self.vae_scale_factor * self.patch_size) != 0 + or width % (self.vae_scale_factor * self.patch_size) != 0 + ): + raise ValueError( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * self.patch_size} but are {height} and {width}." + f"You can use height {height - height % (self.vae_scale_factor * self.patch_size)} and width {width - width % (self.vae_scale_factor * self.patch_size)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_3 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)): + raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_3 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_latents + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + if latents is not None: + return latents.to(device=device, dtype=dtype) + + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + prompt_2: str | list[str] | None = None, + prompt_3: str | list[str] | None = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 28, + sigmas: list[float] | None = None, + guidance_scale: float = 7.0, + negative_prompt: str | list[str] | None = None, + negative_prompt_2: str | list[str] | None = None, + negative_prompt_3: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.FloatTensor | None = None, + prompt_embeds: torch.FloatTensor | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + pooled_prompt_embeds: torch.FloatTensor | None = None, + negative_pooled_prompt_embeds: torch.FloatTensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + joint_attention_kwargs: dict[str, Any] | None = None, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 256, + pag_scale: float = 3.0, + pag_adaptive_scale: float = 0.0, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead + prompt_3 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is + will be used instead + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used instead + negative_prompt_3 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and + `text_encoder_3`. If not defined, `negative_prompt` is used instead + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. + pag_scale (`float`, *optional*, defaults to 3.0): + The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention + guidance will not be used. + pag_adaptive_scale (`float`, *optional*, defaults to 0.0): + The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is + used. + + Examples: + + Returns: + [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + prompt_3, + height, + width, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + self._pag_scale = pag_scale + self._pag_adaptive_scale = pag_adaptive_scale # + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_3=prompt_3, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + device=device, + clip_skip=self.clip_skip, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + if self.do_perturbed_attention_guidance: + prompt_embeds = self._prepare_perturbed_attention_guidance( + prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance + ) + pooled_prompt_embeds = self._prepare_perturbed_attention_guidance( + pooled_prompt_embeds, negative_pooled_prompt_embeds, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + + # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, sigmas=sigmas + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + if self.do_perturbed_attention_guidance: + original_attn_proc = self.transformer.attn_processors + self._set_pag_attn_processor( + pag_applied_layers=self.pag_applied_layers, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance, perturbed-attention guidance, or both + latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + pooled_projections=pooled_prompt_embeds, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_perturbed_attention_guidance: + noise_pred = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t + ) + + elif self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if self.do_perturbed_attention_guidance: + self.transformer.set_attn_processor(original_attn_proc) + + if not return_dict: + return (image,) + + return StableDiffusion3PipelineOutput(images=image) diff --git a/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py b/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py new file mode 100644 index 0000000000000000000000000000000000000000..84b727dc061350d22be13f025c2c2d773dcbe2f2 --- /dev/null +++ b/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py @@ -0,0 +1,1062 @@ +# Copyright 2025 Stability AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import PIL.Image +import torch +from transformers import ( + CLIPTextModelWithProjection, + CLIPTokenizer, + T5EncoderModel, + T5TokenizerFast, +) + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin +from ...models.attention_processor import PAGCFGJointAttnProcessor2_0, PAGJointAttnProcessor2_0 +from ...models.autoencoders import AutoencoderKL +from ...models.transformers import SD3Transformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from ..stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput +from .pag_utils import PAGMixin + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusion3PAGImg2ImgPipeline + >>> from diffusers.utils import load_image + + >>> pipe = StableDiffusion3PAGImg2ImgPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-3-medium-diffusers", + ... torch_dtype=torch.float16, + ... pag_applied_layers=["blocks.13"], + ... ) + >>> pipe.to("cuda") + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> url = "https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/aa_xl/000000009.png" + >>> init_image = load_image(url).convert("RGB") + >>> image = pipe(prompt, image=init_image, guidance_scale=5.0, pag_scale=0.7).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusion3PAGImg2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, PAGMixin): + r""" + [PAG pipeline](https://huggingface.co/docs/diffusers/main/en/using-diffusers/pag) for image-to-image generation + using Stable Diffusion 3. + + Args: + transformer ([`SD3Transformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModelWithProjection`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant, + with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size` + as its dimension. + text_encoder_2 ([`CLIPTextModelWithProjection`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + text_encoder_3 ([`T5EncoderModel`]): + Frozen text-encoder. Stable Diffusion 3 uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the + [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_3 (`T5TokenizerFast`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"] + + def __init__( + self, + transformer: SD3Transformer2DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer_2: CLIPTokenizer, + text_encoder_3: T5EncoderModel, + tokenizer_3: T5TokenizerFast, + pag_applied_layers: str | list[str] = "blocks.1", # 1st transformer block + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + text_encoder_3=text_encoder_3, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + tokenizer_3=tokenizer_3, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = ( + self.transformer.config.sample_size + if hasattr(self, "transformer") and self.transformer is not None + else 128 + ) + self.patch_size = ( + self.transformer.config.patch_size if hasattr(self, "transformer") and self.transformer is not None else 2 + ) + + self.set_pag_applied_layers( + pag_applied_layers, pag_attn_processors=(PAGCFGJointAttnProcessor2_0(), PAGJointAttnProcessor2_0()) + ) + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 256, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if self.text_encoder_3 is None: + return torch.zeros( + ( + batch_size * num_images_per_prompt, + max_sequence_length, + self.transformer.config.joint_attention_dim, + ), + device=device, + dtype=dtype, + ) + + text_inputs = self.tokenizer_3( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0] + + dtype = self.text_encoder_3.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( + self, + prompt: str | list[str], + num_images_per_prompt: int = 1, + device: torch.device | None = None, + clip_skip: int | None = None, + clip_model_index: int = 0, + ): + device = device or self._execution_device + + clip_tokenizers = [self.tokenizer, self.tokenizer_2] + clip_text_encoders = [self.text_encoder, self.text_encoder_2] + + tokenizer = clip_tokenizers[clip_model_index] + text_encoder = clip_text_encoders[clip_model_index] + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + pooled_prompt_embeds = prompt_embeds[0] + + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt) + pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds, pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + prompt_2: str | list[str], + prompt_3: str | list[str], + device: torch.device | None = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: str | list[str] | None = None, + negative_prompt_2: str | list[str] | None = None, + negative_prompt_3: str | list[str] | None = None, + prompt_embeds: torch.FloatTensor | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + pooled_prompt_embeds: torch.FloatTensor | None = None, + negative_pooled_prompt_embeds: torch.FloatTensor | None = None, + clip_skip: int | None = None, + max_sequence_length: int = 256, + lora_scale: float | None = None, + ): + r""" + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + prompt_3 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. + negative_prompt_3 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and + `text_encoder_3`. If not defined, `negative_prompt` is used in all the text-encoders. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + prompt_3 = prompt_3 or prompt + prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3 + + prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=clip_skip, + clip_model_index=0, + ) + prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds( + prompt=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=clip_skip, + clip_model_index=1, + ) + clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1) + + t5_prompt_embed = self._get_t5_prompt_embeds( + prompt=prompt_3, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + clip_prompt_embeds = torch.nn.functional.pad( + clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]) + ) + + prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) + pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + negative_prompt_3 = negative_prompt_3 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + negative_prompt_3 = ( + batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3 + ) + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds( + negative_prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=None, + clip_model_index=0, + ) + negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds( + negative_prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=None, + clip_model_index=1, + ) + negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1) + + t5_negative_prompt_embed = self._get_t5_prompt_embeds( + prompt=negative_prompt_3, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + negative_clip_prompt_embeds = torch.nn.functional.pad( + negative_clip_prompt_embeds, + (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]), + ) + + negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2) + negative_pooled_prompt_embeds = torch.cat( + [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1 + ) + + if self.text_encoder is not None: + if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.check_inputs + def check_inputs( + self, + prompt, + prompt_2, + prompt_3, + height, + width, + strength, + negative_prompt=None, + negative_prompt_2=None, + negative_prompt_3=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if ( + height % (self.vae_scale_factor * self.patch_size) != 0 + or width % (self.vae_scale_factor * self.patch_size) != 0 + ): + raise ValueError( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * self.patch_size} but are {height} and {width}." + f"You can use height {height - height % (self.vae_scale_factor * self.patch_size)} and width {width - width % (self.vae_scale_factor * self.patch_size)}." + ) + + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_3 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)): + raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_3 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.prepare_latents + def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + if image.shape[1] == self.vae.config.latent_channels: + init_latents = image + + else: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + init_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + init_latents = (init_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # get latents + init_latents = self.scheduler.scale_noise(init_latents, timestep, noise) + latents = init_latents.to(device=device, dtype=dtype) + + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + prompt_2: str | list[str] | None = None, + prompt_3: str | list[str] | None = None, + height: int | None = None, + width: int | None = None, + image: PipelineImageInput = None, + strength: float = 0.6, + num_inference_steps: int = 50, + sigmas: list[float] | None = None, + guidance_scale: float = 7.0, + negative_prompt: str | list[str] | None = None, + negative_prompt_2: str | list[str] | None = None, + negative_prompt_3: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.FloatTensor | None = None, + prompt_embeds: torch.FloatTensor | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + pooled_prompt_embeds: torch.FloatTensor | None = None, + negative_pooled_prompt_embeds: torch.FloatTensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + joint_attention_kwargs: dict[str, Any] | None = None, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 256, + pag_scale: float = 3.0, + pag_adaptive_scale: float = 0.0, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead + prompt_3 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is + will be used instead + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + strength (`float`, *optional*, defaults to 0.8): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used instead + negative_prompt_3 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and + `text_encoder_3`. If not defined, `negative_prompt` is used instead + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. + pag_scale (`float`, *optional*, defaults to 3.0): + The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention + guidance will not be used. + pag_adaptive_scale (`float`, *optional*, defaults to 0.0): + The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is + used. + + Examples: + + Returns: + [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + prompt_3, + height, + width, + strength, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + self._pag_scale = pag_scale + self._pag_adaptive_scale = pag_adaptive_scale + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_3=prompt_3, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + device=device, + clip_skip=self.clip_skip, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + if self.do_perturbed_attention_guidance: + prompt_embeds = self._prepare_perturbed_attention_guidance( + prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance + ) + pooled_prompt_embeds = self._prepare_perturbed_attention_guidance( + pooled_prompt_embeds, negative_pooled_prompt_embeds, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + + # 3. Preprocess image + image = self.image_processor.preprocess(image, height=height, width=width) + + # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, sigmas=sigmas + ) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + if latents is None: + latents = self.prepare_latents( + image, + latent_timestep, + batch_size, + num_images_per_prompt, + prompt_embeds.dtype, + device, + generator, + ) + + if self.do_perturbed_attention_guidance: + original_attn_proc = self.transformer.attn_processors + self._set_pag_attn_processor( + pag_applied_layers=self.pag_applied_layers, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + + # 6. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance, perturbed-attention guidance, or both + latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + pooled_projections=pooled_prompt_embeds, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_perturbed_attention_guidance: + noise_pred = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t + ) + + elif self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if self.do_perturbed_attention_guidance: + self.transformer.set_attn_processor(original_attn_proc) + + if not return_dict: + return (image,) + + return StableDiffusion3PipelineOutput(images=image) diff --git a/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py b/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py new file mode 100644 index 0000000000000000000000000000000000000000..62d1c912283f7258741c846d3a4df5110768bda2 --- /dev/null +++ b/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py @@ -0,0 +1,878 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import torch +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...image_processor import PipelineImageInput +from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...models.unets.unet_motion_model import MotionAdapter +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..animatediff.pipeline_output import AnimateDiffPipelineOutput +from ..free_init_utils import FreeInitMixin +from ..free_noise_utils import AnimateDiffFreeNoiseMixin +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from .pag_utils import PAGMixin + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import AnimateDiffPAGPipeline, MotionAdapter, DDIMScheduler + >>> from diffusers.utils import export_to_gif + + >>> model_id = "SG161222/Realistic_Vision_V5.1_noVAE" + >>> motion_adapter_id = "guoyww/animatediff-motion-adapter-v1-5-2" + >>> motion_adapter = MotionAdapter.from_pretrained(motion_adapter_id) + >>> scheduler = DDIMScheduler.from_pretrained( + ... model_id, subfolder="scheduler", beta_schedule="linear", steps_offset=1, clip_sample=False + ... ) + >>> pipe = AnimateDiffPAGPipeline.from_pretrained( + ... model_id, + ... motion_adapter=motion_adapter, + ... scheduler=scheduler, + ... pag_applied_layers=["mid"], + ... torch_dtype=torch.float16, + ... ).to("cuda") + + >>> video = pipe( + ... prompt="car, futuristic cityscape with neon lights, street, no human", + ... negative_prompt="low quality, bad quality", + ... num_inference_steps=25, + ... guidance_scale=6.0, + ... pag_scale=3.0, + ... generator=torch.Generator().manual_seed(42), + ... ).frames[0] + + >>> export_to_gif(video, "animatediff_pag.gif") + ``` +""" + + +class AnimateDiffPAGPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + IPAdapterMixin, + StableDiffusionLoraLoaderMixin, + FreeInitMixin, + AnimateDiffFreeNoiseMixin, + PAGMixin, +): + r""" + Pipeline for text-to-video generation using + [AnimateDiff](https://huggingface.co/docs/diffusers/en/api/pipelines/animatediff) and [Perturbed Attention + Guidance](https://huggingface.co/docs/diffusers/en/using-diffusers/pag). + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer (`CLIPTokenizer`): + A [`~transformers.CLIPTokenizer`] to tokenize text. + unet ([`UNet2DConditionModel`]): + A [`UNet2DConditionModel`] used to create a UNetMotionModel to denoise the encoded video latents. + motion_adapter ([`MotionAdapter`]): + A [`MotionAdapter`] to be used in combination with `unet` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = ["feature_extractor", "image_encoder", "motion_adapter"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel | UNetMotionModel, + motion_adapter: MotionAdapter, + scheduler: KarrasDiffusionSchedulers, + feature_extractor: CLIPImageProcessor = None, + image_encoder: CLIPVisionModelWithProjection = None, + pag_applied_layers: str | list[str] = "mid_block.*attn1", # ["mid"], ["down_blocks.1"] + ): + super().__init__() + if isinstance(unet, UNet2DConditionModel): + unet = UNetMotionModel.from_unet2d(unet, motion_adapter) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + motion_adapter=motion_adapter, + scheduler=scheduler, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) + + self.set_pag_applied_layers(pag_applied_layers) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents + def decode_latents(self, latents, decode_chunk_size: int = 16): + latents = 1 / self.vae.config.scaling_factor * latents + + batch_size, channels, num_frames, height, width = latents.shape + latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) + + video = [] + for i in range(0, latents.shape[0], decode_chunk_size): + batch_latents = latents[i : i + decode_chunk_size] + batch_latents = self.vae.decode(batch_latents).sample + video.append(batch_latents) + + video = torch.cat(video) + video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + video = video.float() + return video + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.pia.pipeline_pia.PIAPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.prepare_latents + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + # If FreeNoise is enabled, generate latents as described in Equation (7) of [FreeNoise](https://huggingface.co/papers/2310.15169) + if self.free_noise_enabled: + latents = self._prepare_latents_free_noise( + batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + shape = ( + batch_size, + num_channels_latents, + num_frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] | None = None, + num_frames: int | None = 16, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: str | list[str] | None = None, + num_videos_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + decode_chunk_size: int = 16, + pag_scale: float = 3.0, + pag_adaptive_scale: float = 0.0, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated video. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated video. + num_frames (`int`, *optional*, defaults to 16): + The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds + amounts to 2 seconds of video. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality videos at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. Latents should be of shape + `(batch_size, num_channel, num_frames, height, width)`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated video. Choose between `torch.Tensor`, `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead + of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + pag_scale (`float`, *optional*, defaults to 3.0): + The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention + guidance will not be used. + pag_adaptive_scale (`float`, *optional*, defaults to 0.0): + The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is + used. + + Examples: + + Returns: + [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is + returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. + """ + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._pag_scale = pag_scale + self._pag_adaptive_scale = pag_adaptive_scale + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_videos_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_perturbed_attention_guidance: + prompt_embeds = self._prepare_perturbed_attention_guidance( + prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_videos_per_prompt, + self.do_classifier_free_guidance, + ) + + for i, image_embeds in enumerate(ip_adapter_image_embeds): + negative_image_embeds = None + if self.do_classifier_free_guidance: + negative_image_embeds, image_embeds = image_embeds.chunk(2) + if self.do_perturbed_attention_guidance: + image_embeds = self._prepare_perturbed_attention_guidance( + image_embeds, negative_image_embeds, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) + image_embeds = image_embeds.to(device) + ip_adapter_image_embeds[i] = image_embeds + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Add image embeds for IP-Adapter + added_cond_kwargs = ( + {"image_embeds": ip_adapter_image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) + + if self.do_perturbed_attention_guidance: + original_attn_proc = self.unet.attn_processors + self._set_pag_attn_processor( + pag_applied_layers=self.pag_applied_layers, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + + num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1 + for free_init_iter in range(num_free_init_iters): + if self.free_init_enabled: + latents, timesteps = self._apply_free_init( + latents, free_init_iter, num_inference_steps, device, latents.dtype, generator + ) + + self._num_timesteps = len(timesteps) + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + # 8. Denoising loop + with self.progress_bar(total=self._num_timesteps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat( + [latents] * (prompt_embeds.shape[0] // num_frames // latents.shape[0]) + ) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + ).sample + + # perform guidance + if self.do_perturbed_attention_guidance: + noise_pred = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t + ) + elif self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + # 9. Post processing + if output_type == "latent": + video = latents + else: + video_tensor = self.decode_latents(latents, decode_chunk_size) + video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) + + # 10. Offload all models + self.maybe_free_model_hooks() + + if self.do_perturbed_attention_guidance: + self.unet.set_attn_processor(original_attn_proc) + + if not return_dict: + return (video,) + + return AnimateDiffPipelineOutput(frames=video) diff --git a/diffusers/pipelines/pag/pipeline_pag_sd_img2img.py b/diffusers/pipelines/pag/pipeline_pag_sd_img2img.py new file mode 100644 index 0000000000000000000000000000000000000000..822483eca995e43da9f3c8b68cb71dfc6ae44c94 --- /dev/null +++ b/diffusers/pipelines/pag/pipeline_pag_sd_img2img.py @@ -0,0 +1,1112 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import PIL.Image +import torch +from packaging import version +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...configuration_utils import FrozenDict +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from .pag_utils import PAGMixin + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import AutoPipelineForImage2Image + >>> from diffusers.utils import load_image + + >>> pipe = AutoPipelineForImage2Image.from_pretrained( + ... "stable-diffusion-v1-5/stable-diffusion-v1-5", + ... torch_dtype=torch.float16, + ... enable_pag=True, + ... ) + >>> pipe = pipe.to("cuda") + >>> url = "https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/aa_xl/000000009.png" + + >>> init_image = load_image(url).convert("RGB") + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt, image=init_image, pag_scale=0.3).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusionPAGImg2ImgPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + IPAdapterMixin, + StableDiffusionLoraLoaderMixin, + FromSingleFileMixin, + PAGMixin, +): + r""" + Pipeline for text-guided image-to-image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for + more details about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection = None, + requires_safety_checker: bool = True, + pag_applied_layers: str | list[str] = "mid", # ["mid"], ["down.block_1", "up.block_0.attentions_0"] + ): + super().__init__() + + if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = ( + unet is not None + and hasattr(unet.config, "_diffusers_version") + and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0") + ) + is_unet_sample_size_less_64 = ( + unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5" + " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + self.set_pag_applied_layers(pag_applied_layers) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + strength, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.prepare_latents + def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + + else: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: + image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) + elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " + ) + + init_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + init_latents = self.vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + latents = init_latents + + return latents + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + image: PipelineImageInput = None, + strength: float = 0.8, + num_inference_steps: int | None = 50, + timesteps: list[int] = None, + sigmas: list[float] = None, + guidance_scale: float | None = 7.5, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float | None = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + clip_skip: int = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + pag_scale: float = 3.0, + pag_adaptive_scale: float = 0.0, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + strength (`float`, *optional*, defaults to 0.8): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter is modulated by `strength`. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + pag_scale (`float`, *optional*, defaults to 3.0): + The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention + guidance will not be used. + pag_adaptive_scale (`float`, *optional*, defaults to 0.0): + The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is + used. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + strength, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False + + self._pag_scale = pag_scale + self._pag_adaptive_scale = pag_adaptive_scale + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_perturbed_attention_guidance: + prompt_embeds = self._prepare_perturbed_attention_guidance( + prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + for i, image_embeds in enumerate(ip_adapter_image_embeds): + negative_image_embeds = None + if self.do_classifier_free_guidance: + negative_image_embeds, image_embeds = image_embeds.chunk(2) + if self.do_perturbed_attention_guidance: + image_embeds = self._prepare_perturbed_attention_guidance( + image_embeds, negative_image_embeds, self.do_classifier_free_guidance + ) + + elif self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) + image_embeds = image_embeds.to(device) + ip_adapter_image_embeds[i] = image_embeds + + # 4. Preprocess image + image = self.image_processor.preprocess(image) + + # 5. set timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas + ) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 6. Prepare latent variables + latents = self.prepare_latents( + image, + latent_timestep, + batch_size, + num_images_per_prompt, + prompt_embeds.dtype, + device, + generator, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Add image embeds for IP-Adapter + added_cond_kwargs = ( + {"image_embeds": image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) + + # 7.2 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + if self.do_perturbed_attention_guidance: + original_attn_proc = self.unet.attn_processors + self._set_pag_attn_processor( + pag_applied_layers=self.pag_applied_layers, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + self._num_timesteps = len(timesteps) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + if ip_adapter_image_embeds is not None: + added_cond_kwargs["image_embeds"] = ip_adapter_image_embeds + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_perturbed_attention_guidance: + noise_pred = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t + ) + elif self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if self.do_perturbed_attention_guidance: + self.unet.set_attn_processor(original_attn_proc) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py b/diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py new file mode 100644 index 0000000000000000000000000000000000000000..0f6fbbd9ae1607402a457670d78688ad77544beb --- /dev/null +++ b/diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py @@ -0,0 +1,1376 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +from typing import Any, Callable + +import PIL.Image +import torch +from packaging import version +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...configuration_utils import FrozenDict +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AsymmetricAutoencoderKL, AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from .pag_utils import PAGMixin + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import AutoPipelineForInpainting + + >>> pipe = AutoPipelineForInpainting.from_pretrained( + ... "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, enable_pag=True + ... ) + >>> pipe = pipe.to("cuda") + >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" + >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" + >>> init_image = load_image(img_url).convert("RGB") + >>> mask_image = load_image(mask_url).convert("RGB") + >>> prompt = "A majestic tiger sitting on a bench" + >>> image = pipe( + ... prompt=prompt, + ... image=init_image, + ... mask_image=mask_image, + ... strength=0.8, + ... num_inference_steps=50, + ... guidance_scale=guidance_scale, + ... generator=generator, + ... pag_scale=pag_scale, + ... ).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusionPAGInpaintPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionLoraLoaderMixin, + IPAdapterMixin, + FromSingleFileMixin, + PAGMixin, +): + r""" + Pipeline for text-to-image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for + more details about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection = None, + requires_safety_checker: bool = True, + pag_applied_layers: str | list[str] = "mid", + ): + super().__init__() + + if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = ( + unet is not None + and hasattr(unet.config, "_diffusers_version") + and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0") + ) + is_unet_sample_size_less_64 = ( + unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5" + " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + self.set_pag_applied_layers(pag_applied_layers) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.check_inputs + def check_inputs( + self, + prompt, + image, + mask_image, + height, + width, + strength, + callback_steps, + output_type, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + padding_mask_crop=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if padding_mask_crop is not None: + if not isinstance(image, PIL.Image.Image): + raise ValueError( + f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}." + ) + if not isinstance(mask_image, PIL.Image.Image): + raise ValueError( + f"The mask image should be a PIL image when inpainting mask crop, but is of type" + f" {type(mask_image)}." + ) + if output_type != "pil": + raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.") + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_latents + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + image=None, + timestep=None, + is_strength_max=True, + return_noise=False, + return_image_latents=False, + ): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if (image is None or timestep is None) and not is_strength_max: + raise ValueError( + "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." + "However, either the image or the noise timestep has not been provided." + ) + + if return_image_latents or (latents is None and not is_strength_max): + image = image.to(device=device, dtype=dtype) + + if image.shape[1] == 4: + image_latents = image + else: + image_latents = self._encode_vae_image(image=image, generator=generator) + image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) + + if latents is None: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # if strength is 1. then initialise the latents to noise, else initial to image + noise + latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep) + # if pure noise then scale the initial latents by the Scheduler's init sigma + latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents + else: + noise = latents.to(device) + latents = noise * self.scheduler.init_noise_sigma + + outputs = (latents,) + + if return_noise: + outputs += (noise,) + + if return_image_latents: + outputs += (image_latents,) + + return outputs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + image_latents = self.vae.config.scaling_factor * image_latents + + return image_latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_mask_latents + def prepare_mask_latents( + self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + masked_image = masked_image.to(device=device, dtype=dtype) + + if masked_image.shape[1] == 4: + masked_image_latents = masked_image + else: + masked_image_latents = self._encode_vae_image(masked_image, generator=generator) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) + + mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask + masked_image_latents = ( + torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + return mask, masked_image_latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + image: PipelineImageInput = None, + mask_image: PipelineImageInput = None, + masked_image_latents: torch.Tensor = None, + height: int | None = None, + width: int | None = None, + padding_mask_crop: int | None = None, + strength: float = 0.9999, + num_inference_steps: int = 50, + timesteps: list[int] = None, + sigmas: list[float] = None, + guidance_scale: float = 7.5, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + guidance_rescale: float = 0.0, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + pag_scale: float = 3.0, + pag_adaptive_scale: float = 0.0, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when + using zero terminal SNR. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + pag_scale (`float`, *optional*, defaults to 3.0): + The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention + guidance will not be used. + pag_adaptive_scale (`float`, *optional*, defaults to 0.0): + The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is + used. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + # to deal with lora scaling and other possible forward hooks + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + image, + mask_image, + height, + width, + strength, + None, + None, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, + padding_mask_crop, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False + self._pag_scale = pag_scale + self._pag_adaptive_scale = pag_adaptive_scale + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + clip_skip=self.clip_skip, + ) + + # 4. set timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas + ) + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps=num_inference_steps, strength=strength, device=device + ) + # check that number of inference steps is not < 1 - as this doesn't make sense + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise + is_strength_max = strength == 1.0 + + # 5. Preprocess mask and image + if padding_mask_crop is not None: + crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop) + resize_mode = "fill" + else: + crops_coords = None + resize_mode = "default" + + original_image = image + init_image = self.image_processor.preprocess( + image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode + ) + init_image = init_image.to(dtype=torch.float32) + + # 6. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + num_channels_unet = self.unet.config.in_channels + return_image_latents = num_channels_unet == 4 + + latents_outputs = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + image=init_image, + timestep=latent_timestep, + is_strength_max=is_strength_max, + return_noise=True, + return_image_latents=return_image_latents, + ) + + if return_image_latents: + latents, noise, image_latents = latents_outputs + else: + latents, noise = latents_outputs + + # 7. Prepare mask latent variables + mask_condition = self.mask_processor.preprocess( + mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords + ) + + if masked_image_latents is None: + masked_image = init_image * (mask_condition < 0.5) + else: + masked_image = masked_image_latents + + mask, masked_image_latents = self.prepare_mask_latents( + mask_condition, + masked_image, + batch_size * num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + self.do_classifier_free_guidance, + ) + if self.do_perturbed_attention_guidance: + if self.do_classifier_free_guidance: + mask, _ = mask.chunk(2) + masked_image_latents, _ = masked_image_latents.chunk(2) + mask = self._prepare_perturbed_attention_guidance(mask, mask, self.do_classifier_free_guidance) + masked_image_latents = self._prepare_perturbed_attention_guidance( + masked_image_latents, masked_image_latents, self.do_classifier_free_guidance + ) + + # 8. Check that sizes of mask, masked image and latents match + if num_channels_unet == 9: + # default case for stable-diffusion-v1-5/stable-diffusion-inpainting + num_channels_mask = mask.shape[1] + num_channels_masked_image = masked_image_latents.shape[1] + if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of" + " `pipeline.unet` or your `mask_image` or `image` input." + ) + elif num_channels_unet != 4: + raise ValueError( + f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}." + ) + # 9 Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + + if self.do_perturbed_attention_guidance: + prompt_embeds = self._prepare_perturbed_attention_guidance( + prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + for i, image_embeds in enumerate(ip_adapter_image_embeds): + negative_image_embeds = None + if self.do_classifier_free_guidance: + negative_image_embeds, image_embeds = image_embeds.chunk(2) + if self.do_perturbed_attention_guidance: + image_embeds = self._prepare_perturbed_attention_guidance( + image_embeds, negative_image_embeds, self.do_classifier_free_guidance + ) + + elif self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) + image_embeds = image_embeds.to(device) + ip_adapter_image_embeds[i] = image_embeds + + # 9.1 Add image embeds for IP-Adapter + added_cond_kwargs = ( + {"image_embeds": ip_adapter_image_embeds} + if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) + else None + ) + + # 9.2 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 10. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + if self.do_perturbed_attention_guidance: + original_attn_proc = self.unet.attn_processors + self._set_pag_attn_processor( + pag_applied_layers=self.pag_applied_layers, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + if num_channels_unet == 9: + latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_perturbed_attention_guidance: + noise_pred = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t + ) + + elif self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://huggingface.co/papers/2305.08891 + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if num_channels_unet == 4: + init_latents_proper = image_latents + if self.do_perturbed_attention_guidance: + init_mask, *_ = mask.chunk(3) if self.do_classifier_free_guidance else mask.chunk(2) + else: + init_mask, *_ = mask.chunk(2) if self.do_classifier_free_guidance else mask + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.add_noise( + init_latents_proper, noise, torch.tensor([noise_timestep]) + ) + + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + mask = callback_outputs.pop("mask", mask) + masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + condition_kwargs = {} + if isinstance(self.vae, AsymmetricAutoencoderKL): + init_image = init_image.to(device=device, dtype=masked_image_latents.dtype) + init_image_condition = init_image.clone() + init_image = self._encode_vae_image(init_image, generator=generator) + mask_condition = mask_condition.to(device=device, dtype=masked_image_latents.dtype) + condition_kwargs = {"image": init_image_condition, "mask": mask_condition} + image = self.vae.decode( + latents / self.vae.config.scaling_factor, return_dict=False, generator=generator, **condition_kwargs + )[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + if padding_mask_crop is not None: + image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image] + + # Offload all models + self.maybe_free_model_hooks() + + if self.do_perturbed_attention_guidance: + self.unet.set_attn_processor(original_attn_proc) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/pipelines/pag/pipeline_pag_sd_xl.py b/diffusers/pipelines/pag/pipeline_pag_sd_xl.py new file mode 100644 index 0000000000000000000000000000000000000000..2987c90626ef62f88b3db276547d25537a589a5e --- /dev/null +++ b/diffusers/pipelines/pag/pipeline_pag_sd_xl.py @@ -0,0 +1,1344 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import torch +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_invisible_watermark_available, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput +from .pag_utils import PAGMixin + + +if is_invisible_watermark_available(): + from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import AutoPipelineForText2Image + + >>> pipe = AutoPipelineForText2Image.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-1.0", + ... torch_dtype=torch.float16, + ... enable_pag=True, + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt, pag_scale=0.3).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusionXLPAGPipeline( + DiffusionPipeline, + StableDiffusionMixin, + FromSingleFileMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, + IPAdapterMixin, + PAGMixin, +): + r""" + Pipeline for text-to-image generation using Stable Diffusion XL. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion XL uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([` CLIPTextModelWithProjection`]): + Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): + Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of + `stabilityai/stable-diffusion-xl-base-1-0`. + add_watermarker (`bool`, *optional*): + Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to + watermark output images. If not defined, it will default to True if the package is installed, otherwise no + watermarker will be used. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "image_encoder", + "feature_extractor", + ] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "add_text_embeds", + "add_time_ids", + "negative_pooled_prompt_embeds", + "negative_add_time_ids", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: bool | None = None, + pag_applied_layers: str | list[str] = "mid", # ["mid"],["down.block_1"],["up.block_0.attentions_0"] + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + self.default_sample_size = ( + self.unet.config.sample_size + if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size") + else 128 + ) + + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None + + self.set_pag_applied_layers(pag_applied_layers) + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt: str, + prompt_2: str | None = None, + device: torch.device | None = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: str | None = None, + negative_prompt_2: str | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2: + pooled_prompt_embeds = prompt_embeds[0] + + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: list[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2: + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.check_inputs + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae + def upcast_vae(self): + deprecate( + "upcast_vae", + "1.0.0", + "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.", + ) + self.vae.to(dtype=torch.float32) + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def denoising_end(self): + return self._denoising_end + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + prompt_2: str | list[str] | None = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + timesteps: list[int] = None, + sigmas: list[float] = None, + denoising_end: float | None = None, + guidance_scale: float = 5.0, + negative_prompt: str | list[str] | None = None, + negative_prompt_2: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + guidance_rescale: float = 0.0, + original_size: tuple[int, int] | None = None, + crops_coords_top_left: tuple[int, int] = (0, 0), + target_size: tuple[int, int] | None = None, + negative_original_size: tuple[int, int] | None = None, + negative_crops_coords_top_left: tuple[int, int] = (0, 0), + negative_target_size: tuple[int, int] | None = None, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + pag_scale: float = 3.0, + pag_adaptive_scale: float = 0.0, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when + using zero terminal SNR. + original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + pag_scale (`float`, *optional*, defaults to 3.0): + The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention + guidance will not be used. + pag_adaptive_scale (`float`, *optional*, defaults to 0.0): + The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is + used. + + Examples: + + Returns: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + # 0. Default height and width to unet + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + None, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._denoising_end = denoising_end + self._interrupt = False + self._pag_scale = pag_scale + self._pag_adaptive_scale = pag_adaptive_scale + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=lora_scale, + clip_skip=self.clip_skip, + ) + + # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas + ) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids + + if self.do_perturbed_attention_guidance: + prompt_embeds = self._prepare_perturbed_attention_guidance( + prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance + ) + add_text_embeds = self._prepare_perturbed_attention_guidance( + add_text_embeds, negative_pooled_prompt_embeds, self.do_classifier_free_guidance + ) + add_time_ids = self._prepare_perturbed_attention_guidance( + add_time_ids, negative_add_time_ids, self.do_classifier_free_guidance + ) + + elif self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + for i, image_embeds in enumerate(ip_adapter_image_embeds): + negative_image_embeds = None + if self.do_classifier_free_guidance: + negative_image_embeds, image_embeds = image_embeds.chunk(2) + + if self.do_perturbed_attention_guidance: + image_embeds = self._prepare_perturbed_attention_guidance( + image_embeds, negative_image_embeds, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) + image_embeds = image_embeds.to(device) + ip_adapter_image_embeds[i] = image_embeds + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # 8.1 Apply denoising_end + if ( + self.denoising_end is not None + and isinstance(self.denoising_end, float) + and self.denoising_end > 0 + and self.denoising_end < 1 + ): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (self.denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + # 9. Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + if self.do_perturbed_attention_guidance: + original_attn_proc = self.unet.attn_processors + self._set_pag_attn_processor( + pag_applied_layers=self.pag_applied_layers, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance, perturbed-attention guidance, or both + latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + if ip_adapter_image_embeds is not None: + added_cond_kwargs["image_embeds"] = ip_adapter_image_embeds + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_perturbed_attention_guidance: + noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t, True + ) + + elif self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://huggingface.co/papers/2305.08891 + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) + negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + elif latents.dtype != self.vae.dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + self.vae = self.vae.to(latents.dtype) + + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None + has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None + if has_latents_mean and has_latents_std: + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean + else: + latents = latents / self.vae.config.scaling_factor + + image = self.vae.decode(latents, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + if not output_type == "latent": + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if self.do_perturbed_attention_guidance: + self.unet.set_attn_processor(original_attn_proc) + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py b/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py new file mode 100644 index 0000000000000000000000000000000000000000..433b9edc69b7c895c2aa645fc91480cd654957f5 --- /dev/null +++ b/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py @@ -0,0 +1,1539 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import PIL.Image +import torch +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_invisible_watermark_available, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import empty_device_cache, randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput +from .pag_utils import PAGMixin + + +if is_invisible_watermark_available(): + from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import AutoPipelineForImage2Image + >>> from diffusers.utils import load_image + + >>> pipe = AutoPipelineForImage2Image.from_pretrained( + ... "stabilityai/stable-diffusion-xl-refiner-1.0", + ... torch_dtype=torch.float16, + ... enable_pag=True, + ... ) + >>> pipe = pipe.to("cuda") + >>> url = "https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/aa_xl/000000009.png" + + >>> init_image = load_image(url).convert("RGB") + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt, image=init_image, pag_scale=0.3).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusionXLPAGImg2ImgPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + FromSingleFileMixin, + StableDiffusionXLLoraLoaderMixin, + IPAdapterMixin, + PAGMixin, +): + r""" + Pipeline for text-to-image generation using Stable Diffusion XL. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion XL uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([` CLIPTextModelWithProjection`]): + Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`): + Whether the `unet` requires an `aesthetic_score` condition to be passed during inference. Also see the + config of `stabilityai/stable-diffusion-xl-refiner-1-0`. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): + Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of + `stabilityai/stable-diffusion-xl-base-1-0`. + add_watermarker (`bool`, *optional*): + Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to + watermark output images. If not defined, it will default to True if the package is installed, otherwise no + watermarker will be used. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "image_encoder", + "feature_extractor", + ] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "add_text_embeds", + "add_time_ids", + "negative_pooled_prompt_embeds", + "add_neg_time_ids", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, + requires_aesthetics_score: bool = False, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: bool | None = None, + pag_applied_layers: str | list[str] = "mid", # ["mid"], ["down.block_1", "up.block_0.attentions_0"] + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + scheduler=scheduler, + ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.register_to_config(requires_aesthetics_score=requires_aesthetics_score) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None + + self.set_pag_applied_layers(pag_applied_layers) + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt: str, + prompt_2: str | None = None, + device: torch.device | None = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: str | None = None, + negative_prompt_2: str | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2: + pooled_prompt_embeds = prompt_embeds[0] + + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: list[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2: + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.check_inputs + def check_inputs( + self, + prompt, + prompt_2, + strength, + num_inference_steps, + callback_steps, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + if num_inference_steps is None: + raise ValueError("`num_inference_steps` cannot be None.") + elif not isinstance(num_inference_steps, int) or num_inference_steps <= 0: + raise ValueError( + f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type" + f" {type(num_inference_steps)}." + ) + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None): + # get the original timestep using init_timestep + if denoising_start is None: + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + t_start = max(num_inference_steps - init_timestep, 0) + + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + else: + # Strength is irrelevant if we directly request a timestep to start at; + # that is, strength is determined by the denoising_start instead. + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_start * self.scheduler.config.num_train_timesteps) + ) + ) + + num_inference_steps = (self.scheduler.timesteps < discrete_timestep_cutoff).sum().item() + if self.scheduler.order == 2 and num_inference_steps % 2 == 0: + # if the scheduler is a 2nd order scheduler we might have to do +1 + # because `num_inference_steps` might be even given that every timestep + # (except the highest one) is duplicated. If `num_inference_steps` is even it would + # mean that we cut the timesteps in the middle of the denoising step + # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1 + # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler + num_inference_steps = num_inference_steps + 1 + + # because t_n+1 >= t_n, we slice the timesteps starting from the end + t_start = len(self.scheduler.timesteps) - num_inference_steps + timesteps = self.scheduler.timesteps[t_start:] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start) + return timesteps, num_inference_steps + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents + def prepare_latents( + self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True + ): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + latents_mean = latents_std = None + if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: + latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: + latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) + + # Offload text encoder if `enable_model_cpu_offload` was enabled + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.text_encoder_2.to("cpu") + empty_device_cache() + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + + else: + # make sure the VAE is in float32 mode, as it overflows in float16 + if self.vae.config.force_upcast: + image = image.float() + self.vae.to(dtype=torch.float32) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: + image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) + elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " + ) + + init_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + if self.vae.config.force_upcast: + self.vae.to(dtype) + + init_latents = init_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=device, dtype=dtype) + latents_std = latents_std.to(device=device, dtype=dtype) + init_latents = (init_latents - latents_mean) * self.vae.config.scaling_factor / latents_std + else: + init_latents = self.vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + if add_noise: + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + + latents = init_latents + + return latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids + def _get_add_time_ids( + self, + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype, + text_encoder_projection_dim=None, + ): + if self.config.requires_aesthetics_score: + add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) + add_neg_time_ids = list( + negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,) + ) + else: + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if ( + expected_add_embed_dim > passed_add_embed_dim + and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." + ) + elif ( + expected_add_embed_dim < passed_add_embed_dim + and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." + ) + elif expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) + + return add_time_ids, add_neg_time_ids + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + deprecate( + "upcast_vae", + "1.0.0", + "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.", + ) + self.vae.to(dtype=torch.float32) + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def denoising_end(self): + return self._denoising_end + + @property + def denoising_start(self): + return self._denoising_start + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + prompt_2: str | list[str] | None = None, + image: PipelineImageInput = None, + strength: float = 0.3, + num_inference_steps: int = 50, + timesteps: list[int] = None, + sigmas: list[float] = None, + denoising_start: float | None = None, + denoising_end: float | None = None, + guidance_scale: float = 5.0, + negative_prompt: str | list[str] | None = None, + negative_prompt_2: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + guidance_rescale: float = 0.0, + original_size: tuple[int, int] = None, + crops_coords_top_left: tuple[int, int] = (0, 0), + target_size: tuple[int, int] = None, + negative_original_size: tuple[int, int] | None = None, + negative_crops_coords_top_left: tuple[int, int] = (0, 0), + negative_target_size: tuple[int, int] | None = None, + aesthetic_score: float = 6.0, + negative_aesthetic_score: float = 2.5, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + pag_scale: float = 3.0, + pag_adaptive_scale: float = 0.0, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + image (`torch.Tensor` or `PIL.Image.Image` or `np.ndarray` or `list[torch.Tensor]` or `list[PIL.Image.Image]` or `list[np.ndarray]`): + The image(s) to modify with the pipeline. + strength (`float`, *optional*, defaults to 0.3): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. Note that in the case of + `denoising_start` being declared as an integer, the value of `strength` will be ignored. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + denoising_start (`float`, *optional*): + When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be + bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and + it is assumed that the passed `image` is a partly denoised image. Note that when this is specified, + strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline + is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refine Image + Quality**](https://huggingface.co/docs/diffusers/using-diffusers/sdxl#refine-image-quality). + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise (ca. final 20% of timesteps still needed) and should be + denoised by a successor pipeline that has `denoising_start` set to 0.8 so that it only denoises the + final 20% of the scheduler. The denoising_end parameter should ideally be utilized when this pipeline + forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refine Image + Quality**](https://huggingface.co/docs/diffusers/using-diffusers/sdxl#refine-image-quality). + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when + using zero terminal SNR. + original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + aesthetic_score (`float`, *optional*, defaults to 6.0): + Used to simulate an aesthetic score of the generated image by influencing the positive text condition. + Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_aesthetic_score (`float`, *optional*, defaults to 2.5): + Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to + simulate an aesthetic score of the generated image by influencing the negative text condition. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + pag_scale (`float`, *optional*, defaults to 3.0): + The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention + guidance will not be used. + pag_adaptive_scale (`float`, *optional*, defaults to 0.0): + The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is + used. + + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple. When returning a tuple, the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + strength, + num_inference_steps, + None, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._denoising_end = denoising_end + self._denoising_start = denoising_start + self._interrupt = False + self._pag_scale = pag_scale + self._pag_adaptive_scale = pag_adaptive_scale + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + # 4. Preprocess image + image = self.image_processor.preprocess(image) + + # 5. Prepare timesteps + def denoising_value_valid(dnv): + return isinstance(dnv, float) and 0 < dnv < 1 + + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas + ) + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps, + strength, + device, + denoising_start=self.denoising_start if denoising_value_valid(self.denoising_start) else None, + ) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + add_noise = True if self.denoising_start is None else False + + # 6. Prepare latent variables + if latents is None: + latents = self.prepare_latents( + image, + latent_timestep, + batch_size, + num_images_per_prompt, + prompt_embeds.dtype, + device, + generator, + add_noise, + ) + # 7. Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + height, width = latents.shape[-2:] + height = height * self.vae_scale_factor + width = width * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 8. Prepare added time ids & embeddings + if negative_original_size is None: + negative_original_size = original_size + if negative_target_size is None: + negative_target_size = target_size + + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids, add_neg_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) + add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1) + + if self.do_perturbed_attention_guidance: + prompt_embeds = self._prepare_perturbed_attention_guidance( + prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance + ) + add_text_embeds = self._prepare_perturbed_attention_guidance( + add_text_embeds, negative_pooled_prompt_embeds, self.do_classifier_free_guidance + ) + add_time_ids = self._prepare_perturbed_attention_guidance( + add_time_ids, add_neg_time_ids, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + for i, image_embeds in enumerate(ip_adapter_image_embeds): + negative_image_embeds = None + if self.do_classifier_free_guidance: + negative_image_embeds, image_embeds = image_embeds.chunk(2) + + if self.do_perturbed_attention_guidance: + image_embeds = self._prepare_perturbed_attention_guidance( + image_embeds, negative_image_embeds, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) + image_embeds = image_embeds.to(device) + ip_adapter_image_embeds[i] = image_embeds + + # 9. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # 9.1 Apply denoising_end + if ( + self.denoising_end is not None + and self.denoising_start is not None + and denoising_value_valid(self.denoising_end) + and denoising_value_valid(self.denoising_start) + and self.denoising_start >= self.denoising_end + ): + raise ValueError( + f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: " + + f" {self.denoising_end} when using type float." + ) + elif self.denoising_end is not None and denoising_value_valid(self.denoising_end): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (self.denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + # 9.2 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + if self.do_perturbed_attention_guidance: + original_attn_proc = self.unet.attn_processors + self._set_pag_attn_processor( + pag_applied_layers=self.pag_applied_layers, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + if ip_adapter_image_embeds is not None: + added_cond_kwargs["image_embeds"] = ip_adapter_image_embeds + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_perturbed_attention_guidance: + noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t, True + ) + elif self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://huggingface.co/papers/2305.08891 + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) + add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + elif latents.dtype != self.vae.dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + self.vae = self.vae.to(latents.dtype) + + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None + has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None + if has_latents_mean and has_latents_std: + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean + else: + latents = latents / self.vae.config.scaling_factor + + image = self.vae.decode(latents, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if self.do_perturbed_attention_guidance: + self.unet.set_attn_processor(original_attn_proc) + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py b/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py new file mode 100644 index 0000000000000000000000000000000000000000..9caf50e5e333ec4ce9295e5a597a094c0d749087 --- /dev/null +++ b/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py @@ -0,0 +1,1770 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import PIL.Image +import torch +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_invisible_watermark_available, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput +from .pag_utils import PAGMixin + + +if is_invisible_watermark_available(): + from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import AutoPipelineForInpainting + >>> from diffusers.utils import load_image + + >>> pipe = AutoPipelineForInpainting.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-1.0", + ... torch_dtype=torch.float16, + ... variant="fp16", + ... enable_pag=True, + ... ) + >>> pipe.to("cuda") + + >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" + >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" + + >>> init_image = load_image(img_url).convert("RGB") + >>> mask_image = load_image(mask_url).convert("RGB") + + >>> prompt = "A majestic tiger sitting on a bench" + >>> image = pipe( + ... prompt=prompt, + ... image=init_image, + ... mask_image=mask_image, + ... num_inference_steps=50, + ... strength=0.80, + ... pag_scale=0.3, + ... ).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusionXLPAGInpaintPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionXLLoraLoaderMixin, + FromSingleFileMixin, + IPAdapterMixin, + PAGMixin, +): + r""" + Pipeline for text-to-image generation using Stable Diffusion XL. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion XL uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([` CLIPTextModelWithProjection`]): + Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`): + Whether the `unet` requires a aesthetic_score condition to be passed during inference. Also see the config + of `stabilityai/stable-diffusion-xl-refiner-1-0`. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): + Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of + `stabilityai/stable-diffusion-xl-base-1-0`. + add_watermarker (`bool`, *optional*): + Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to + watermark output images. If not defined, it will default to True if the package is installed, otherwise no + watermarker will be used. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" + + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "image_encoder", + "feature_extractor", + ] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "add_text_embeds", + "add_time_ids", + "negative_pooled_prompt_embeds", + "add_neg_time_ids", + "mask", + "masked_image_latents", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, + requires_aesthetics_score: bool = False, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: bool | None = None, + pag_applied_layers: str | list[str] = "mid", # ["mid"], ["down.block_1", "up.block_0.attentions_0"] + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + scheduler=scheduler, + ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.register_to_config(requires_aesthetics_score=requires_aesthetics_score) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True + ) + + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None + + self.set_pag_applied_layers(pag_applied_layers) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt: str, + prompt_2: str | None = None, + device: torch.device | None = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: str | None = None, + negative_prompt_2: str | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2: + pooled_prompt_embeds = prompt_embeds[0] + + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: list[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2: + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.check_inputs + def check_inputs( + self, + prompt, + prompt_2, + image, + mask_image, + height, + width, + strength, + callback_steps, + output_type, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + padding_mask_crop=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if padding_mask_crop is not None: + if not isinstance(image, PIL.Image.Image): + raise ValueError( + f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}." + ) + if not isinstance(mask_image, PIL.Image.Image): + raise ValueError( + f"The mask image should be a PIL image when inpainting mask crop, but is of type" + f" {type(mask_image)}." + ) + if output_type != "pil": + raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.") + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_latents + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + image=None, + timestep=None, + is_strength_max=True, + add_noise=True, + return_noise=False, + return_image_latents=False, + ): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if (image is None or timestep is None) and not is_strength_max: + raise ValueError( + "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." + "However, either the image or the noise timestep has not been provided." + ) + + if image.shape[1] == 4: + image_latents = image.to(device=device, dtype=dtype) + image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) + elif return_image_latents or (latents is None and not is_strength_max): + image = image.to(device=device, dtype=dtype) + image_latents = self._encode_vae_image(image=image, generator=generator) + image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) + + if latents is None and add_noise: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # if strength is 1. then initialise the latents to noise, else initial to image + noise + latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep) + # if pure noise then scale the initial latents by the Scheduler's init sigma + latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents + elif add_noise: + noise = latents.to(device) + latents = noise * self.scheduler.init_noise_sigma + else: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = image_latents.to(device) + + outputs = (latents,) + + if return_noise: + outputs += (noise,) + + if return_image_latents: + outputs += (image_latents,) + + return outputs + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + dtype = image.dtype + if self.vae.config.force_upcast: + image = image.float() + self.vae.to(dtype=torch.float32) + + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + if self.vae.config.force_upcast: + self.vae.to(dtype) + + image_latents = image_latents.to(dtype) + image_latents = self.vae.config.scaling_factor * image_latents + + return image_latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents + def prepare_mask_latents( + self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + + mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask + + if masked_image is not None and masked_image.shape[1] == 4: + masked_image_latents = masked_image + else: + masked_image_latents = None + + if masked_image is not None: + if masked_image_latents is None: + masked_image = masked_image.to(device=device, dtype=dtype) + masked_image_latents = self._encode_vae_image(masked_image, generator=generator) + + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat( + batch_size // masked_image_latents.shape[0], 1, 1, 1 + ) + + masked_image_latents = ( + torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + + return mask, masked_image_latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None): + # get the original timestep using init_timestep + if denoising_start is None: + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + t_start = max(num_inference_steps - init_timestep, 0) + + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + else: + # Strength is irrelevant if we directly request a timestep to start at; + # that is, strength is determined by the denoising_start instead. + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_start * self.scheduler.config.num_train_timesteps) + ) + ) + + num_inference_steps = (self.scheduler.timesteps < discrete_timestep_cutoff).sum().item() + if self.scheduler.order == 2 and num_inference_steps % 2 == 0: + # if the scheduler is a 2nd order scheduler we might have to do +1 + # because `num_inference_steps` might be even given that every timestep + # (except the highest one) is duplicated. If `num_inference_steps` is even it would + # mean that we cut the timesteps in the middle of the denoising step + # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1 + # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler + num_inference_steps = num_inference_steps + 1 + + # because t_n+1 >= t_n, we slice the timesteps starting from the end + t_start = len(self.scheduler.timesteps) - num_inference_steps + timesteps = self.scheduler.timesteps[t_start:] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start) + return timesteps, num_inference_steps + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids + def _get_add_time_ids( + self, + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype, + text_encoder_projection_dim=None, + ): + if self.config.requires_aesthetics_score: + add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) + add_neg_time_ids = list( + negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,) + ) + else: + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if ( + expected_add_embed_dim > passed_add_embed_dim + and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." + ) + elif ( + expected_add_embed_dim < passed_add_embed_dim + and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." + ) + elif expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) + + return add_time_ids, add_neg_time_ids + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + deprecate( + "upcast_vae", + "1.0.0", + "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.", + ) + self.vae.to(dtype=torch.float32) + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def denoising_end(self): + return self._denoising_end + + @property + def denoising_start(self): + return self._denoising_start + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + prompt_2: str | list[str] | None = None, + image: PipelineImageInput = None, + mask_image: PipelineImageInput = None, + masked_image_latents: torch.Tensor = None, + height: int | None = None, + width: int | None = None, + padding_mask_crop: int | None = None, + strength: float = 0.9999, + num_inference_steps: int = 50, + timesteps: list[int] = None, + sigmas: list[float] = None, + denoising_start: float | None = None, + denoising_end: float | None = None, + guidance_scale: float = 7.5, + negative_prompt: str | list[str] | None = None, + negative_prompt_2: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + guidance_rescale: float = 0.0, + original_size: tuple[int, int] = None, + crops_coords_top_left: tuple[int, int] = (0, 0), + target_size: tuple[int, int] = None, + negative_original_size: tuple[int, int] | None = None, + negative_crops_coords_top_left: tuple[int, int] = (0, 0), + negative_target_size: tuple[int, int] | None = None, + aesthetic_score: float = 6.0, + negative_aesthetic_score: float = 2.5, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + pag_scale: float = 3.0, + pag_adaptive_scale: float = 0.0, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will + be masked out with `mask_image` and repainted according to `prompt`. + mask_image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be + repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted + to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) + instead of 3, so the expected shape would be `(B, H, W, 1)`. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + padding_mask_crop (`int`, *optional*, defaults to `None`): + The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to + image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region + with the same aspect ration of the image and contains all masked area, and then expand that area based + on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before + resizing to the original image size for inpainting. This is useful when the masked area is small while + the image is large and contain information irrelevant for inpainting, such as background. + strength (`float`, *optional*, defaults to 0.9999): + Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be + between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the + `strength`. The number of denoising steps depends on the amount of noise initially added. When + `strength` is 1, added noise will be maximum and the denoising process will run for the full number of + iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked + portion of the reference `image`. Note that in the case of `denoising_start` being declared as an + integer, the value of `strength` will be ignored. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + denoising_start (`float`, *optional*): + When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be + bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and + it is assumed that the passed `image` is a partly denoised image. Note that when this is specified, + strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline + is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output). + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise (ca. final 20% of timesteps still needed) and should be + denoised by a successor pipeline that has `denoising_start` set to 0.8 so that it only denoises the + final 20% of the scheduler. The denoising_end parameter should ideally be utilized when this pipeline + forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output). + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + aesthetic_score (`float`, *optional*, defaults to 6.0): + Used to simulate an aesthetic score of the generated image by influencing the positive text condition. + Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_aesthetic_score (`float`, *optional*, defaults to 2.5): + Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to + simulate an aesthetic score of the generated image by influencing the negative text condition. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + pag_scale (`float`, *optional*, defaults to 3.0): + The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention + guidance will not be used. + pag_adaptive_scale (`float`, *optional*, defaults to 0.0): + The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is + used. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple. `tuple. When returning a tuple, the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs + self.check_inputs( + prompt, + prompt_2, + image, + mask_image, + height, + width, + strength, + None, + output_type, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, + padding_mask_crop, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._denoising_end = denoising_end + self._denoising_start = denoising_start + self._interrupt = False + self._pag_scale = pag_scale + self._pag_adaptive_scale = pag_adaptive_scale + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + # 4. set timesteps + def denoising_value_valid(dnv): + return isinstance(dnv, float) and 0 < dnv < 1 + + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas + ) + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps, + strength, + device, + denoising_start=self.denoising_start if denoising_value_valid(self.denoising_start) else None, + ) + # check that number of inference steps is not < 1 - as this doesn't make sense + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise + is_strength_max = strength == 1.0 + + # 5. Preprocess mask and image + if padding_mask_crop is not None: + crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop) + resize_mode = "fill" + else: + crops_coords = None + resize_mode = "default" + + original_image = image + init_image = self.image_processor.preprocess( + image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode + ) + init_image = init_image.to(dtype=torch.float32) + + mask = self.mask_processor.preprocess( + mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords + ) + + if masked_image_latents is not None: + masked_image = masked_image_latents + elif init_image.shape[1] == 4: + # if images are in latent space, we can't mask it + masked_image = None + else: + masked_image = init_image * (mask < 0.5) + + # 6. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + num_channels_unet = self.unet.config.in_channels + return_image_latents = num_channels_unet == 4 + + add_noise = True if self.denoising_start is None else False + latents_outputs = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + image=init_image, + timestep=latent_timestep, + is_strength_max=is_strength_max, + add_noise=add_noise, + return_noise=True, + return_image_latents=return_image_latents, + ) + + if return_image_latents: + latents, noise, image_latents = latents_outputs + else: + latents, noise = latents_outputs + + # 7. Prepare mask latent variables + mask, masked_image_latents = self.prepare_mask_latents( + mask, + masked_image, + batch_size * num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + self.do_classifier_free_guidance, + ) + if self.do_perturbed_attention_guidance: + if self.do_classifier_free_guidance: + mask, _ = mask.chunk(2) + masked_image_latents, _ = masked_image_latents.chunk(2) + mask = self._prepare_perturbed_attention_guidance(mask, mask, self.do_classifier_free_guidance) + masked_image_latents = self._prepare_perturbed_attention_guidance( + masked_image_latents, masked_image_latents, self.do_classifier_free_guidance + ) + + # 8. Check that sizes of mask, masked image and latents match + if num_channels_unet == 9: + # default case for stable-diffusion-v1-5/stable-diffusion-inpainting + num_channels_mask = mask.shape[1] + num_channels_masked_image = masked_image_latents.shape[1] + if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of" + " `pipeline.unet` or your `mask_image` or `image` input." + ) + elif num_channels_unet != 4: + raise ValueError( + f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}." + ) + # 8.1 Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + height, width = latents.shape[-2:] + height = height * self.vae_scale_factor + width = width * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 10. Prepare added time ids & embeddings + if negative_original_size is None: + negative_original_size = original_size + if negative_target_size is None: + negative_target_size = target_size + + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids, add_neg_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) + add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1) + + if self.do_perturbed_attention_guidance: + prompt_embeds = self._prepare_perturbed_attention_guidance( + prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance + ) + add_text_embeds = self._prepare_perturbed_attention_guidance( + add_text_embeds, negative_pooled_prompt_embeds, self.do_classifier_free_guidance + ) + add_time_ids = self._prepare_perturbed_attention_guidance( + add_time_ids, add_neg_time_ids, self.do_classifier_free_guidance + ) + + elif self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + for i, image_embeds in enumerate(ip_adapter_image_embeds): + negative_image_embeds = None + if self.do_classifier_free_guidance: + negative_image_embeds, image_embeds = image_embeds.chunk(2) + + if self.do_perturbed_attention_guidance: + image_embeds = self._prepare_perturbed_attention_guidance( + image_embeds, negative_image_embeds, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) + image_embeds = image_embeds.to(device) + ip_adapter_image_embeds[i] = image_embeds + + # 11. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + if ( + self.denoising_end is not None + and self.denoising_start is not None + and denoising_value_valid(self.denoising_end) + and denoising_value_valid(self.denoising_start) + and self.denoising_start >= self.denoising_end + ): + raise ValueError( + f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: " + + f" {self.denoising_end} when using type float." + ) + elif self.denoising_end is not None and denoising_value_valid(self.denoising_end): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (self.denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + # 11.1 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + if self.do_perturbed_attention_guidance: + original_attn_proc = self.unet.attn_processors + self._set_pag_attn_processor( + pag_applied_layers=self.pag_applied_layers, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) + + # concat latents, mask, masked_image_latents in the channel dimension + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + if num_channels_unet == 9: + latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + added_cond_kwargs["image_embeds"] = ip_adapter_image_embeds + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_perturbed_attention_guidance: + noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t, True + ) + elif self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://huggingface.co/papers/2305.08891 + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if num_channels_unet == 4: + init_latents_proper = image_latents + if self.do_perturbed_attention_guidance: + init_mask, *_ = mask.chunk(3) if self.do_classifier_free_guidance else mask.chunk(2) + else: + init_mask, *_ = mask.chunk(2) if self.do_classifier_free_guidance else mask + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.add_noise( + init_latents_proper, noise, torch.tensor([noise_timestep]) + ) + + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) + add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids) + mask = callback_outputs.pop("mask", mask) + masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + elif latents.dtype != self.vae.dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + self.vae = self.vae.to(latents.dtype) + + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None + has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None + if has_latents_mean and has_latents_std: + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean + else: + latents = latents / self.vae.config.scaling_factor + + image = self.vae.decode(latents, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + return StableDiffusionXLPipelineOutput(images=latents) + + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + if padding_mask_crop is not None: + image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image] + + # Offload all models + self.maybe_free_model_hooks() + + if self.do_perturbed_attention_guidance: + self.unet.set_attn_processor(original_attn_proc) + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/diffusers/pipelines/paint_by_example/__init__.py b/diffusers/pipelines/paint_by_example/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d2906b540c6eeb6f463340c8f856611e6c5dbd2f --- /dev/null +++ b/diffusers/pipelines/paint_by_example/__init__.py @@ -0,0 +1,55 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import numpy as np +import PIL +from PIL import Image + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["image_encoder"] = ["PaintByExampleImageEncoder"] + _import_structure["pipeline_paint_by_example"] = ["PaintByExamplePipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .image_encoder import PaintByExampleImageEncoder + from .pipeline_paint_by_example import PaintByExamplePipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/pipelines/paint_by_example/image_encoder.py b/diffusers/pipelines/paint_by_example/image_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..74c575ed8653d110ca300e9198a5b13070659adf --- /dev/null +++ b/diffusers/pipelines/paint_by_example/image_encoder.py @@ -0,0 +1,67 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from torch import nn +from transformers import CLIPPreTrainedModel, CLIPVisionModel + +from ...models.attention import BasicTransformerBlock +from ...utils import logging + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class PaintByExampleImageEncoder(CLIPPreTrainedModel): + def __init__(self, config, proj_size=None): + super().__init__(config) + self.proj_size = proj_size or getattr(config, "projection_dim", 768) + + self.model = CLIPVisionModel(config) + self.mapper = PaintByExampleMapper(config) + self.final_layer_norm = nn.LayerNorm(config.hidden_size) + self.proj_out = nn.Linear(config.hidden_size, self.proj_size) + + # uncondition for scaling + self.uncond_vector = nn.Parameter(torch.randn((1, 1, self.proj_size))) + + def forward(self, pixel_values, return_uncond_vector=False): + clip_output = self.model(pixel_values=pixel_values) + latent_states = clip_output.pooler_output + latent_states = self.mapper(latent_states[:, None]) + latent_states = self.final_layer_norm(latent_states) + latent_states = self.proj_out(latent_states) + if return_uncond_vector: + return latent_states, self.uncond_vector + + return latent_states + + +class PaintByExampleMapper(nn.Module): + def __init__(self, config): + super().__init__() + num_layers = (config.num_hidden_layers + 1) // 5 + hid_size = config.hidden_size + num_heads = 1 + self.blocks = nn.ModuleList( + [ + BasicTransformerBlock(hid_size, num_heads, hid_size, activation_fn="gelu", attention_bias=True) + for _ in range(num_layers) + ] + ) + + def forward(self, hidden_states): + for block in self.blocks: + hidden_states = block(hidden_states) + + return hidden_states diff --git a/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py b/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py new file mode 100644 index 0000000000000000000000000000000000000000..aa7dbaa720e570edbed4acf497cb01c9a9339935 --- /dev/null +++ b/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py @@ -0,0 +1,633 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable + +import numpy as np +import PIL.Image +import torch +from transformers import CLIPImageProcessor + +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKL, UNet2DConditionModel +from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...utils import deprecate, is_torch_xla_available, logging +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion import StableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from .image_encoder import PaintByExampleImageEncoder + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +def prepare_mask_and_masked_image(image, mask): + """ + Prepares a pair (image, mask) to be consumed by the Paint by Example pipeline. This means that those inputs will be + converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the + ``image`` and ``1`` for the ``mask``. + + The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be + binarized (``mask > 0.5``) and cast to ``torch.float32`` too. + + Args: + image (np.array | PIL.Image | torch.Tensor): The image to inpaint. + It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width`` + ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``. + mask (_type_): The mask to apply to the image, i.e. regions to inpaint. + It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width`` + ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``. + + + Raises: + ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask + should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions. + TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not + (ot the other way around). + + Returns: + tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4 + dimensions: ``batch x channels x height x width``. + """ + if isinstance(image, torch.Tensor): + if not isinstance(mask, torch.Tensor): + raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not") + + # Batch single image + if image.ndim == 3: + assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)" + image = image.unsqueeze(0) + + # Batch and add channel dim for single mask + if mask.ndim == 2: + mask = mask.unsqueeze(0).unsqueeze(0) + + # Batch single mask or add channel dim + if mask.ndim == 3: + # Batched mask + if mask.shape[0] == image.shape[0]: + mask = mask.unsqueeze(1) + else: + mask = mask.unsqueeze(0) + + assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions" + assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions" + assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size" + assert mask.shape[1] == 1, "Mask image must have a single channel" + + # Check image is in [-1, 1] + if image.min() < -1 or image.max() > 1: + raise ValueError("Image should be in [-1, 1] range") + + # Check mask is in [0, 1] + if mask.min() < 0 or mask.max() > 1: + raise ValueError("Mask should be in [0, 1] range") + + # paint-by-example inverses the mask + mask = 1 - mask + + # Binarize mask + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + + # Image as float32 + image = image.to(dtype=torch.float32) + elif isinstance(mask, torch.Tensor): + raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not") + else: + if isinstance(image, PIL.Image.Image): + image = [image] + + image = np.concatenate([np.array(i.convert("RGB"))[None, :] for i in image], axis=0) + image = image.transpose(0, 3, 1, 2) + image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 + + # preprocess mask + if isinstance(mask, PIL.Image.Image): + mask = [mask] + + mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) + mask = mask.astype(np.float32) / 255.0 + + # paint-by-example inverses the mask + mask = 1 - mask + + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + mask = torch.from_numpy(mask) + + masked_image = image * mask + + return mask, masked_image + + +class PaintByExamplePipeline(DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin): + _last_supported_version = "0.33.1" + r""" + > [!WARNING] > 🧪 This is an experimental feature! + + Pipeline for image-guided image inpainting using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + image_encoder ([`PaintByExampleImageEncoder`]): + Encodes the example input image. The `unet` is conditioned on the example image instead of a text prompt. + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for + more details about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + + """ + + # TODO: feature_extractor is required to encode initial images (if they are in PIL format), + # we should give a descriptive message if the pipeline doesn't have one. + + model_cpu_offload_seq = "unet->vae" + _exclude_from_cpu_offload = ["image_encoder"] + _optional_components = ["safety_checker"] + + def __init__( + self, + vae: AutoencoderKL, + image_encoder: PaintByExampleImageEncoder, + unet: UNet2DConditionModel, + scheduler: DDIMScheduler | PNDMScheduler | LMSDiscreteScheduler, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = False, + ): + super().__init__() + + self.register_modules( + vae=vae, + image_encoder=image_encoder, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_image_variation.StableDiffusionImageVariationPipeline.check_inputs + def check_inputs(self, image, height, width, callback_steps): + if ( + not isinstance(image, torch.Tensor) + and not isinstance(image, PIL.Image.Image) + and not isinstance(image, list) + ): + raise ValueError( + "`image` has to be of type `torch.Tensor` or `PIL.Image.Image` or `list[PIL.Image.Image]` but is" + f" {type(image)}" + ) + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_mask_latents + def prepare_mask_latents( + self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + masked_image = masked_image.to(device=device, dtype=dtype) + + if masked_image.shape[1] == 4: + masked_image_latents = masked_image + else: + masked_image_latents = self._encode_vae_image(masked_image, generator=generator) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) + + mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask + masked_image_latents = ( + torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + return mask, masked_image_latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + image_latents = self.vae.config.scaling_factor * image_latents + + return image_latents + + def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(images=image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeddings, negative_prompt_embeds = self.image_encoder(image, return_uncond_vector=True) + + # duplicate image embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = image_embeddings.shape + image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1) + image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + negative_prompt_embeds = negative_prompt_embeds.repeat(1, image_embeddings.shape[0], 1) + negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, 1, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings]) + + return image_embeddings + + @torch.no_grad() + def __call__( + self, + example_image: torch.Tensor | PIL.Image.Image, + image: torch.Tensor | PIL.Image.Image, + mask_image: torch.Tensor | PIL.Image.Image, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + ): + r""" + The call function to the pipeline for generation. + + Args: + example_image (`torch.Tensor` or `PIL.Image.Image` or `list[PIL.Image.Image]`): + An example image to guide image generation. + image (`torch.Tensor` or `PIL.Image.Image` or `list[PIL.Image.Image]`): + `Image` or tensor representing an image batch to be inpainted (parts of the image are masked out with + `mask_image` and repainted according to `prompt`). + mask_image (`torch.Tensor` or `PIL.Image.Image` or `list[PIL.Image.Image]`): + `Image` or tensor representing an image batch to mask `image`. White pixels in the mask are repainted, + while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a single channel + (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, so the + expected shape would be `(B, H, W, 1)`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + + Example: + + ```py + >>> import PIL + >>> import requests + >>> import torch + >>> from io import BytesIO + >>> from diffusers import PaintByExamplePipeline + + + >>> def download_image(url): + ... response = requests.get(url) + ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") + + + >>> img_url = ( + ... "https://raw.githubusercontent.com/Fantasy-Studio/Paint-by-Example/main/examples/image/example_1.png" + ... ) + >>> mask_url = ( + ... "https://raw.githubusercontent.com/Fantasy-Studio/Paint-by-Example/main/examples/mask/example_1.png" + ... ) + >>> example_url = "https://raw.githubusercontent.com/Fantasy-Studio/Paint-by-Example/main/examples/reference/example_1.jpg" + + >>> init_image = download_image(img_url).resize((512, 512)) + >>> mask_image = download_image(mask_url).resize((512, 512)) + >>> example_image = download_image(example_url).resize((512, 512)) + + >>> pipe = PaintByExamplePipeline.from_pretrained( + ... "Fantasy-Studio/Paint-by-Example", + ... torch_dtype=torch.float16, + ... ) + >>> pipe = pipe.to("cuda") + + >>> image = pipe(image=init_image, mask_image=mask_image, example_image=example_image).images[0] + >>> image + ``` + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + # 1. Define call parameters + if isinstance(image, PIL.Image.Image): + batch_size = 1 + elif isinstance(image, list): + batch_size = len(image) + else: + batch_size = image.shape[0] + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 2. Preprocess mask and image + mask, masked_image = prepare_mask_and_masked_image(image, mask_image) + height, width = masked_image.shape[-2:] + + # 3. Check inputs + self.check_inputs(example_image, height, width, callback_steps) + + # 4. Encode input image + image_embeddings = self._encode_image( + example_image, device, num_images_per_prompt, do_classifier_free_guidance + ) + + # 5. set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 6. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + image_embeddings.dtype, + device, + generator, + latents, + ) + + # 7. Prepare mask latent variables + mask, masked_image_latents = self.prepare_mask_latents( + mask, + masked_image, + batch_size * num_images_per_prompt, + height, + width, + image_embeddings.dtype, + device, + generator, + do_classifier_free_guidance, + ) + + # 8. Check that sizes of mask, masked image and latents match + num_channels_mask = mask.shape[1] + num_channels_masked_image = masked_image_latents.shape[1] + if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of" + " `pipeline.unet` or your `mask_image` or `image` input." + ) + + # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 10. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + # concat latents, mask, masked_image_latents in the channel dimension + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = torch.cat([latent_model_input, masked_image_latents, mask], dim=1) + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=image_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + self.maybe_free_model_hooks() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/pipelines/pixart_alpha/__init__.py b/diffusers/pipelines/pixart_alpha/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7cb870fc8589aeff366429c182736b2ddd6ce215 --- /dev/null +++ b/diffusers/pipelines/pixart_alpha/__init__.py @@ -0,0 +1,55 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_pixart_alpha"] = ["PixArtAlphaPipeline"] + _import_structure["pipeline_pixart_sigma"] = ["PixArtSigmaPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_pixart_alpha import ( + ASPECT_RATIO_256_BIN, + ASPECT_RATIO_512_BIN, + ASPECT_RATIO_1024_BIN, + PixArtAlphaPipeline, + ) + from .pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN, PixArtSigmaPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py b/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py new file mode 100644 index 0000000000000000000000000000000000000000..604e51d885834e507b852dfcca588260b2907178 --- /dev/null +++ b/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py @@ -0,0 +1,980 @@ +# Copyright 2025 PixArt-Alpha Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import inspect +import re +import urllib.parse as ul +from typing import Callable + +import torch +from transformers import T5EncoderModel, T5Tokenizer + +from ...image_processor import PixArtImageProcessor +from ...models import AutoencoderKL, PixArtTransformer2DModel +from ...schedulers import DPMSolverMultistepScheduler +from ...utils import ( + BACKENDS_MAPPING, + deprecate, + is_bs4_available, + is_ftfy_available, + is_torch_xla_available, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import PixArtAlphaPipeline + + >>> # You can replace the checkpoint id with "PixArt-alpha/PixArt-XL-2-512x512" too. + >>> pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16) + >>> # Enable memory optimizations. + >>> pipe.enable_model_cpu_offload() + + >>> prompt = "A small cactus with a happy face in the Sahara desert." + >>> image = pipe(prompt).images[0] + ``` +""" + +ASPECT_RATIO_1024_BIN = { + "0.25": [512.0, 2048.0], + "0.28": [512.0, 1856.0], + "0.32": [576.0, 1792.0], + "0.33": [576.0, 1728.0], + "0.35": [576.0, 1664.0], + "0.4": [640.0, 1600.0], + "0.42": [640.0, 1536.0], + "0.48": [704.0, 1472.0], + "0.5": [704.0, 1408.0], + "0.52": [704.0, 1344.0], + "0.57": [768.0, 1344.0], + "0.6": [768.0, 1280.0], + "0.68": [832.0, 1216.0], + "0.72": [832.0, 1152.0], + "0.78": [896.0, 1152.0], + "0.82": [896.0, 1088.0], + "0.88": [960.0, 1088.0], + "0.94": [960.0, 1024.0], + "1.0": [1024.0, 1024.0], + "1.07": [1024.0, 960.0], + "1.13": [1088.0, 960.0], + "1.21": [1088.0, 896.0], + "1.29": [1152.0, 896.0], + "1.38": [1152.0, 832.0], + "1.46": [1216.0, 832.0], + "1.67": [1280.0, 768.0], + "1.75": [1344.0, 768.0], + "2.0": [1408.0, 704.0], + "2.09": [1472.0, 704.0], + "2.4": [1536.0, 640.0], + "2.5": [1600.0, 640.0], + "3.0": [1728.0, 576.0], + "4.0": [2048.0, 512.0], +} + +ASPECT_RATIO_512_BIN = { + "0.25": [256.0, 1024.0], + "0.28": [256.0, 928.0], + "0.32": [288.0, 896.0], + "0.33": [288.0, 864.0], + "0.35": [288.0, 832.0], + "0.4": [320.0, 800.0], + "0.42": [320.0, 768.0], + "0.48": [352.0, 736.0], + "0.5": [352.0, 704.0], + "0.52": [352.0, 672.0], + "0.57": [384.0, 672.0], + "0.6": [384.0, 640.0], + "0.68": [416.0, 608.0], + "0.72": [416.0, 576.0], + "0.78": [448.0, 576.0], + "0.82": [448.0, 544.0], + "0.88": [480.0, 544.0], + "0.94": [480.0, 512.0], + "1.0": [512.0, 512.0], + "1.07": [512.0, 480.0], + "1.13": [544.0, 480.0], + "1.21": [544.0, 448.0], + "1.29": [576.0, 448.0], + "1.38": [576.0, 416.0], + "1.46": [608.0, 416.0], + "1.67": [640.0, 384.0], + "1.75": [672.0, 384.0], + "2.0": [704.0, 352.0], + "2.09": [736.0, 352.0], + "2.4": [768.0, 320.0], + "2.5": [800.0, 320.0], + "3.0": [864.0, 288.0], + "4.0": [1024.0, 256.0], +} + +ASPECT_RATIO_256_BIN = { + "0.25": [128.0, 512.0], + "0.28": [128.0, 464.0], + "0.32": [144.0, 448.0], + "0.33": [144.0, 432.0], + "0.35": [144.0, 416.0], + "0.4": [160.0, 400.0], + "0.42": [160.0, 384.0], + "0.48": [176.0, 368.0], + "0.5": [176.0, 352.0], + "0.52": [176.0, 336.0], + "0.57": [192.0, 336.0], + "0.6": [192.0, 320.0], + "0.68": [208.0, 304.0], + "0.72": [208.0, 288.0], + "0.78": [224.0, 288.0], + "0.82": [224.0, 272.0], + "0.88": [240.0, 272.0], + "0.94": [240.0, 256.0], + "1.0": [256.0, 256.0], + "1.07": [256.0, 240.0], + "1.13": [272.0, 240.0], + "1.21": [272.0, 224.0], + "1.29": [288.0, 224.0], + "1.38": [288.0, 208.0], + "1.46": [304.0, 208.0], + "1.67": [320.0, 192.0], + "1.75": [336.0, 192.0], + "2.0": [352.0, 176.0], + "2.09": [368.0, 176.0], + "2.4": [384.0, 160.0], + "2.5": [400.0, 160.0], + "3.0": [432.0, 144.0], + "4.0": [512.0, 128.0], +} + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class PixArtAlphaPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using PixArt-Alpha. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. PixArt-Alpha uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the + [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. + tokenizer (`T5Tokenizer`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`PixArtTransformer2DModel`]): + A text conditioned `PixArtTransformer2DModel` to denoise the encoded image latents. Initially published as + [`Transformer2DModel`](https://huggingface.co/PixArt-alpha/PixArt-XL-2-1024-MS/blob/main/transformer/config.json#L2) + in the config, but the mismatch can be ignored. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + """ + + bad_punct_regex = re.compile( + r"[" + + "#®•©™&@·º½¾¿¡§~" + + r"\)" + + r"\(" + + r"\]" + + r"\[" + + r"\}" + + r"\{" + + r"\|" + + "\\" + + r"\/" + + r"\*" + + r"]{1,}" + ) # noqa + + _optional_components = ["tokenizer", "text_encoder"] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKL, + transformer: PixArtTransformer2DModel, + scheduler: DPMSolverMultistepScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) + + # Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + do_classifier_free_guidance: bool = True, + negative_prompt: str = "", + num_images_per_prompt: int = 1, + device: torch.device | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + clean_caption: bool = False, + max_sequence_length: int = 120, + **kwargs, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For + PixArt-Alpha, this should be "". + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the "" + string. + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 120): Maximum sequence length to use for the prompt. + """ + + if "mask_feature" in kwargs: + deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version." + deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False) + + if device is None: + device = self._execution_device + + # See Section 3.1. of the paper. + max_length = max_sequence_length + + if prompt_embeds is None: + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because T5 can only handle sequences up to" + f" {max_length} tokens: {removed_text}" + ) + + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.to(device) + + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask) + prompt_embeds = prompt_embeds[0] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.transformer is not None: + dtype = self.transformer.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.repeat(1, num_images_per_prompt) + prompt_attention_mask = prompt_attention_mask.view(bs_embed * num_images_per_prompt, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens = [negative_prompt] * bs_embed if isinstance(negative_prompt, str) else negative_prompt + uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + negative_prompt_attention_mask = uncond_input.attention_mask + negative_prompt_attention_mask = negative_prompt_attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(1, num_images_per_prompt) + negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed * num_images_per_prompt, -1) + else: + negative_prompt_embeds = None + negative_prompt_attention_mask = None + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_steps, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip addresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str = "", + num_inference_steps: int = 20, + timesteps: list[int] = None, + sigmas: list[float] = None, + guidance_scale: float = 4.5, + num_images_per_prompt: int | None = 1, + height: int | None = None, + width: int | None = None, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + clean_caption: bool = True, + use_resolution_binning: bool = True, + max_sequence_length: int = 120, + **kwargs, + ) -> ImagePipelineOutput | tuple: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 4.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*, defaults to self.unet.config.sample_size): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size): + The width in pixels of the generated image. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Alpha this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + use_resolution_binning (`bool` defaults to `True`): + If set to `True`, the requested height and width are first mapped to the closest resolutions using + `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to + the requested resolution. Useful for generating non-square images. + max_sequence_length (`int` defaults to 120): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images + """ + if "mask_feature" in kwargs: + deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version." + deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False) + # 1. Check inputs. Raise error if not correct + height = height or self.transformer.config.sample_size * self.vae_scale_factor + width = width or self.transformer.config.sample_size * self.vae_scale_factor + if use_resolution_binning: + if self.transformer.config.sample_size == 128: + aspect_ratio_bin = ASPECT_RATIO_1024_BIN + elif self.transformer.config.sample_size == 64: + aspect_ratio_bin = ASPECT_RATIO_512_BIN + elif self.transformer.config.sample_size == 32: + aspect_ratio_bin = ASPECT_RATIO_256_BIN + else: + raise ValueError("Invalid sample size") + orig_height, orig_width = height, width + height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin) + + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_steps, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) + + # 2. Default height and width to transformer + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt, + do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas + ) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 6.1 Prepare micro-conditions. + added_cond_kwargs = {"resolution": None, "aspect_ratio": None} + if self.transformer.config.sample_size == 128: + resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1) + aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1) + resolution = resolution.to(dtype=prompt_embeds.dtype, device=device) + aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device) + + if do_classifier_free_guidance: + resolution = torch.cat([resolution, resolution], dim=0) + aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0) + + added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio} + + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + current_timestep = t + if not torch.is_tensor(current_timestep): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = latent_model_input.device.type == "mps" + is_npu = latent_model_input.device.type == "npu" + if isinstance(current_timestep, float): + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + else: + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 + current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device) + elif len(current_timestep.shape) == 0: + current_timestep = current_timestep[None].to(latent_model_input.device) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + current_timestep = current_timestep.expand(latent_model_input.shape[0]) + + # predict noise model_output + noise_pred = self.transformer( + latent_model_input, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + timestep=current_timestep, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # learned sigma + if self.transformer.config.out_channels // 2 == latent_channels: + noise_pred = noise_pred.chunk(2, dim=1)[0] + else: + noise_pred = noise_pred + + # compute previous image: x_t -> x_t-1 + if num_inference_steps == 1: + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[1] + else: + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + if use_resolution_binning: + image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height) + else: + image = latents + + if not output_type == "latent": + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py b/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py new file mode 100644 index 0000000000000000000000000000000000000000..286695aa8eb98d49ef40f6dfc8349496fac7df67 --- /dev/null +++ b/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py @@ -0,0 +1,910 @@ +# Copyright 2025 PixArt-Sigma Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import inspect +import re +import urllib.parse as ul +from typing import Callable + +import torch +from transformers import T5EncoderModel, T5Tokenizer + +from ...image_processor import PixArtImageProcessor +from ...models import AutoencoderKL, PixArtTransformer2DModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + BACKENDS_MAPPING, + deprecate, + is_bs4_available, + is_ftfy_available, + is_torch_xla_available, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from .pipeline_pixart_alpha import ( + ASPECT_RATIO_256_BIN, + ASPECT_RATIO_512_BIN, + ASPECT_RATIO_1024_BIN, +) + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + + +ASPECT_RATIO_2048_BIN = { + "0.25": [1024.0, 4096.0], + "0.26": [1024.0, 3968.0], + "0.27": [1024.0, 3840.0], + "0.28": [1024.0, 3712.0], + "0.32": [1152.0, 3584.0], + "0.33": [1152.0, 3456.0], + "0.35": [1152.0, 3328.0], + "0.4": [1280.0, 3200.0], + "0.42": [1280.0, 3072.0], + "0.48": [1408.0, 2944.0], + "0.5": [1408.0, 2816.0], + "0.52": [1408.0, 2688.0], + "0.57": [1536.0, 2688.0], + "0.6": [1536.0, 2560.0], + "0.68": [1664.0, 2432.0], + "0.72": [1664.0, 2304.0], + "0.78": [1792.0, 2304.0], + "0.82": [1792.0, 2176.0], + "0.88": [1920.0, 2176.0], + "0.94": [1920.0, 2048.0], + "1.0": [2048.0, 2048.0], + "1.07": [2048.0, 1920.0], + "1.13": [2176.0, 1920.0], + "1.21": [2176.0, 1792.0], + "1.29": [2304.0, 1792.0], + "1.38": [2304.0, 1664.0], + "1.46": [2432.0, 1664.0], + "1.67": [2560.0, 1536.0], + "1.75": [2688.0, 1536.0], + "2.0": [2816.0, 1408.0], + "2.09": [2944.0, 1408.0], + "2.4": [3072.0, 1280.0], + "2.5": [3200.0, 1280.0], + "2.89": [3328.0, 1152.0], + "3.0": [3456.0, 1152.0], + "3.11": [3584.0, 1152.0], + "3.62": [3712.0, 1024.0], + "3.75": [3840.0, 1024.0], + "3.88": [3968.0, 1024.0], + "4.0": [4096.0, 1024.0], +} + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import PixArtSigmaPipeline + + >>> # You can replace the checkpoint id with "PixArt-alpha/PixArt-Sigma-XL-2-512-MS" too. + >>> pipe = PixArtSigmaPipeline.from_pretrained( + ... "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", torch_dtype=torch.float16 + ... ) + >>> # Enable memory optimizations. + >>> # pipe.enable_model_cpu_offload() + + >>> prompt = "A small cactus with a happy face in the Sahara desert." + >>> image = pipe(prompt).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class PixArtSigmaPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using PixArt-Sigma. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. PixArt-Alpha uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the + [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. + tokenizer (`T5Tokenizer`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`PixArtTransformer2DModel`]): + A text conditioned `PixArtTransformer2DModel` to denoise the encoded image latents. Initially published as + [`Transformer2DModel`](https://huggingface.co/PixArt-alpha/PixArt-Sigma-XL-2-1024-MS/blob/main/transformer/config.json#L2) + in the config, but the mismatch can be ignored. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + """ + + bad_punct_regex = re.compile( + r"[" + + "#®•©™&@·º½¾¿¡§~" + + r"\)" + + r"\(" + + r"\]" + + r"\[" + + r"\}" + + r"\{" + + r"\|" + + "\\" + + r"\/" + + r"\*" + + r"]{1,}" + ) # noqa + + _optional_components = ["tokenizer", "text_encoder"] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKL, + transformer: PixArtTransformer2DModel, + scheduler: KarrasDiffusionSchedulers, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) + + # Copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.PixArtAlphaPipeline.encode_prompt with 120->300 + def encode_prompt( + self, + prompt: str | list[str], + do_classifier_free_guidance: bool = True, + negative_prompt: str = "", + num_images_per_prompt: int = 1, + device: torch.device | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + clean_caption: bool = False, + max_sequence_length: int = 300, + **kwargs, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For + PixArt-Alpha, this should be "". + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the "" + string. + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt. + """ + + if "mask_feature" in kwargs: + deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version." + deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False) + + if device is None: + device = self._execution_device + + # See Section 3.1. of the paper. + max_length = max_sequence_length + + if prompt_embeds is None: + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because T5 can only handle sequences up to" + f" {max_length} tokens: {removed_text}" + ) + + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.to(device) + + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask) + prompt_embeds = prompt_embeds[0] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.transformer is not None: + dtype = self.transformer.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.repeat(1, num_images_per_prompt) + prompt_attention_mask = prompt_attention_mask.view(bs_embed * num_images_per_prompt, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens = [negative_prompt] * bs_embed if isinstance(negative_prompt, str) else negative_prompt + uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + negative_prompt_attention_mask = uncond_input.attention_mask + negative_prompt_attention_mask = negative_prompt_attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(1, num_images_per_prompt) + negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed * num_images_per_prompt, -1) + else: + negative_prompt_embeds = None + negative_prompt_attention_mask = None + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.PixArtAlphaPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_steps, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip addresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str = "", + num_inference_steps: int = 20, + timesteps: list[int] = None, + sigmas: list[float] = None, + guidance_scale: float = 4.5, + num_images_per_prompt: int | None = 1, + height: int | None = None, + width: int | None = None, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + clean_caption: bool = True, + use_resolution_binning: bool = True, + max_sequence_length: int = 300, + **kwargs, + ) -> ImagePipelineOutput | tuple: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 4.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*, defaults to self.unet.config.sample_size): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size): + The width in pixels of the generated image. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + use_resolution_binning (`bool` defaults to `True`): + If set to `True`, the requested height and width are first mapped to the closest resolutions using + `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to + the requested resolution. Useful for generating non-square images. + max_sequence_length (`int` defaults to 300): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images + """ + # 1. Check inputs. Raise error if not correct + height = height or self.transformer.config.sample_size * self.vae_scale_factor + width = width or self.transformer.config.sample_size * self.vae_scale_factor + if use_resolution_binning: + if self.transformer.config.sample_size == 256: + aspect_ratio_bin = ASPECT_RATIO_2048_BIN + elif self.transformer.config.sample_size == 128: + aspect_ratio_bin = ASPECT_RATIO_1024_BIN + elif self.transformer.config.sample_size == 64: + aspect_ratio_bin = ASPECT_RATIO_512_BIN + elif self.transformer.config.sample_size == 32: + aspect_ratio_bin = ASPECT_RATIO_256_BIN + else: + raise ValueError("Invalid sample size") + orig_height, orig_width = height, width + height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin) + + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_steps, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) + + # 2. Default height and width to transformer + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt, + do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas + ) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 6.1 Prepare micro-conditions. + added_cond_kwargs = {"resolution": None, "aspect_ratio": None} + + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + current_timestep = t + if not torch.is_tensor(current_timestep): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = latent_model_input.device.type == "mps" + is_npu = latent_model_input.device.type == "npu" + if isinstance(current_timestep, float): + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + else: + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 + current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device) + elif len(current_timestep.shape) == 0: + current_timestep = current_timestep[None].to(latent_model_input.device) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + current_timestep = current_timestep.expand(latent_model_input.shape[0]) + + # predict noise model_output + noise_pred = self.transformer( + latent_model_input, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + timestep=current_timestep, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # learned sigma + if self.transformer.config.out_channels // 2 == latent_channels: + noise_pred = noise_pred.chunk(2, dim=1)[0] + else: + noise_pred = noise_pred + + # compute previous image: x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + image = self.vae.decode(latents.to(self.vae.dtype) / self.vae.config.scaling_factor, return_dict=False)[0] + if use_resolution_binning: + image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height) + else: + image = latents + + if not output_type == "latent": + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/pipelines/sana/__init__.py b/diffusers/pipelines/sana/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..91684f35f153516de83cd65ebed28164970b7548 --- /dev/null +++ b/diffusers/pipelines/sana/__init__.py @@ -0,0 +1,53 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_sana"] = ["SanaPipeline"] + _import_structure["pipeline_sana_controlnet"] = ["SanaControlNetPipeline"] + _import_structure["pipeline_sana_sprint"] = ["SanaSprintPipeline"] + _import_structure["pipeline_sana_sprint_img2img"] = ["SanaSprintImg2ImgPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_sana import SanaPipeline + from .pipeline_sana_controlnet import SanaControlNetPipeline + from .pipeline_sana_sprint import SanaSprintPipeline + from .pipeline_sana_sprint_img2img import SanaSprintImg2ImgPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/pipelines/sana/pipeline_output.py b/diffusers/pipelines/sana/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..b9d095906e06ba975233a2a1577d1dac651c813d --- /dev/null +++ b/diffusers/pipelines/sana/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import numpy as np +import PIL.Image + +from ...utils import BaseOutput + + +@dataclass +class SanaPipelineOutput(BaseOutput): + """ + Output class for Sana pipelines. + + Args: + images (`list[PIL.Image.Image]` or `np.ndarray`) + list of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: list[PIL.Image.Image] | np.ndarray diff --git a/diffusers/pipelines/sana/pipeline_sana.py b/diffusers/pipelines/sana/pipeline_sana.py new file mode 100644 index 0000000000000000000000000000000000000000..17e0be9ba7aa23a9c810e49cce88b3796a1c7ae1 --- /dev/null +++ b/diffusers/pipelines/sana/pipeline_sana.py @@ -0,0 +1,1040 @@ +# Copyright 2025 SANA Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import inspect +import re +import urllib.parse as ul +import warnings +from typing import Any, Callable + +import torch +from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PixArtImageProcessor +from ...loaders import SanaLoraLoaderMixin +from ...models import AutoencoderDC, SanaTransformer2DModel +from ...schedulers import DPMSolverMultistepScheduler +from ...utils import ( + BACKENDS_MAPPING, + USE_PEFT_BACKEND, + deprecate, + is_bs4_available, + is_ftfy_available, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import get_device, is_torch_version, randn_tensor +from ..pipeline_utils import DiffusionPipeline +from ..pixart_alpha.pipeline_pixart_alpha import ( + ASPECT_RATIO_512_BIN, + ASPECT_RATIO_1024_BIN, +) +from ..pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN +from .pipeline_output import SanaPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + + +ASPECT_RATIO_4096_BIN = { + "0.25": [2048.0, 8192.0], + "0.26": [2048.0, 7936.0], + "0.27": [2048.0, 7680.0], + "0.28": [2048.0, 7424.0], + "0.32": [2304.0, 7168.0], + "0.33": [2304.0, 6912.0], + "0.35": [2304.0, 6656.0], + "0.4": [2560.0, 6400.0], + "0.42": [2560.0, 6144.0], + "0.48": [2816.0, 5888.0], + "0.5": [2816.0, 5632.0], + "0.52": [2816.0, 5376.0], + "0.57": [3072.0, 5376.0], + "0.6": [3072.0, 5120.0], + "0.68": [3328.0, 4864.0], + "0.72": [3328.0, 4608.0], + "0.78": [3584.0, 4608.0], + "0.82": [3584.0, 4352.0], + "0.88": [3840.0, 4352.0], + "0.94": [3840.0, 4096.0], + "1.0": [4096.0, 4096.0], + "1.07": [4096.0, 3840.0], + "1.13": [4352.0, 3840.0], + "1.21": [4352.0, 3584.0], + "1.29": [4608.0, 3584.0], + "1.38": [4608.0, 3328.0], + "1.46": [4864.0, 3328.0], + "1.67": [5120.0, 3072.0], + "1.75": [5376.0, 3072.0], + "2.0": [5632.0, 2816.0], + "2.09": [5888.0, 2816.0], + "2.4": [6144.0, 2560.0], + "2.5": [6400.0, 2560.0], + "2.89": [6656.0, 2304.0], + "3.0": [6912.0, 2304.0], + "3.11": [7168.0, 2304.0], + "3.62": [7424.0, 2048.0], + "3.75": [7680.0, 2048.0], + "3.88": [7936.0, 2048.0], + "4.0": [8192.0, 2048.0], +} + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import SanaPipeline + + >>> pipe = SanaPipeline.from_pretrained( + ... "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", torch_dtype=torch.float32 + ... ) + >>> pipe.to("cuda") + >>> pipe.text_encoder.to(torch.bfloat16) + >>> pipe.transformer = pipe.transformer.to(torch.bfloat16) + + >>> image = pipe(prompt='a cyberpunk cat with a neon sign that says "Sana"')[0] + >>> image[0].save("output.png") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin): + r""" + Pipeline for text-to-image generation using [Sana](https://huggingface.co/papers/2410.10629). + """ + + # fmt: off + bad_punct_regex = re.compile(r"[" + "#®•©™&@·º½¾¿¡§~" + r"\)" + r"\(" + r"\]" + r"\[" + r"\}" + r"\{" + r"\|" + "\\" + r"\/" + r"\*" + r"]{1,}") + # fmt: on + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + tokenizer: GemmaTokenizer | GemmaTokenizerFast, + text_encoder: Gemma2PreTrainedModel, + vae: AutoencoderDC, + transformer: SanaTransformer2DModel, + scheduler: DPMSolverMultistepScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.encoder_block_out_channels) - 1) + if hasattr(self, "vae") and self.vae is not None + else 32 + ) + self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." + deprecate( + "enable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." + deprecate( + "disable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." + deprecate( + "enable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." + deprecate( + "disable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.disable_tiling() + + def _get_gemma_prompt_embeds( + self, + prompt: str | list[str], + device: torch.device, + dtype: torch.dtype, + clean_caption: bool = False, + max_sequence_length: int = 300, + complex_human_instruction: list[str] | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt. + complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`): + If `complex_human_instruction` is not empty, the function will use the complex Human instruction for + the prompt. + """ + prompt = [prompt] if isinstance(prompt, str) else prompt + + if getattr(self, "tokenizer", None) is not None: + self.tokenizer.padding_side = "right" + + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + + # prepare complex human instruction + if not complex_human_instruction: + max_length_all = max_sequence_length + else: + chi_prompt = "\n".join(complex_human_instruction) + prompt = [chi_prompt + p for p in prompt] + num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt)) + max_length_all = num_chi_prompt_tokens + max_sequence_length - 2 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length_all, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.to(device) + + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask) + prompt_embeds = prompt_embeds[0].to(dtype=dtype, device=device) + + return prompt_embeds, prompt_attention_mask + + def encode_prompt( + self, + prompt: str | list[str], + do_classifier_free_guidance: bool = True, + negative_prompt: str = "", + num_images_per_prompt: int = 1, + device: torch.device | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + clean_caption: bool = False, + max_sequence_length: int = 300, + complex_human_instruction: list[str] | None = None, + lora_scale: float | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For + PixArt-Alpha, this should be "". + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For Sana, it's should be the embeddings of the "" string. + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt. + complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`): + If `complex_human_instruction` is not empty, the function will use the complex Human instruction for + the prompt. + """ + + if device is None: + device = self._execution_device + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + else: + dtype = None + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, SanaLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if getattr(self, "tokenizer", None) is not None: + self.tokenizer.padding_side = "right" + + # See Section 3.1. of the paper. + max_length = max_sequence_length + select_index = [0] + list(range(-max_length + 1, 0)) + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=prompt, + device=device, + dtype=dtype, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + complex_human_instruction=complex_human_instruction, + ) + + prompt_embeds = prompt_embeds[:, select_index] + prompt_attention_mask = prompt_attention_mask[:, select_index] + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=negative_prompt, + device=device, + dtype=dtype, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + complex_human_instruction=False, + ) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1) + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + else: + negative_prompt_embeds = None + negative_prompt_attention_mask = None + + if self.text_encoder is not None: + if isinstance(self, SanaLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_on_step_end_tensor_inputs=None, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if height % 32 != 0 or width % 32 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip addresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + if latents is not None: + return latents.to(device=device, dtype=dtype) + + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str = "", + num_inference_steps: int = 20, + timesteps: list[int] = None, + sigmas: list[float] = None, + guidance_scale: float = 4.5, + num_images_per_prompt: int | None = 1, + height: int = 1024, + width: int = 1024, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + clean_caption: bool = False, + use_resolution_binning: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 300, + complex_human_instruction: list[str] = [ + "Given a user prompt, generate an 'Enhanced prompt' that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:", + "- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.", + "- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.", + "Here are examples of how to transform or refine prompts:", + "- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.", + "- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.", + "Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:", + "User Prompt: ", + ], + ) -> SanaPipelineOutput | tuple: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_inference_steps (`int`, *optional*, defaults to 20): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 4.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*, defaults to self.unet.config.sample_size): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size): + The width in pixels of the generated image. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + attention_kwargs: + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + use_resolution_binning (`bool` defaults to `True`): + If set to `True`, the requested height and width are first mapped to the closest resolutions using + `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to + the requested resolution. Useful for generating non-square images. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to `300`): + Maximum sequence length to use with the `prompt`. + complex_human_instruction (`list[str]`, *optional*): + Instructions for complex human attention: + https://github.com/NVlabs/Sana/blob/main/configs/sana_app_config/Sana_1600M_app.yaml#L55. + + Examples: + + Returns: + [`~pipelines.sana.pipeline_output.SanaPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.sana.pipeline_output.SanaPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + if use_resolution_binning: + if self.transformer.config.sample_size == 128: + aspect_ratio_bin = ASPECT_RATIO_4096_BIN + elif self.transformer.config.sample_size == 64: + aspect_ratio_bin = ASPECT_RATIO_2048_BIN + elif self.transformer.config.sample_size == 32: + aspect_ratio_bin = ASPECT_RATIO_1024_BIN + elif self.transformer.config.sample_size == 16: + aspect_ratio_bin = ASPECT_RATIO_512_BIN + else: + raise ValueError("Invalid sample size") + orig_height, orig_width = height, width + height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin) + + self.check_inputs( + prompt, + height, + width, + callback_on_step_end_tensor_inputs, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default height and width to transformer + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None + + # 3. Encode input prompt + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt, + self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + complex_human_instruction=complex_human_instruction, + lora_scale=lora_scale, + ) + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas + ) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + height, + width, + torch.float32, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + transformer_dtype = self.transformer.dtype + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + timestep = timestep * self.transformer.config.timestep_scale + + # predict noise model_output + noise_pred = self.transformer( + latent_model_input.to(dtype=transformer_dtype), + encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype), + encoder_attention_mask=prompt_attention_mask, + timestep=timestep, + return_dict=False, + attention_kwargs=self.attention_kwargs, + )[0] + noise_pred = noise_pred.float() + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # learned sigma + if self.transformer.config.out_channels // 2 == latent_channels: + noise_pred = noise_pred.chunk(2, dim=1)[0] + + # compute previous image: x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + else: + latents = latents.to(self.vae.dtype) + torch_accelerator_module = getattr(torch, get_device(), torch.cuda) + oom_error = ( + torch.OutOfMemoryError + if is_torch_version(">=", "2.5.0") + else torch_accelerator_module.OutOfMemoryError + ) + try: + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + except oom_error as e: + warnings.warn( + f"{e}. \n" + f"Try to use VAE tiling for large images. For example: \n" + f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)" + ) + if use_resolution_binning: + image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height) + + if not output_type == "latent": + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return SanaPipelineOutput(images=image) diff --git a/diffusers/pipelines/sana/pipeline_sana_controlnet.py b/diffusers/pipelines/sana/pipeline_sana_controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..d976c7035d9d136b28b861a2803c7d723c82659b --- /dev/null +++ b/diffusers/pipelines/sana/pipeline_sana_controlnet.py @@ -0,0 +1,1135 @@ +# Copyright 2025 PixArt-Sigma Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import inspect +import re +import urllib.parse as ul +import warnings +from typing import Any, Callable + +import torch +from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, PixArtImageProcessor +from ...loaders import SanaLoraLoaderMixin +from ...models import AutoencoderDC, SanaControlNetModel, SanaTransformer2DModel +from ...schedulers import DPMSolverMultistepScheduler +from ...utils import ( + BACKENDS_MAPPING, + USE_PEFT_BACKEND, + deprecate, + is_bs4_available, + is_ftfy_available, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import get_device, is_torch_version, randn_tensor +from ..pipeline_utils import DiffusionPipeline +from ..pixart_alpha.pipeline_pixart_alpha import ( + ASPECT_RATIO_512_BIN, + ASPECT_RATIO_1024_BIN, +) +from ..pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN +from .pipeline_output import SanaPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + + +ASPECT_RATIO_4096_BIN = { + "0.25": [2048.0, 8192.0], + "0.26": [2048.0, 7936.0], + "0.27": [2048.0, 7680.0], + "0.28": [2048.0, 7424.0], + "0.32": [2304.0, 7168.0], + "0.33": [2304.0, 6912.0], + "0.35": [2304.0, 6656.0], + "0.4": [2560.0, 6400.0], + "0.42": [2560.0, 6144.0], + "0.48": [2816.0, 5888.0], + "0.5": [2816.0, 5632.0], + "0.52": [2816.0, 5376.0], + "0.57": [3072.0, 5376.0], + "0.6": [3072.0, 5120.0], + "0.68": [3328.0, 4864.0], + "0.72": [3328.0, 4608.0], + "0.78": [3584.0, 4608.0], + "0.82": [3584.0, 4352.0], + "0.88": [3840.0, 4352.0], + "0.94": [3840.0, 4096.0], + "1.0": [4096.0, 4096.0], + "1.07": [4096.0, 3840.0], + "1.13": [4352.0, 3840.0], + "1.21": [4352.0, 3584.0], + "1.29": [4608.0, 3584.0], + "1.38": [4608.0, 3328.0], + "1.46": [4864.0, 3328.0], + "1.67": [5120.0, 3072.0], + "1.75": [5376.0, 3072.0], + "2.0": [5632.0, 2816.0], + "2.09": [5888.0, 2816.0], + "2.4": [6144.0, 2560.0], + "2.5": [6400.0, 2560.0], + "2.89": [6656.0, 2304.0], + "3.0": [6912.0, 2304.0], + "3.11": [7168.0, 2304.0], + "3.62": [7424.0, 2048.0], + "3.75": [7680.0, 2048.0], + "3.88": [7936.0, 2048.0], + "4.0": [8192.0, 2048.0], +} + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import SanaControlNetPipeline + >>> from diffusers.utils import load_image + + >>> pipe = SanaControlNetPipeline.from_pretrained( + ... "ishan24/Sana_600M_1024px_ControlNetPlus_diffusers", + ... variant="fp16", + ... torch_dtype={"default": torch.bfloat16, "controlnet": torch.float16, "transformer": torch.float16}, + ... device_map="balanced", + ... ) + >>> cond_image = load_image( + ... "https://huggingface.co/ishan24/Sana_600M_1024px_ControlNet_diffusers/resolve/main/hed_example.png" + ... ) + >>> prompt = 'a cat with a neon sign that says "Sana"' + >>> image = pipe( + ... prompt, + ... control_image=cond_image, + ... ).images[0] + >>> image.save("output.png") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class SanaControlNetPipeline(DiffusionPipeline, SanaLoraLoaderMixin): + r""" + Pipeline for text-to-image generation using [Sana](https://huggingface.co/papers/2410.10629). + """ + + # fmt: off + bad_punct_regex = re.compile(r"[" + "#®•©™&@·º½¾¿¡§~" + r"\)" + r"\(" + r"\]" + r"\[" + r"\}" + r"\{" + r"\|" + "\\" + r"\/" + r"\*" + r"]{1,}") + # fmt: on + + model_cpu_offload_seq = "text_encoder->controlnet->transformer->vae" + _callback_tensor_inputs = ["latents", "control_image", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + tokenizer: GemmaTokenizer | GemmaTokenizerFast, + text_encoder: Gemma2PreTrainedModel, + vae: AutoencoderDC, + transformer: SanaTransformer2DModel, + controlnet: SanaControlNetModel, + scheduler: DPMSolverMultistepScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + vae=vae, + transformer=transformer, + controlnet=controlnet, + scheduler=scheduler, + ) + + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.encoder_block_out_channels) - 1) + if hasattr(self, "vae") and self.vae is not None + else 32 + ) + self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." + deprecate( + "enable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." + deprecate( + "disable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." + deprecate( + "enable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." + deprecate( + "disable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.disable_tiling() + + # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline._get_gemma_prompt_embeds + def _get_gemma_prompt_embeds( + self, + prompt: str | list[str], + device: torch.device, + dtype: torch.dtype, + clean_caption: bool = False, + max_sequence_length: int = 300, + complex_human_instruction: list[str] | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt. + complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`): + If `complex_human_instruction` is not empty, the function will use the complex Human instruction for + the prompt. + """ + prompt = [prompt] if isinstance(prompt, str) else prompt + + if getattr(self, "tokenizer", None) is not None: + self.tokenizer.padding_side = "right" + + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + + # prepare complex human instruction + if not complex_human_instruction: + max_length_all = max_sequence_length + else: + chi_prompt = "\n".join(complex_human_instruction) + prompt = [chi_prompt + p for p in prompt] + num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt)) + max_length_all = num_chi_prompt_tokens + max_sequence_length - 2 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length_all, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.to(device) + + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask) + prompt_embeds = prompt_embeds[0].to(dtype=dtype, device=device) + + return prompt_embeds, prompt_attention_mask + + # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + do_classifier_free_guidance: bool = True, + negative_prompt: str = "", + num_images_per_prompt: int = 1, + device: torch.device | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + clean_caption: bool = False, + max_sequence_length: int = 300, + complex_human_instruction: list[str] | None = None, + lora_scale: float | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For + PixArt-Alpha, this should be "". + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For Sana, it's should be the embeddings of the "" string. + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt. + complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`): + If `complex_human_instruction` is not empty, the function will use the complex Human instruction for + the prompt. + """ + + if device is None: + device = self._execution_device + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + else: + dtype = None + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, SanaLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if getattr(self, "tokenizer", None) is not None: + self.tokenizer.padding_side = "right" + + # See Section 3.1. of the paper. + max_length = max_sequence_length + select_index = [0] + list(range(-max_length + 1, 0)) + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=prompt, + device=device, + dtype=dtype, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + complex_human_instruction=complex_human_instruction, + ) + + prompt_embeds = prompt_embeds[:, select_index] + prompt_attention_mask = prompt_attention_mask[:, select_index] + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=negative_prompt, + device=device, + dtype=dtype, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + complex_human_instruction=False, + ) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1) + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + else: + negative_prompt_embeds = None + negative_prompt_attention_mask = None + + if self.text_encoder is not None: + if isinstance(self, SanaLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_on_step_end_tensor_inputs=None, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if height % 32 != 0 or width % 32 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip addresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + if isinstance(image, torch.Tensor): + pass + else: + image = self.image_processor.preprocess(image, height=height, width=width) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + if latents is not None: + return latents.to(device=device, dtype=dtype) + + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str = "", + num_inference_steps: int = 20, + timesteps: list[int] = None, + sigmas: list[float] = None, + guidance_scale: float = 4.5, + control_image: PipelineImageInput = None, + controlnet_conditioning_scale: float | list[float] = 1.0, + num_images_per_prompt: int | None = 1, + height: int = 1024, + width: int = 1024, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + clean_caption: bool = False, + use_resolution_binning: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 300, + complex_human_instruction: list[str] = [ + "Given a user prompt, generate an 'Enhanced prompt' that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:", + "- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.", + "- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.", + "Here are examples of how to transform or refine prompts:", + "- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.", + "- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.", + "Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:", + "User Prompt: ", + ], + ) -> SanaPipelineOutput | tuple: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_inference_steps (`int`, *optional*, defaults to 20): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 4.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, `list[np.ndarray]`,: + `list[list[torch.Tensor]]`, `list[list[np.ndarray]]` or `list[list[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. + controlnet_conditioning_scale (`float` or `list[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*, defaults to self.unet.config.sample_size): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size): + The width in pixels of the generated image. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + attention_kwargs: + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + use_resolution_binning (`bool` defaults to `True`): + If set to `True`, the requested height and width are first mapped to the closest resolutions using + `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to + the requested resolution. Useful for generating non-square images. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to `300`): + Maximum sequence length to use with the `prompt`. + complex_human_instruction (`list[str]`, *optional*): + Instructions for complex human attention: + https://github.com/NVlabs/Sana/blob/main/configs/sana_app_config/Sana_1600M_app.yaml#L55. + + Examples: + + Returns: + [`~pipelines.sana.pipeline_output.SanaPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.sana.pipeline_output.SanaPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + if use_resolution_binning: + if self.transformer.config.sample_size == 128: + aspect_ratio_bin = ASPECT_RATIO_4096_BIN + elif self.transformer.config.sample_size == 64: + aspect_ratio_bin = ASPECT_RATIO_2048_BIN + elif self.transformer.config.sample_size == 32: + aspect_ratio_bin = ASPECT_RATIO_1024_BIN + elif self.transformer.config.sample_size == 16: + aspect_ratio_bin = ASPECT_RATIO_512_BIN + else: + raise ValueError("Invalid sample size") + orig_height, orig_width = height, width + height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin) + + self.check_inputs( + prompt, + height, + width, + callback_on_step_end_tensor_inputs, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default height and width to transformer + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None + + # 3. Encode input prompt + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt, + self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + complex_human_instruction=complex_human_instruction, + lora_scale=lora_scale, + ) + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + # 4. Prepare control image + if isinstance(self.controlnet, SanaControlNetModel): + control_image = self.prepare_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.vae.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=False, + ) + height, width = control_image.shape[-2:] + + control_image = self.vae.encode(control_image).latent + control_image = control_image * self.vae.config.scaling_factor + else: + raise ValueError("`controlnet` must be of type `SanaControlNetModel`.") + + # 5. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas + ) + + # 6. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + height, + width, + torch.float32, + device, + generator, + latents, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + controlnet_dtype = self.controlnet.dtype + transformer_dtype = self.transformer.dtype + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # controlnet(s) inference + controlnet_block_samples = self.controlnet( + latent_model_input.to(dtype=controlnet_dtype), + encoder_hidden_states=prompt_embeds.to(dtype=controlnet_dtype), + encoder_attention_mask=prompt_attention_mask, + timestep=timestep, + return_dict=False, + attention_kwargs=self.attention_kwargs, + controlnet_cond=control_image, + conditioning_scale=controlnet_conditioning_scale, + )[0] + + # predict noise model_output + noise_pred = self.transformer( + latent_model_input.to(dtype=transformer_dtype), + encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype), + encoder_attention_mask=prompt_attention_mask, + timestep=timestep, + return_dict=False, + attention_kwargs=self.attention_kwargs, + controlnet_block_samples=tuple(t.to(dtype=transformer_dtype) for t in controlnet_block_samples), + )[0] + noise_pred = noise_pred.float() + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # learned sigma + if self.transformer.config.out_channels // 2 == latent_channels: + noise_pred = noise_pred.chunk(2, dim=1)[0] + + # compute previous image: x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + else: + latents = latents.to(self.vae.dtype) + torch_accelerator_module = getattr(torch, get_device(), torch.cuda) + oom_error = ( + torch.OutOfMemoryError + if is_torch_version(">=", "2.5.0") + else torch_accelerator_module.OutOfMemoryError + ) + try: + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + except oom_error as e: + warnings.warn( + f"{e}. \n" + f"Try to use VAE tiling for large images. For example: \n" + f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)" + ) + if use_resolution_binning: + image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return SanaPipelineOutput(images=image) diff --git a/diffusers/pipelines/sana/pipeline_sana_sprint.py b/diffusers/pipelines/sana/pipeline_sana_sprint.py new file mode 100644 index 0000000000000000000000000000000000000000..c85f05275f51165e154306614b34cc601e9988a5 --- /dev/null +++ b/diffusers/pipelines/sana/pipeline_sana_sprint.py @@ -0,0 +1,922 @@ +# Copyright 2025 SANA-Sprint Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import inspect +import re +import urllib.parse as ul +import warnings +from typing import Any, Callable + +import torch +from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PixArtImageProcessor +from ...loaders import SanaLoraLoaderMixin +from ...models import AutoencoderDC, SanaTransformer2DModel +from ...schedulers import DPMSolverMultistepScheduler +from ...utils import ( + BACKENDS_MAPPING, + USE_PEFT_BACKEND, + deprecate, + is_bs4_available, + is_ftfy_available, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import get_device, is_torch_version, randn_tensor +from ..pipeline_utils import DiffusionPipeline +from ..pixart_alpha.pipeline_pixart_alpha import ASPECT_RATIO_1024_BIN +from .pipeline_output import SanaPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import SanaSprintPipeline + + >>> pipe = SanaSprintPipeline.from_pretrained( + ... "Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers", torch_dtype=torch.bfloat16 + ... ) + >>> pipe.to("cuda") + + >>> image = pipe(prompt="a tiny astronaut hatching from an egg on the moon")[0] + >>> image[0].save("output.png") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin): + r""" + Pipeline for text-to-image generation using [SANA-Sprint](https://huggingface.co/papers/2503.09641). + """ + + # fmt: off + bad_punct_regex = re.compile(r"[" + "#®•©™&@·º½¾¿¡§~" + r"\)" + r"\(" + r"\]" + r"\[" + r"\}" + r"\{" + r"\|" + "\\" + r"\/" + r"\*" + r"]{1,}") + # fmt: on + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + tokenizer: GemmaTokenizer | GemmaTokenizerFast, + text_encoder: Gemma2PreTrainedModel, + vae: AutoencoderDC, + transformer: SanaTransformer2DModel, + scheduler: DPMSolverMultistepScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.encoder_block_out_channels) - 1) + if hasattr(self, "vae") and self.vae is not None + else 32 + ) + self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." + deprecate( + "enable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." + deprecate( + "disable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." + deprecate( + "enable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." + deprecate( + "disable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.disable_tiling() + + # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline._get_gemma_prompt_embeds + def _get_gemma_prompt_embeds( + self, + prompt: str | list[str], + device: torch.device, + dtype: torch.dtype, + clean_caption: bool = False, + max_sequence_length: int = 300, + complex_human_instruction: list[str] | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt. + complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`): + If `complex_human_instruction` is not empty, the function will use the complex Human instruction for + the prompt. + """ + prompt = [prompt] if isinstance(prompt, str) else prompt + + if getattr(self, "tokenizer", None) is not None: + self.tokenizer.padding_side = "right" + + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + + # prepare complex human instruction + if not complex_human_instruction: + max_length_all = max_sequence_length + else: + chi_prompt = "\n".join(complex_human_instruction) + prompt = [chi_prompt + p for p in prompt] + num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt)) + max_length_all = num_chi_prompt_tokens + max_sequence_length - 2 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length_all, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.to(device) + + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask) + prompt_embeds = prompt_embeds[0].to(dtype=dtype, device=device) + + return prompt_embeds, prompt_attention_mask + + def encode_prompt( + self, + prompt: str | list[str], + num_images_per_prompt: int = 1, + device: torch.device | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + clean_caption: bool = False, + max_sequence_length: int = 300, + complex_human_instruction: list[str] | None = None, + lora_scale: float | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt. + complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`): + If `complex_human_instruction` is not empty, the function will use the complex Human instruction for + the prompt. + """ + + if device is None: + device = self._execution_device + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + else: + dtype = None + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, SanaLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + + if getattr(self, "tokenizer", None) is not None: + self.tokenizer.padding_side = "right" + + # See Section 3.1. of the paper. + max_length = max_sequence_length + select_index = [0] + list(range(-max_length + 1, 0)) + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=prompt, + device=device, + dtype=dtype, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + complex_human_instruction=complex_human_instruction, + ) + + prompt_embeds = prompt_embeds[:, select_index] + prompt_attention_mask = prompt_attention_mask[:, select_index] + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + + if self.text_encoder is not None: + if isinstance(self, SanaLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + num_inference_steps, + timesteps, + max_timesteps, + intermediate_timesteps, + callback_on_step_end_tensor_inputs=None, + prompt_embeds=None, + prompt_attention_mask=None, + ): + if height % 32 != 0 or width % 32 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if timesteps is not None and len(timesteps) != num_inference_steps + 1: + raise ValueError("If providing custom timesteps, `timesteps` must be of length `num_inference_steps + 1`.") + + if timesteps is not None and max_timesteps is not None: + raise ValueError("If providing custom timesteps, `max_timesteps` should not be provided.") + + if timesteps is None and max_timesteps is None: + raise ValueError("Should provide either `timesteps` or `max_timesteps`.") + + if intermediate_timesteps is not None and num_inference_steps != 2: + raise ValueError("Intermediate timesteps for SCM is not supported when num_inference_steps != 2.") + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip addresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + if latents is not None: + return latents.to(device=device, dtype=dtype) + + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + num_inference_steps: int = 2, + timesteps: list[int] = None, + max_timesteps: float = 1.57080, + intermediate_timesteps: float = 1.3, + guidance_scale: float = 4.5, + num_images_per_prompt: int | None = 1, + height: int = 1024, + width: int = 1024, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + clean_caption: bool = False, + use_resolution_binning: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 300, + complex_human_instruction: list[str] = [ + "Given a user prompt, generate an 'Enhanced prompt' that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:", + "- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.", + "- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.", + "Here are examples of how to transform or refine prompts:", + "- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.", + "- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.", + "Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:", + "User Prompt: ", + ], + ) -> SanaPipelineOutput | tuple: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + num_inference_steps (`int`, *optional*, defaults to 20): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + max_timesteps (`float`, *optional*, defaults to 1.57080): + The maximum timestep value used in the SCM scheduler. + intermediate_timesteps (`float`, *optional*, defaults to 1.3): + The intermediate timestep value used in SCM scheduler (only used when num_inference_steps=2). + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 4.5): + Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages + a model to generate images more aligned with `prompt` at the expense of lower image quality. + + Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to + the [paper](https://huggingface.co/papers/2210.03142) to learn more. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*, defaults to self.unet.config.sample_size): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size): + The width in pixels of the generated image. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + attention_kwargs: + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + use_resolution_binning (`bool` defaults to `True`): + If set to `True`, the requested height and width are first mapped to the closest resolutions using + `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to + the requested resolution. Useful for generating non-square images. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to `300`): + Maximum sequence length to use with the `prompt`. + complex_human_instruction (`list[str]`, *optional*): + Instructions for complex human attention: + https://github.com/NVlabs/Sana/blob/main/configs/sana_app_config/Sana_1600M_app.yaml#L55. + + Examples: + + Returns: + [`~pipelines.sana.pipeline_output.SanaPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.sana.pipeline_output.SanaPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + if use_resolution_binning: + if self.transformer.config.sample_size == 32: + aspect_ratio_bin = ASPECT_RATIO_1024_BIN + else: + raise ValueError("Invalid sample size") + orig_height, orig_width = height, width + height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin) + + self.check_inputs( + prompt=prompt, + height=height, + width=width, + num_inference_steps=num_inference_steps, + timesteps=timesteps, + max_timesteps=max_timesteps, + intermediate_timesteps=intermediate_timesteps, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + prompt_embeds=prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default height and width to transformer + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None + + # 3. Encode input prompt + ( + prompt_embeds, + prompt_attention_mask, + ) = self.encode_prompt( + prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + complex_human_instruction=complex_human_instruction, + lora_scale=lora_scale, + ) + + # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + timestep_device, + timesteps, + sigmas=None, + max_timesteps=max_timesteps, + intermediate_timesteps=intermediate_timesteps, + ) + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(0) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + height, + width, + torch.float32, + device, + generator, + latents, + ) + + latents = latents * self.scheduler.config.sigma_data + + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]).to(prompt_embeds.dtype) + guidance = guidance * self.transformer.config.guidance_embeds_scale + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + timesteps = timesteps[:-1] + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + transformer_dtype = self.transformer.dtype + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]) + latents_model_input = latents / self.scheduler.config.sigma_data + + scm_timestep = torch.sin(timestep) / (torch.cos(timestep) + torch.sin(timestep)) + + scm_timestep_expanded = scm_timestep.view(-1, 1, 1, 1) + latent_model_input = latents_model_input * torch.sqrt( + scm_timestep_expanded**2 + (1 - scm_timestep_expanded) ** 2 + ) + + # predict noise model_output + noise_pred = self.transformer( + latent_model_input.to(dtype=transformer_dtype), + encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype), + encoder_attention_mask=prompt_attention_mask, + guidance=guidance, + timestep=scm_timestep, + return_dict=False, + attention_kwargs=self.attention_kwargs, + )[0] + + noise_pred = ( + (1 - 2 * scm_timestep_expanded) * latent_model_input + + (1 - 2 * scm_timestep_expanded + 2 * scm_timestep_expanded**2) * noise_pred + ) / torch.sqrt(scm_timestep_expanded**2 + (1 - scm_timestep_expanded) ** 2) + noise_pred = noise_pred.float() * self.scheduler.config.sigma_data + + # compute previous image: x_t -> x_t-1 + latents, denoised = self.scheduler.step( + noise_pred, timestep, latents, **extra_step_kwargs, return_dict=False + ) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + latents = denoised / self.scheduler.config.sigma_data + if output_type == "latent": + image = latents + else: + latents = latents.to(self.vae.dtype) + torch_accelerator_module = getattr(torch, get_device(), torch.cuda) + oom_error = ( + torch.OutOfMemoryError + if is_torch_version(">=", "2.5.0") + else torch_accelerator_module.OutOfMemoryError + ) + try: + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + except oom_error as e: + warnings.warn( + f"{e}. \n" + f"Try to use VAE tiling for large images. For example: \n" + f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)" + ) + if use_resolution_binning: + image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height) + + if not output_type == "latent": + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return SanaPipelineOutput(images=image) diff --git a/diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py b/diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py new file mode 100644 index 0000000000000000000000000000000000000000..a17c494e88eb1445a62970c8907c3ae9a0de1d07 --- /dev/null +++ b/diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py @@ -0,0 +1,1006 @@ +# Copyright 2025 PixArt-Sigma Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import inspect +import re +import urllib.parse as ul +import warnings +from typing import Any, Callable + +import torch +import torch.nn.functional as F +from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, PixArtImageProcessor +from ...loaders import SanaLoraLoaderMixin +from ...models import AutoencoderDC, SanaTransformer2DModel +from ...schedulers import DPMSolverMultistepScheduler +from ...utils import ( + BACKENDS_MAPPING, + USE_PEFT_BACKEND, + deprecate, + is_bs4_available, + is_ftfy_available, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import get_device, is_torch_version, randn_tensor +from ..pipeline_utils import DiffusionPipeline +from ..pixart_alpha.pipeline_pixart_alpha import ASPECT_RATIO_1024_BIN +from .pipeline_output import SanaPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import SanaSprintImg2ImgPipeline + >>> from diffusers.utils.loading_utils import load_image + + >>> pipe = SanaSprintImg2ImgPipeline.from_pretrained( + ... "Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers", torch_dtype=torch.bfloat16 + ... ) + >>> pipe.to("cuda") + + >>> image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/penguin.png" + ... ) + + + >>> image = pipe(prompt="a cute pink bear", image=image, strength=0.5, height=832, width=480).images[0] + >>> image[0].save("output.png") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class SanaSprintImg2ImgPipeline(DiffusionPipeline, SanaLoraLoaderMixin): + r""" + Pipeline for text-to-image generation using [SANA-Sprint](https://huggingface.co/papers/2503.09641). + """ + + # fmt: off + bad_punct_regex = re.compile( + r"[" + "#®•©™&@·º½¾¿¡§~" + r"\)" + r"\(" + r"\]" + r"\[" + r"\}" + r"\{" + r"\|" + "\\" + r"\/" + r"\*" + r"]{1,}") + # fmt: on + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + tokenizer: GemmaTokenizer | GemmaTokenizerFast, + text_encoder: Gemma2PreTrainedModel, + vae: AutoencoderDC, + transformer: SanaTransformer2DModel, + scheduler: DPMSolverMultistepScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.encoder_block_out_channels) - 1) + if hasattr(self, "vae") and self.vae is not None + else 32 + ) + self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) + + # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." + deprecate( + "enable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." + deprecate( + "disable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." + deprecate( + "enable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." + deprecate( + "disable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.disable_tiling() + + # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline._get_gemma_prompt_embeds + def _get_gemma_prompt_embeds( + self, + prompt: str | list[str], + device: torch.device, + dtype: torch.dtype, + clean_caption: bool = False, + max_sequence_length: int = 300, + complex_human_instruction: list[str] | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt. + complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`): + If `complex_human_instruction` is not empty, the function will use the complex Human instruction for + the prompt. + """ + prompt = [prompt] if isinstance(prompt, str) else prompt + + if getattr(self, "tokenizer", None) is not None: + self.tokenizer.padding_side = "right" + + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + + # prepare complex human instruction + if not complex_human_instruction: + max_length_all = max_sequence_length + else: + chi_prompt = "\n".join(complex_human_instruction) + prompt = [chi_prompt + p for p in prompt] + num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt)) + max_length_all = num_chi_prompt_tokens + max_sequence_length - 2 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length_all, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.to(device) + + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask) + prompt_embeds = prompt_embeds[0].to(dtype=dtype, device=device) + + return prompt_embeds, prompt_attention_mask + + # Copied from diffusers.pipelines.sana.pipeline_sana_sprint.SanaSprintPipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + num_images_per_prompt: int = 1, + device: torch.device | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + clean_caption: bool = False, + max_sequence_length: int = 300, + complex_human_instruction: list[str] | None = None, + lora_scale: float | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt. + complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`): + If `complex_human_instruction` is not empty, the function will use the complex Human instruction for + the prompt. + """ + + if device is None: + device = self._execution_device + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + else: + dtype = None + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, SanaLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + + if getattr(self, "tokenizer", None) is not None: + self.tokenizer.padding_side = "right" + + # See Section 3.1. of the paper. + max_length = max_sequence_length + select_index = [0] + list(range(-max_length + 1, 0)) + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=prompt, + device=device, + dtype=dtype, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + complex_human_instruction=complex_human_instruction, + ) + + prompt_embeds = prompt_embeds[:, select_index] + prompt_attention_mask = prompt_attention_mask[:, select_index] + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + + if self.text_encoder is not None: + if isinstance(self, SanaLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def check_inputs( + self, + prompt, + strength, + height, + width, + num_inference_steps, + timesteps, + max_timesteps, + intermediate_timesteps, + callback_on_step_end_tensor_inputs=None, + prompt_embeds=None, + prompt_attention_mask=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % 32 != 0 or width % 32 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if timesteps is not None and len(timesteps) != num_inference_steps + 1: + raise ValueError("If providing custom timesteps, `timesteps` must be of length `num_inference_steps + 1`.") + + if timesteps is not None and max_timesteps is not None: + raise ValueError("If providing custom timesteps, `max_timesteps` should not be provided.") + + if timesteps is None and max_timesteps is None: + raise ValueError("Should provide either `timesteps` or `max_timesteps`.") + + if intermediate_timesteps is not None and num_inference_steps != 2: + raise ValueError("Intermediate timesteps for SCM is not supported when num_inference_steps != 2.") + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip addresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + def prepare_image( + self, + image: PipelineImageInput, + width: int, + height: int, + device: torch.device, + dtype: torch.dtype, + ): + if isinstance(image, torch.Tensor): + if image.ndim == 3: + image = image.unsqueeze(0) + # Resize if current dimensions do not match target dimensions. + if image.shape[2] != height or image.shape[3] != width: + image = F.interpolate(image, size=(height, width), mode="bilinear", align_corners=False) + + image = self.image_processor.preprocess(image, height=height, width=width) + + else: + image = self.image_processor.preprocess(image, height=height, width=width) + + image = image.to(device=device, dtype=dtype) + + return image + + def prepare_latents( + self, image, timestep, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None + ): + if latents is not None: + return latents.to(device=device, dtype=dtype) + + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + + if image.shape[1] != num_channels_latents: + image = self.vae.encode(image).latent + image_latents = image * self.vae.config.scaling_factor * self.scheduler.config.sigma_data + else: + image_latents = image + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + # adapt from https://github.com/huggingface/diffusers/blob/c36f8487df35895421c15f351c7d360bd680[…]/examples/research_projects/sana/train_sana_sprint_diffusers.py + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) * self.scheduler.config.sigma_data + latents = torch.cos(timestep) * image_latents + torch.sin(timestep) * noise + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + num_inference_steps: int = 2, + timesteps: list[int] = None, + max_timesteps: float = 1.57080, + intermediate_timesteps: float = 1.3, + guidance_scale: float = 4.5, + image: PipelineImageInput = None, + strength: float = 0.6, + num_images_per_prompt: int | None = 1, + height: int = 1024, + width: int = 1024, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + clean_caption: bool = False, + use_resolution_binning: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 300, + complex_human_instruction: list[str] = [ + "Given a user prompt, generate an 'Enhanced prompt' that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:", + "- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.", + "- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.", + "Here are examples of how to transform or refine prompts:", + "- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.", + "- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.", + "Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:", + "User Prompt: ", + ], + ) -> SanaPipelineOutput | tuple: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + num_inference_steps (`int`, *optional*, defaults to 20): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + max_timesteps (`float`, *optional*, defaults to 1.57080): + The maximum timestep value used in the SCM scheduler. + intermediate_timesteps (`float`, *optional*, defaults to 1.3): + The intermediate timestep value used in SCM scheduler (only used when num_inference_steps=2). + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 4.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*, defaults to self.unet.config.sample_size): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size): + The width in pixels of the generated image. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + attention_kwargs: + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + use_resolution_binning (`bool` defaults to `True`): + If set to `True`, the requested height and width are first mapped to the closest resolutions using + `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to + the requested resolution. Useful for generating non-square images. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to `300`): + Maximum sequence length to use with the `prompt`. + complex_human_instruction (`list[str]`, *optional*): + Instructions for complex human attention: + https://github.com/NVlabs/Sana/blob/main/configs/sana_app_config/Sana_1600M_app.yaml#L55. + + Examples: + + Returns: + [`~pipelines.sana.pipeline_output.SanaPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.sana.pipeline_output.SanaPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + if use_resolution_binning: + if self.transformer.config.sample_size == 32: + aspect_ratio_bin = ASPECT_RATIO_1024_BIN + else: + raise ValueError("Invalid sample size") + orig_height, orig_width = height, width + height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin) + + self.check_inputs( + prompt=prompt, + strength=strength, + height=height, + width=width, + num_inference_steps=num_inference_steps, + timesteps=timesteps, + max_timesteps=max_timesteps, + intermediate_timesteps=intermediate_timesteps, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + prompt_embeds=prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default height and width to transformer + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None + + # 2. Preprocess image + init_image = self.prepare_image(image, width, height, device, self.vae.dtype) + + # 3. Encode input prompt + ( + prompt_embeds, + prompt_attention_mask, + ) = self.encode_prompt( + prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + complex_human_instruction=complex_human_instruction, + lora_scale=lora_scale, + ) + + # 5. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas=None, + max_timesteps=max_timesteps, + intermediate_timesteps=intermediate_timesteps, + ) + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(0) + + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + latent_timestep = timesteps[:1] + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + init_image, + latent_timestep, + batch_size * num_images_per_prompt, + latent_channels, + height, + width, + torch.float32, + device, + generator, + latents, + ) + + # I think this is redundant given the scaling in prepare_latents + # latents = latents * self.scheduler.config.sigma_data + + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]).to(prompt_embeds.dtype) + guidance = guidance * self.transformer.config.guidance_embeds_scale + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + timesteps = timesteps[:-1] + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + transformer_dtype = self.transformer.dtype + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]) + latents_model_input = latents / self.scheduler.config.sigma_data + + scm_timestep = torch.sin(timestep) / (torch.cos(timestep) + torch.sin(timestep)) + + scm_timestep_expanded = scm_timestep.view(-1, 1, 1, 1) + latent_model_input = latents_model_input * torch.sqrt( + scm_timestep_expanded**2 + (1 - scm_timestep_expanded) ** 2 + ) + + # predict noise model_output + noise_pred = self.transformer( + latent_model_input.to(dtype=transformer_dtype), + encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype), + encoder_attention_mask=prompt_attention_mask, + guidance=guidance, + timestep=scm_timestep, + return_dict=False, + attention_kwargs=self.attention_kwargs, + )[0] + + noise_pred = ( + (1 - 2 * scm_timestep_expanded) * latent_model_input + + (1 - 2 * scm_timestep_expanded + 2 * scm_timestep_expanded**2) * noise_pred + ) / torch.sqrt(scm_timestep_expanded**2 + (1 - scm_timestep_expanded) ** 2) + noise_pred = noise_pred.float() * self.scheduler.config.sigma_data + + # compute previous image: x_t -> x_t-1 + latents, denoised = self.scheduler.step( + noise_pred, timestep, latents, **extra_step_kwargs, return_dict=False + ) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + latents = denoised / self.scheduler.config.sigma_data + if output_type == "latent": + image = latents + else: + latents = latents.to(self.vae.dtype) + torch_accelerator_module = getattr(torch, get_device(), torch.cuda) + oom_error = ( + torch.OutOfMemoryError + if is_torch_version(">=", "2.5.0") + else torch_accelerator_module.OutOfMemoryError + ) + try: + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + except oom_error as e: + warnings.warn( + f"{e}. \n" + f"Try to use VAE tiling for large images. For example: \n" + f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)" + ) + if use_resolution_binning: + image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height) + + if not output_type == "latent": + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return SanaPipelineOutput(images=image) diff --git a/diffusers/pipelines/sana_video/__init__.py b/diffusers/pipelines/sana_video/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..73e224bf749d358e46a2cd4e762c7875b8eeb761 --- /dev/null +++ b/diffusers/pipelines/sana_video/__init__.py @@ -0,0 +1,49 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_sana_video"] = ["SanaVideoPipeline"] + _import_structure["pipeline_sana_video_i2v"] = ["SanaImageToVideoPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_sana_video import SanaVideoPipeline + from .pipeline_sana_video_i2v import SanaImageToVideoPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/pipelines/sana_video/pipeline_output.py b/diffusers/pipelines/sana_video/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..a625ad3e1d3449b48308bb7ce4a51fcfcc36b23b --- /dev/null +++ b/diffusers/pipelines/sana_video/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import torch + +from ...utils import BaseOutput + + +@dataclass +class SanaVideoPipelineOutput(BaseOutput): + r""" + Output class for Sana-Video pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or list[list[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + frames: torch.Tensor diff --git a/diffusers/pipelines/sana_video/pipeline_sana_video.py b/diffusers/pipelines/sana_video/pipeline_sana_video.py new file mode 100644 index 0000000000000000000000000000000000000000..8b44dfc1143c4dba1e44b2c41d5ed1835c6cbf17 --- /dev/null +++ b/diffusers/pipelines/sana_video/pipeline_sana_video.py @@ -0,0 +1,1017 @@ +# Copyright 2025 SANA-Video Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import inspect +import re +import urllib.parse as ul +import warnings +from typing import Any, Callable + +import torch +from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import SanaLoraLoaderMixin +from ...models import AutoencoderDC, AutoencoderKLWan, SanaVideoTransformer3DModel +from ...schedulers import DPMSolverMultistepScheduler +from ...utils import ( + BACKENDS_MAPPING, + USE_PEFT_BACKEND, + is_bs4_available, + is_ftfy_available, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import get_device, is_torch_version, randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import SanaVideoPipelineOutput + + +ASPECT_RATIO_480_BIN = { + "0.5": [448.0, 896.0], + "0.57": [480.0, 832.0], + "0.68": [528.0, 768.0], + "0.78": [560.0, 720.0], + "1.0": [624.0, 624.0], + "1.13": [672.0, 592.0], + "1.29": [720.0, 560.0], + "1.46": [768.0, 528.0], + "1.67": [816.0, 496.0], + "1.75": [832.0, 480.0], + "2.0": [896.0, 448.0], +} + + +ASPECT_RATIO_720_BIN = { + "0.5": [672.0, 1344.0], + "0.57": [704.0, 1280.0], + "0.68": [800.0, 1152.0], + "0.78": [832.0, 1088.0], + "1.0": [960.0, 960.0], + "1.13": [1024.0, 896.0], + "1.29": [1088.0, 832.0], + "1.46": [1152.0, 800.0], + "1.67": [1248.0, 736.0], + "1.75": [1280.0, 704.0], + "2.0": [1344.0, 672.0], +} + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import SanaVideoPipeline + >>> from diffusers.utils import export_to_video + + >>> pipe = SanaVideoPipeline.from_pretrained("Efficient-Large-Model/SANA-Video_2B_480p_diffusers") + >>> pipe.transformer.to(torch.bfloat16) + >>> pipe.text_encoder.to(torch.bfloat16) + >>> pipe.vae.to(torch.float32) + >>> pipe.to("cuda") + >>> motion_score = 30 + + >>> prompt = "Evening, backlight, side lighting, soft light, high contrast, mid-shot, centered composition, clean solo shot, warm color. A young Caucasian man stands in a forest, golden light glimmers on his hair as sunlight filters through the leaves. He wears a light shirt, wind gently blowing his hair and collar, light dances across his face with his movements. The background is blurred, with dappled light and soft tree shadows in the distance. The camera focuses on his lifted gaze, clear and emotional." + >>> negative_prompt = "A chaotic sequence with misshapen, deformed limbs in heavy motion blur, sudden disappearance, jump cuts, jerky movements, rapid shot changes, frames out of sync, inconsistent character shapes, temporal artifacts, jitter, and ghosting effects, creating a disorienting visual experience." + >>> motion_prompt = f" motion score: {motion_score}." + >>> prompt = prompt + motion_prompt + + >>> output = pipe( + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... height=480, + ... width=832, + ... frames=81, + ... guidance_scale=6, + ... num_inference_steps=50, + ... generator=torch.Generator(device="cuda").manual_seed(42), + ... ).frames[0] + + >>> export_to_video(output, "sana-video-output.mp4", fps=16) + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class SanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin): + r""" + Pipeline for text-to-video generation using [Sana](https://huggingface.co/papers/2509.24695). This model inherits + from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all + pipelines (downloading, saving, running on a particular device, etc.). + + Args: + tokenizer ([`GemmaTokenizer`] or [`GemmaTokenizerFast`]): + The tokenizer used to tokenize the prompt. + text_encoder ([`Gemma2PreTrainedModel`]): + Text encoder model to encode the input prompts. + vae ([`AutoencoderKLWan` or `AutoencoderDCAEV`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + transformer ([`SanaVideoTransformer3DModel`]): + Conditional Transformer to denoise the input latents. + scheduler ([`DPMSolverMultistepScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + """ + + # fmt: off + bad_punct_regex = re.compile(r"[" + "#®•©™&@·º½¾¿¡§~" + r"\)" + r"\(" + r"\]" + r"\[" + r"\}" + r"\{" + r"\|" + "\\" + r"\/" + r"\*" + r"]{1,}") + # fmt: on + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + tokenizer: GemmaTokenizer | GemmaTokenizerFast, + text_encoder: Gemma2PreTrainedModel, + vae: AutoencoderDC | AutoencoderKLWan, + transformer: SanaVideoTransformer3DModel, + scheduler: DPMSolverMultistepScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + + self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8 + + self.vae_scale_factor = self.vae_scale_factor_spatial + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline._get_gemma_prompt_embeds + def _get_gemma_prompt_embeds( + self, + prompt: str | list[str], + device: torch.device, + dtype: torch.dtype, + clean_caption: bool = False, + max_sequence_length: int = 300, + complex_human_instruction: list[str] | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt. + complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`): + If `complex_human_instruction` is not empty, the function will use the complex Human instruction for + the prompt. + """ + prompt = [prompt] if isinstance(prompt, str) else prompt + + if getattr(self, "tokenizer", None) is not None: + self.tokenizer.padding_side = "right" + + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + + # prepare complex human instruction + if not complex_human_instruction: + max_length_all = max_sequence_length + else: + chi_prompt = "\n".join(complex_human_instruction) + prompt = [chi_prompt + p for p in prompt] + num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt)) + max_length_all = num_chi_prompt_tokens + max_sequence_length - 2 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length_all, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.to(device) + + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask) + prompt_embeds = prompt_embeds[0].to(dtype=dtype, device=device) + + return prompt_embeds, prompt_attention_mask + + def encode_prompt( + self, + prompt: str | list[str], + do_classifier_free_guidance: bool = True, + negative_prompt: str = "", + num_videos_per_prompt: int = 1, + device: torch.device | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + clean_caption: bool = False, + max_sequence_length: int = 300, + complex_human_instruction: list[str] | None = None, + lora_scale: float | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt not to guide the video generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For + PixArt-Alpha, this should be "". + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_videos_per_prompt (`int`, *optional*, defaults to 1): + number of videos that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For Sana, it's should be the embeddings of the "" string. + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt. + complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`): + If `complex_human_instruction` is not empty, the function will use the complex Human instruction for + the prompt. + """ + + if device is None: + device = self._execution_device + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + else: + dtype = None + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, SanaLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if getattr(self, "tokenizer", None) is not None: + self.tokenizer.padding_side = "right" + + # See Section 3.1. of the paper. + max_length = max_sequence_length + select_index = [0] + list(range(-max_length + 1, 0)) + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=prompt, + device=device, + dtype=dtype, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + complex_human_instruction=complex_human_instruction, + ) + + prompt_embeds = prompt_embeds[:, select_index] + prompt_attention_mask = prompt_attention_mask[:, select_index] + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=negative_prompt, + device=device, + dtype=dtype, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + complex_human_instruction=False, + ) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1) + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_videos_per_prompt, 1) + else: + negative_prompt_embeds = None + negative_prompt_attention_mask = None + + if self.text_encoder is not None: + if isinstance(self, SanaLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_on_step_end_tensor_inputs=None, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if height % 32 != 0 or width % 32 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip addresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int = 16, + height: int = 480, + width: int = 832, + num_frames: int = 81, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_channels_latents, + num_latent_frames, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str = "", + num_inference_steps: int = 50, + timesteps: list[int] = None, + sigmas: list[float] = None, + guidance_scale: float = 6.0, + num_videos_per_prompt: int | None = 1, + height: int = 480, + width: int = 832, + frames: int = 81, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + clean_caption: bool = False, + use_resolution_binning: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 300, + complex_human_instruction: list[str] = [ + "Given a user prompt, generate an 'Enhanced prompt' that provides detailed visual descriptions suitable for video generation. Evaluate the level of detail in the user prompt:", + "- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, motion, and temporal relationships to create vivid and dynamic scenes.", + "- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.", + "Here are examples of how to transform or refine prompts:", + "- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat slowly settling into a curled position, peacefully falling asleep on a warm sunny windowsill, with gentle sunlight filtering through surrounding pots of blooming red flowers.", + "- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps gradually lighting up, a diverse crowd of people in colorful clothing walking past, and a double-decker bus smoothly passing by towering glass skyscrapers.", + "Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:", + "User Prompt: ", + ], + ) -> SanaVideoPipelineOutput | tuple: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the video generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality video at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 4.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate videos that are closely linked to + the text `prompt`, usually at the expense of lower video quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + height (`int`, *optional*, defaults to 480): + The height in pixels of the generated video. + width (`int`, *optional*, defaults to 832): + The width in pixels of the generated video. + frames (`int`, *optional*, defaults to 81): + The number of frames in the generated video. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated video. Choose between mp4 or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`SanaVideoPipelineOutput`] instead of a plain tuple. + attention_kwargs: + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + use_resolution_binning (`bool` defaults to `True`): + If set to `True`, the requested height and width are first mapped to the closest resolutions using + `ASPECT_RATIO_480_BIN` or `ASPECT_RATIO_720_BIN`. After the produced latents are decoded into videos, + they are resized back to the requested resolution. Useful for generating non-square videos. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to `300`): + Maximum sequence length to use with the `prompt`. + complex_human_instruction (`list[str]`, *optional*): + Instructions for complex human attention: + https://github.com/NVlabs/Sana/blob/main/configs/sana_app_config/Sana_1600M_app.yaml#L55. + + Examples: + + Returns: + [`~pipelines.sana_video.pipeline_output.SanaVideoPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.sana_video.pipeline_output.SanaVideoPipelineOutput`] is + returned, otherwise a `tuple` is returned where the first element is a list with the generated videos + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + if use_resolution_binning: + if self.transformer.config.sample_size == 30: + aspect_ratio_bin = ASPECT_RATIO_480_BIN + elif self.transformer.config.sample_size == 22: + aspect_ratio_bin = ASPECT_RATIO_720_BIN + else: + raise ValueError("Invalid sample size") + orig_height, orig_width = height, width + height, width = self.video_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin) + + self.check_inputs( + prompt, + height, + width, + callback_on_step_end_tensor_inputs, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default height and width to transformer + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None + + # 3. Encode input prompt + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt, + self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + complex_human_instruction=complex_human_instruction, + lora_scale=lora_scale, + ) + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + latent_channels, + height, + width, + frames, + torch.float32, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + transformer_dtype = self.transformer.dtype + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # predict noise model_output + noise_pred = self.transformer( + latent_model_input.to(dtype=transformer_dtype), + encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype), + encoder_attention_mask=prompt_attention_mask, + timestep=timestep, + return_dict=False, + attention_kwargs=self.attention_kwargs, + )[0] + noise_pred = noise_pred.float() + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # learned sigma + if self.transformer.config.out_channels // 2 == latent_channels: + noise_pred = noise_pred.chunk(2, dim=1)[0] + + # compute previous image: x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + video = latents + else: + latents = latents.to(self.vae.dtype) + torch_accelerator_module = getattr(torch, get_device(), torch.cuda) + oom_error = ( + torch.OutOfMemoryError + if is_torch_version(">=", "2.5.0") + else torch_accelerator_module.OutOfMemoryError + ) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + try: + video = self.vae.decode(latents, return_dict=False)[0] + except oom_error as e: + warnings.warn( + f"{e}. \n" + f"Try to use VAE tiling for large images. For example: \n" + f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)" + ) + + if use_resolution_binning: + video = self.video_processor.resize_and_crop_tensor(video, orig_width, orig_height) + + video = self.video_processor.postprocess_video(video, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return SanaVideoPipelineOutput(frames=video) diff --git a/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py b/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py new file mode 100644 index 0000000000000000000000000000000000000000..b90d7c6f5a60ae7bdeadd067a9c462fe1f7e3d50 --- /dev/null +++ b/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py @@ -0,0 +1,1066 @@ +# Copyright 2025 SANA-Video Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import inspect +import re +import urllib.parse as ul +import warnings +from typing import Any, Callable + +import PIL +import torch +from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput +from ...loaders import SanaLoraLoaderMixin +from ...models import AutoencoderDC, AutoencoderKLWan, SanaVideoTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + BACKENDS_MAPPING, + USE_PEFT_BACKEND, + is_bs4_available, + is_ftfy_available, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import get_device, is_torch_version, randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import SanaVideoPipelineOutput +from .pipeline_sana_video import ASPECT_RATIO_480_BIN, ASPECT_RATIO_720_BIN + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import SanaImageToVideoPipeline + >>> from diffusers.utils import export_to_video, load_image + + >>> pipe = SanaImageToVideoPipeline.from_pretrained("Efficient-Large-Model/SANA-Video_2B_480p_diffusers") + >>> pipe.transformer.to(torch.bfloat16) + >>> pipe.text_encoder.to(torch.bfloat16) + >>> pipe.vae.to(torch.float32) + >>> pipe.to("cuda") + >>> motion_score = 30 + + >>> prompt = "A woman stands against a stunning sunset backdrop, her long, wavy brown hair gently blowing in the breeze. She wears a sleeveless, light-colored blouse with a deep V-neckline, which accentuates her graceful posture. The warm hues of the setting sun cast a golden glow across her face and hair, creating a serene and ethereal atmosphere. The background features a blurred landscape with soft, rolling hills and scattered clouds, adding depth to the scene. The camera remains steady, capturing the tranquil moment from a medium close-up angle." + >>> negative_prompt = "A chaotic sequence with misshapen, deformed limbs in heavy motion blur, sudden disappearance, jump cuts, jerky movements, rapid shot changes, frames out of sync, inconsistent character shapes, temporal artifacts, jitter, and ghosting effects, creating a disorienting visual experience." + >>> motion_prompt = f" motion score: {motion_score}." + >>> prompt = prompt + motion_prompt + >>> image = load_image("https://raw.githubusercontent.com/NVlabs/Sana/refs/heads/main/asset/samples/i2v-1.png") + + >>> output = pipe( + ... image=image, + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... height=480, + ... width=832, + ... frames=81, + ... guidance_scale=6, + ... num_inference_steps=50, + ... generator=torch.Generator(device="cuda").manual_seed(42), + ... ).frames[0] + + >>> export_to_video(output, "sana-ti2v-output.mp4", fps=16) + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class SanaImageToVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin): + r""" + Pipeline for image/text-to-video generation using [Sana](https://huggingface.co/papers/2509.24695). This model + inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all + pipelines (downloading, saving, running on a particular device, etc.). + + Args: + tokenizer ([`GemmaTokenizer`] or [`GemmaTokenizerFast`]): + The tokenizer used to tokenize the prompt. + text_encoder ([`Gemma2PreTrainedModel`]): + Text encoder model to encode the input prompts. + vae ([`AutoencoderKLWan` or `AutoencoderDCAEV`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + transformer ([`SanaVideoTransformer3DModel`]): + Conditional Transformer to denoise the input latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + """ + + # fmt: off + bad_punct_regex = re.compile(r"[" + "#®•©™&@·º½¾¿¡§~" + r"\)" + r"\(" + r"\]" + r"\[" + r"\}" + r"\{" + r"\|" + "\\" + r"\/" + r"\*" + r"]{1,}") + # fmt: on + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + tokenizer: GemmaTokenizer | GemmaTokenizerFast, + text_encoder: Gemma2PreTrainedModel, + vae: AutoencoderDC | AutoencoderKLWan, + transformer: SanaVideoTransformer3DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + + self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8 + + self.vae_scale_factor = self.vae_scale_factor_spatial + + self.transformer_spatial_patch_size = ( + self.transformer.config.patch_size[1] if getattr(self, "transformer", None) is not None else 1 + ) + self.transformer_temporal_patch_size = ( + self.transformer.config.patch_size[0] if getattr(self, "transformer") is not None else 1 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline._get_gemma_prompt_embeds + def _get_gemma_prompt_embeds( + self, + prompt: str | list[str], + device: torch.device, + dtype: torch.dtype, + clean_caption: bool = False, + max_sequence_length: int = 300, + complex_human_instruction: list[str] | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt. + complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`): + If `complex_human_instruction` is not empty, the function will use the complex Human instruction for + the prompt. + """ + prompt = [prompt] if isinstance(prompt, str) else prompt + + if getattr(self, "tokenizer", None) is not None: + self.tokenizer.padding_side = "right" + + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + + # prepare complex human instruction + if not complex_human_instruction: + max_length_all = max_sequence_length + else: + chi_prompt = "\n".join(complex_human_instruction) + prompt = [chi_prompt + p for p in prompt] + num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt)) + max_length_all = num_chi_prompt_tokens + max_sequence_length - 2 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length_all, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.to(device) + + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask) + prompt_embeds = prompt_embeds[0].to(dtype=dtype, device=device) + + return prompt_embeds, prompt_attention_mask + + # Copied from diffusers.pipelines.sana_video.pipeline_sana_video.SanaVideoPipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + do_classifier_free_guidance: bool = True, + negative_prompt: str = "", + num_videos_per_prompt: int = 1, + device: torch.device | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + clean_caption: bool = False, + max_sequence_length: int = 300, + complex_human_instruction: list[str] | None = None, + lora_scale: float | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt not to guide the video generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For + PixArt-Alpha, this should be "". + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_videos_per_prompt (`int`, *optional*, defaults to 1): + number of videos that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For Sana, it's should be the embeddings of the "" string. + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt. + complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`): + If `complex_human_instruction` is not empty, the function will use the complex Human instruction for + the prompt. + """ + + if device is None: + device = self._execution_device + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + else: + dtype = None + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, SanaLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if getattr(self, "tokenizer", None) is not None: + self.tokenizer.padding_side = "right" + + # See Section 3.1. of the paper. + max_length = max_sequence_length + select_index = [0] + list(range(-max_length + 1, 0)) + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=prompt, + device=device, + dtype=dtype, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + complex_human_instruction=complex_human_instruction, + ) + + prompt_embeds = prompt_embeds[:, select_index] + prompt_attention_mask = prompt_attention_mask[:, select_index] + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=negative_prompt, + device=device, + dtype=dtype, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + complex_human_instruction=False, + ) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1) + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_videos_per_prompt, 1) + else: + negative_prompt_embeds = None + negative_prompt_attention_mask = None + + if self.text_encoder is not None: + if isinstance(self, SanaLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + height, + width, + callback_on_step_end_tensor_inputs=None, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if height % 32 != 0 or width % 32 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") + + if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image): + raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip addresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + def prepare_latents( + self, + image: PipelineImageInput, + batch_size: int, + num_channels_latents: int = 16, + height: int = 480, + width: int = 832, + num_frames: int = 81, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_channels_latents, + num_latent_frames, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + image = image.unsqueeze(2) # [B, C, 1, H, W] + image = image.to(device=device, dtype=self.vae.dtype) + + if isinstance(generator, list): + image_latents = [retrieve_latents(self.vae.encode(image), sample_mode="argmax") for _ in generator] + image_latents = torch.cat(image_latents) + else: + image_latents = retrieve_latents(self.vae.encode(image), sample_mode="argmax") + image_latents = image_latents.repeat(batch_size, 1, 1, 1, 1) + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, -1, 1, 1, 1) + .to(image_latents.device, image_latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, -1, 1, 1, 1).to( + image_latents.device, image_latents.dtype + ) + image_latents = (image_latents - latents_mean) * latents_std + + latents[:, :, 0:1] = image_latents.to(dtype) + + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput, + prompt: str | list[str] = None, + negative_prompt: str = "", + num_inference_steps: int = 50, + timesteps: list[int] = None, + sigmas: list[float] = None, + guidance_scale: float = 6.0, + num_videos_per_prompt: int | None = 1, + height: int = 480, + width: int = 832, + frames: int = 81, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + clean_caption: bool = False, + use_resolution_binning: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 300, + complex_human_instruction: list[str] = [ + "Given a user prompt, generate an 'Enhanced prompt' that provides detailed visual descriptions suitable for video generation. Evaluate the level of detail in the user prompt:", + "- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, motion, and temporal relationships to create vivid and dynamic scenes.", + "- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.", + "Here are examples of how to transform or refine prompts:", + "- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat slowly settling into a curled position, peacefully falling asleep on a warm sunny windowsill, with gentle sunlight filtering through surrounding pots of blooming red flowers.", + "- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps gradually lighting up, a diverse crowd of people in colorful clothing walking past, and a double-decker bus smoothly passing by towering glass skyscrapers.", + "Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:", + "User Prompt: ", + ], + ) -> SanaVideoPipelineOutput | tuple: + """ + Function invoked when calling the pipeline for generation. + + Args: + image (`PipelineImageInput`): + The input image to condition the video generation on. The first frame of the generated video will be + conditioned on this image. + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the video generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality video at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 4.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate videos that are closely linked to + the text `prompt`, usually at the expense of lower video quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + height (`int`, *optional*, defaults to 480): + The height in pixels of the generated video. + width (`int`, *optional*, defaults to 832): + The width in pixels of the generated video. + frames (`int`, *optional*, defaults to 81): + The number of frames in the generated video. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated video. Choose between mp4 or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`SanaVideoPipelineOutput`] instead of a plain tuple. + attention_kwargs: + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + use_resolution_binning (`bool` defaults to `True`): + If set to `True`, the requested height and width are first mapped to the closest resolutions using + `ASPECT_RATIO_480_BIN` or `ASPECT_RATIO_720_BIN`. After the produced latents are decoded into videos, + they are resized back to the requested resolution. Useful for generating non-square videos. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to `300`): + Maximum sequence length to use with the `prompt`. + complex_human_instruction (`list[str]`, *optional*): + Instructions for complex human attention: + https://github.com/NVlabs/Sana/blob/main/configs/sana_app_config/Sana_1600M_app.yaml#L55. + + Examples: + + Returns: + [`~pipelines.sana_video.pipeline_output.SanaVideoPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.sana_video.pipeline_output.SanaVideoPipelineOutput`] is + returned, otherwise a `tuple` is returned where the first element is a list with the generated videos + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + if use_resolution_binning: + if self.transformer.config.sample_size == 30: + aspect_ratio_bin = ASPECT_RATIO_480_BIN + elif self.transformer.config.sample_size == 22: + aspect_ratio_bin = ASPECT_RATIO_720_BIN + else: + raise ValueError("Invalid sample size") + orig_height, orig_width = height, width + height, width = self.video_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin) + + self.check_inputs( + prompt, + image, + height, + width, + callback_on_step_end_tensor_inputs, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default height and width to transformer + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None + + # 3. Encode input prompt + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt, + self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + complex_human_instruction=complex_human_instruction, + lora_scale=lora_scale, + ) + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32) + + latents = self.prepare_latents( + image, + batch_size * num_videos_per_prompt, + latent_channels, + height, + width, + frames, + torch.float32, + device, + generator, + latents, + ) + + conditioning_mask = latents.new_zeros( + batch_size, + 1, + latents.shape[2] // self.transformer_temporal_patch_size, + latents.shape[3] // self.transformer_spatial_patch_size, + latents.shape[4] // self.transformer_spatial_patch_size, + ) + conditioning_mask[:, :, 0] = 1.0 + if self.do_classifier_free_guidance: + conditioning_mask = torch.cat([conditioning_mask, conditioning_mask]) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + transformer_dtype = self.transformer.dtype + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(conditioning_mask.shape) + timestep = timestep * (1 - conditioning_mask) + + # predict noise model_output + noise_pred = self.transformer( + latent_model_input.to(dtype=transformer_dtype), + encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype), + encoder_attention_mask=prompt_attention_mask, + timestep=timestep, + return_dict=False, + attention_kwargs=self.attention_kwargs, + )[0] + noise_pred = noise_pred.float() + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + timestep, _ = timestep.chunk(2) + + # learned sigma + if self.transformer.config.out_channels // 2 == latent_channels: + noise_pred = noise_pred.chunk(2, dim=1)[0] + + noise_pred = noise_pred[:, :, 1:] + noise_latents = latents[:, :, 1:] + pred_latents = self.scheduler.step( + noise_pred, t, noise_latents, **extra_step_kwargs, return_dict=False + )[0] + + latents = torch.cat([latents[:, :, :1], pred_latents], dim=2) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + video = latents + else: + latents = latents.to(self.vae.dtype) + torch_accelerator_module = getattr(torch, get_device(), torch.cuda) + oom_error = ( + torch.OutOfMemoryError + if is_torch_version(">=", "2.5.0") + else torch_accelerator_module.OutOfMemoryError + ) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + try: + video = self.vae.decode(latents, return_dict=False)[0] + except oom_error as e: + warnings.warn( + f"{e}. \n" + f"Try to use VAE tiling for large images. For example: \n" + f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)" + ) + + if use_resolution_binning: + video = self.video_processor.resize_and_crop_tensor(video, orig_width, orig_height) + + video = self.video_processor.postprocess_video(video, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return SanaVideoPipelineOutput(frames=video) diff --git a/diffusers/pipelines/skyreels_v2/pipeline_output.py b/diffusers/pipelines/skyreels_v2/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..dac2316362ec20c1874b514637ef1681fb7dd7d3 --- /dev/null +++ b/diffusers/pipelines/skyreels_v2/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import torch + +from diffusers.utils import BaseOutput + + +@dataclass +class SkyReelsV2PipelineOutput(BaseOutput): + r""" + Output class for SkyReelsV2 pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or list[list[PIL.Image.Image]]): + list of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + frames: torch.Tensor diff --git a/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py b/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py new file mode 100644 index 0000000000000000000000000000000000000000..7c24b898e0bb0de16bdb866a67007ce55a5f8d30 --- /dev/null +++ b/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py @@ -0,0 +1,745 @@ +# Copyright 2025 The SkyReels-V2 Team, The Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +from typing import Any, Callable + +import PIL +import regex as re +import torch +from transformers import AutoTokenizer, CLIPProcessor, CLIPVisionModelWithProjection, T5EncoderModel, UMT5EncoderModel + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput +from ...loaders import SkyReelsV2LoraLoaderMixin +from ...models import AutoencoderKLWan, SkyReelsV2Transformer3DModel +from ...schedulers import UniPCMultistepScheduler +from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import SkyReelsV2PipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_ftfy_available(): + import ftfy + + +EXAMPLE_DOC_STRING = """\ + Examples: + ```py + >>> import torch + >>> from diffusers import ( + ... SkyReelsV2ImageToVideoPipeline, + ... UniPCMultistepScheduler, + ... AutoencoderKLWan, + ... ) + >>> from diffusers.utils import export_to_video + >>> from PIL import Image + + >>> # Load the pipeline + >>> # Available models: + >>> # - Skywork/SkyReels-V2-I2V-1.3B-540P-Diffusers + >>> # - Skywork/SkyReels-V2-I2V-14B-540P-Diffusers + >>> # - Skywork/SkyReels-V2-I2V-14B-720P-Diffusers + >>> vae = AutoencoderKLWan.from_pretrained( + ... "Skywork/SkyReels-V2-I2V-14B-720P-Diffusers", + ... subfolder="vae", + ... torch_dtype=torch.float32, + ... ) + >>> pipe = SkyReelsV2ImageToVideoPipeline.from_pretrained( + ... "Skywork/SkyReels-V2-I2V-14B-720P-Diffusers", + ... vae=vae, + ... torch_dtype=torch.bfloat16, + ... ) + >>> flow_shift = 5.0 # 8.0 for T2V, 5.0 for I2V + >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift) + >>> pipe = pipe.to("cuda") + + >>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." + >>> image = Image.open("path/to/image.png") + + >>> output = pipe( + ... image=image, + ... prompt=prompt, + ... num_inference_steps=50, + ... height=544, + ... width=960, + ... guidance_scale=5.0, # 6.0 for T2V, 5.0 for I2V + ... num_frames=97, + ... ).frames[0] + >>> export_to_video(output, "video.mp4", fps=24, quality=8) + ``` +""" + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class SkyReelsV2ImageToVideoPipeline(DiffusionPipeline, SkyReelsV2LoraLoaderMixin): + r""" + Pipeline for Image-to-Video (i2v) generation using SkyReels-V2. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + tokenizer ([`T5Tokenizer`]): + Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer), + specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + text_encoder ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + image_encoder ([`CLIPVisionModelWithProjection`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModelWithProjection), + specifically the + [clip-vit-huge-patch14](https://github.com/mlfoundations/open_clip/blob/main/docs/PRETRAINED.md#vit-h14-xlm-roberta-large) + variant. + transformer ([`SkyReelsV2Transformer3DModel`]): + Conditional Transformer to denoise the input latents. + scheduler ([`UniPCMultistepScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLWan`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: T5EncoderModel | UMT5EncoderModel, + image_encoder: CLIPVisionModelWithProjection, + image_processor: CLIPProcessor, + transformer: SkyReelsV2Transformer3DModel, + vae: AutoencoderKLWan, + scheduler: UniPCMultistepScheduler, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + image_encoder=image_encoder, + transformer=transformer, + scheduler=scheduler, + image_processor=image_processor, + ) + + self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + self.image_processor = image_processor + + # Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline.encode_image + def encode_image( + self, + image: PipelineImageInput, + device: torch.device | None = None, + ): + device = device or self._execution_device + image = self.image_processor(images=image, return_tensors="pt").to(device) + image_embeds = self.image_encoder(**image, output_hidden_states=True) + return image_embeds.hidden_states[-2] + + # Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 226, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def check_inputs( + self, + prompt, + negative_prompt, + image, + height, + width, + prompt_embeds=None, + negative_prompt_embeds=None, + image_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if image is not None and image_embeds is not None: + raise ValueError( + f"Cannot forward both `image`: {image} and `image_embeds`: {image_embeds}. Please make sure to" + " only forward one of the two." + ) + if image is None and image_embeds is None: + raise ValueError( + "Provide either `image` or `prompt_embeds`. Cannot leave both `image` and `image_embeds` undefined." + ) + if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image): + raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}") + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif negative_prompt is not None and ( + not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + def prepare_latents( + self, + image: PipelineImageInput, + batch_size: int, + num_channels_latents: int = 16, + height: int = 480, + width: int = 832, + num_frames: int = 81, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + last_image: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + latent_height = height // self.vae_scale_factor_spatial + latent_width = width // self.vae_scale_factor_spatial + + shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + image = image.unsqueeze(2) + if last_image is None: + video_condition = torch.cat( + [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2 + ) + else: + last_image = last_image.unsqueeze(2) + video_condition = torch.cat( + [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 2, height, width), last_image], + dim=2, + ) + video_condition = video_condition.to(device=device, dtype=self.vae.dtype) + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + + if isinstance(generator, list): + latent_condition = [ + retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") for _ in generator + ] + latent_condition = torch.cat(latent_condition) + else: + latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") + latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1) + + latent_condition = latent_condition.to(dtype) + latent_condition = (latent_condition - latents_mean) * latents_std + + mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width) + + if last_image is None: + mask_lat_size[:, :, list(range(1, num_frames))] = 0 + else: + mask_lat_size[:, :, list(range(1, num_frames - 1))] = 0 + first_frame_mask = mask_lat_size[:, :, 0:1] + first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal) + mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) + mask_lat_size = mask_lat_size.view(batch_size, -1, self.vae_scale_factor_temporal, latent_height, latent_width) + mask_lat_size = mask_lat_size.transpose(1, 2) + mask_lat_size = mask_lat_size.to(latent_condition.device) + + return latents, torch.concat([mask_lat_size, latent_condition], dim=1) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput, + prompt: str | list[str] = None, + negative_prompt: str | list[str] = None, + height: int = 544, + width: int = 960, + num_frames: int = 97, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + num_videos_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + image_embeds: torch.Tensor | None = None, + last_image: torch.Tensor | None = None, + output_type: str | None = "np", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + The call function to the pipeline for generation. + + Args: + image (`PipelineImageInput`): + The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`. + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, defaults to `544`): + The height of the generated video. + width (`int`, defaults to `960`): + The width of the generated video. + num_frames (`int`, defaults to `97`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to `5.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `negative_prompt` input argument. + image_embeds (`torch.Tensor`, *optional*): + Pre-generated image embeddings. Can be used to easily tweak image inputs (weighting). If not provided, + image embeddings are generated from the `image` input argument. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, *optional*, defaults to `512`): + The maximum sequence length of the prompt. + + Examples: + + Returns: + [`~SkyReelsV2PipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`SkyReelsV2PipelineOutput`] is returned, otherwise a `tuple` is returned + where the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + negative_prompt, + image, + height, + width, + prompt_embeds, + negative_prompt_embeds, + image_embeds, + callback_on_step_end_tensor_inputs, + ) + + if num_frames % self.vae_scale_factor_temporal != 1: + logger.warning( + f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." + ) + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + num_frames = max(num_frames, 1) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + + # Encode image embedding + transformer_dtype = self.transformer.dtype + prompt_embeds = prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + + if image_embeds is None: + if last_image is None: + image_embeds = self.encode_image(image, device) + else: + image_embeds = self.encode_image([image, last_image], device) + image_embeds = image_embeds.repeat(batch_size, 1, 1) + image_embeds = image_embeds.to(transformer_dtype) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.vae.config.z_dim + image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32) + if last_image is not None: + last_image = self.video_processor.preprocess(last_image, height=height, width=width).to( + device, dtype=torch.float32 + ) + latents, condition = self.prepare_latents( + image, + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + last_image, + ) + + # 6. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype) + timestep = t.expand(latents.shape[0]) + + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + encoder_hidden_states_image=image_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + with self.transformer.cache_context("uncond"): + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states_image=image_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return SkyReelsV2PipelineOutput(frames=video) diff --git a/diffusers/pipelines/stable_diffusion_diffedit/__init__.py b/diffusers/pipelines/stable_diffusion_diffedit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e2145edb96c6be124abf9e9a21b9a5e8a3f3d641 --- /dev/null +++ b/diffusers/pipelines/stable_diffusion_diffedit/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_stable_diffusion_diffedit"] = ["StableDiffusionDiffEditPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_stable_diffusion_diffedit import StableDiffusionDiffEditPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py b/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py new file mode 100644 index 0000000000000000000000000000000000000000..43bc2eb955c7951fe896d07c1bb377f86d2449d7 --- /dev/null +++ b/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py @@ -0,0 +1,1548 @@ +# Copyright 2025 DiffEdit Authors and Pix2Pix Zero Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass +from typing import Any, Callable + +import numpy as np +import PIL.Image +import torch +from packaging import version +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from ...configuration_utils import FrozenDict +from ...image_processor import VaeImageProcessor +from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import DDIMInverseScheduler, KarrasDiffusionSchedulers +from ...utils import ( + PIL_INTERPOLATION, + USE_PEFT_BACKEND, + BaseOutput, + deprecate, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion import StableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class DiffEditInversionPipelineOutput(BaseOutput): + """ + Output class for Stable Diffusion pipelines. + + Args: + latents (`torch.Tensor`) + inverted latents tensor + images (`list[PIL.Image.Image]` or `np.ndarray`) + list of denoised PIL images of length `num_timesteps * batch_size` or numpy array of shape `(num_timesteps, + batch_size, height, width, num_channels)`. PIL images or numpy array present the denoised images of the + diffusion pipeline. + """ + + latents: torch.Tensor + images: list[PIL.Image.Image] | np.ndarray + + +EXAMPLE_DOC_STRING = """ + + ```py + >>> import PIL + >>> import requests + >>> import torch + >>> from io import BytesIO + + >>> from diffusers import StableDiffusionDiffEditPipeline + + + >>> def download_image(url): + ... response = requests.get(url) + ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") + + + >>> img_url = "https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png" + + >>> init_image = download_image(img_url).resize((768, 768)) + + >>> pipeline = StableDiffusionDiffEditPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16 + ... ) + + >>> pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) + >>> pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config) + >>> pipeline.enable_model_cpu_offload() + + >>> mask_prompt = "A bowl of fruits" + >>> prompt = "A bowl of pears" + + >>> mask_image = pipeline.generate_mask(image=init_image, source_prompt=prompt, target_prompt=mask_prompt) + >>> image_latents = pipeline.invert(image=init_image, prompt=mask_prompt).latents + >>> image = pipeline(prompt=prompt, mask_image=mask_image, image_latents=image_latents).images[0] + ``` +""" + +EXAMPLE_INVERT_DOC_STRING = """ + ```py + >>> import PIL + >>> import requests + >>> import torch + >>> from io import BytesIO + + >>> from diffusers import StableDiffusionDiffEditPipeline + + + >>> def download_image(url): + ... response = requests.get(url) + ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") + + + >>> img_url = "https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png" + + >>> init_image = download_image(img_url).resize((768, 768)) + + >>> pipeline = StableDiffusionDiffEditPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16 + ... ) + + >>> pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) + >>> pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config) + >>> pipeline.enable_model_cpu_offload() + + >>> prompt = "A bowl of fruits" + + >>> inverted_latents = pipeline.invert(image=init_image, prompt=prompt).latents + ``` +""" + + +def auto_corr_loss(hidden_states, generator=None): + reg_loss = 0.0 + for i in range(hidden_states.shape[0]): + for j in range(hidden_states.shape[1]): + noise = hidden_states[i : i + 1, j : j + 1, :, :] + while True: + roll_amount = torch.randint(noise.shape[2] // 2, (1,), generator=generator).item() + reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=2)).mean() ** 2 + reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=3)).mean() ** 2 + + if noise.shape[2] <= 8: + break + noise = torch.nn.functional.avg_pool2d(noise, kernel_size=2) + return reg_loss + + +def kl_divergence(hidden_states): + return hidden_states.var() + hidden_states.mean() ** 2 - 1 - torch.log(hidden_states.var() + 1e-7) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess +def preprocess(image): + deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead" + deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False) + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 + + image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + return image + + +def preprocess_mask(mask, batch_size: int = 1): + if not isinstance(mask, torch.Tensor): + # preprocess mask + if isinstance(mask, (PIL.Image.Image, np.ndarray)): + mask = [mask] + + if isinstance(mask, list): + if isinstance(mask[0], PIL.Image.Image): + mask = [np.array(m.convert("L")).astype(np.float32) / 255.0 for m in mask] + if isinstance(mask[0], np.ndarray): + mask = np.stack(mask, axis=0) if mask[0].ndim < 3 else np.concatenate(mask, axis=0) + mask = torch.from_numpy(mask) + elif isinstance(mask[0], torch.Tensor): + mask = torch.stack(mask, dim=0) if mask[0].ndim < 3 else torch.cat(mask, dim=0) + + # Batch and add channel dim for single mask + if mask.ndim == 2: + mask = mask.unsqueeze(0).unsqueeze(0) + + # Batch single mask or add channel dim + if mask.ndim == 3: + # Single batched mask, no channel dim or single mask not batched but channel dim + if mask.shape[0] == 1: + mask = mask.unsqueeze(0) + + # Batched masks no channel dim + else: + mask = mask.unsqueeze(1) + + # Check mask shape + if batch_size > 1: + if mask.shape[0] == 1: + mask = torch.cat([mask] * batch_size) + elif mask.shape[0] > 1 and mask.shape[0] != batch_size: + raise ValueError( + f"`mask_image` with batch size {mask.shape[0]} cannot be broadcasted to batch size {batch_size} " + f"inferred by prompt inputs" + ) + + if mask.shape[1] != 1: + raise ValueError(f"`mask_image` must have 1 channel, but has {mask.shape[1]} channels") + + # Check mask is in [0, 1] + if mask.min() < 0 or mask.max() > 1: + raise ValueError("`mask_image` should be in [0, 1] range") + + # Binarize mask + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + + return mask + + +class StableDiffusionDiffEditPipeline( + DeprecatedPipelineMixin, + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionLoraLoaderMixin, +): + r""" + > [!WARNING] > This is an experimental feature! + + Pipeline for text-guided image inpainting using Stable Diffusion and DiffEdit. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading and saving methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. + inverse_scheduler ([`DDIMInverseScheduler`]): + A scheduler to be used in combination with `unet` to fill in the unmasked part of the input latents. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for + more details about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + _last_supported_version = "0.33.1" + + model_cpu_offload_seq = "text_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "inverse_scheduler"] + _exclude_from_cpu_offload = ["safety_checker"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + inverse_scheduler: DDIMInverseScheduler, + requires_safety_checker: bool = True, + ): + super().__init__() + + if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if scheduler is not None and getattr(scheduler.config, "skip_prk_steps", True) is False: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration" + " `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make" + " sure to update the config accordingly as not setting `skip_prk_steps` in the config might lead to" + " incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face" + " Hub, it would be very nice if you could open a Pull request for the" + " `scheduler/scheduler_config.json` file" + ) + deprecate("skip_prk_steps not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["skip_prk_steps"] = True + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = ( + unet is not None + and hasattr(unet.config, "_diffusers_version") + and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0") + ) + is_unet_sample_size_less_64 = ( + unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5" + " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + inverse_scheduler=inverse_scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def check_inputs( + self, + prompt, + strength, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if (strength is None) or (strength is not None and (strength < 0 or strength > 1)): + raise ValueError( + f"The value of `strength` should in [0.0, 1.0] but is, but is {strength} of type {type(strength)}." + ) + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def check_source_inputs( + self, + source_prompt=None, + source_negative_prompt=None, + source_prompt_embeds=None, + source_negative_prompt_embeds=None, + ): + if source_prompt is not None and source_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `source_prompt`: {source_prompt} and `source_prompt_embeds`: {source_prompt_embeds}." + " Please make sure to only forward one of the two." + ) + elif source_prompt is None and source_prompt_embeds is None: + raise ValueError( + "Provide either `source_image` or `source_prompt_embeds`. Cannot leave all both of the arguments undefined." + ) + elif source_prompt is not None and ( + not isinstance(source_prompt, str) and not isinstance(source_prompt, list) + ): + raise ValueError(f"`source_prompt` has to be of type `str` or `list` but is {type(source_prompt)}") + + if source_negative_prompt is not None and source_negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `source_negative_prompt`: {source_negative_prompt} and `source_negative_prompt_embeds`:" + f" {source_negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if source_prompt_embeds is not None and source_negative_prompt_embeds is not None: + if source_prompt_embeds.shape != source_negative_prompt_embeds.shape: + raise ValueError( + "`source_prompt_embeds` and `source_negative_prompt_embeds` must have the same shape when passed" + f" directly, but got: `source_prompt_embeds` {source_prompt_embeds.shape} !=" + f" `source_negative_prompt_embeds` {source_negative_prompt_embeds.shape}." + ) + + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + def get_inverse_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + + # safety for t_start overflow to prevent empty timsteps slice + if t_start == 0: + return self.inverse_scheduler.timesteps, num_inference_steps + timesteps = self.inverse_scheduler.timesteps[:-t_start] + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_image_latents(self, image, batch_size, dtype, device, generator=None): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + if image.shape[1] == 4: + latents = image + + else: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if isinstance(generator, list): + latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + ] + latents = torch.cat(latents, dim=0) + else: + latents = self.vae.encode(image).latent_dist.sample(generator) + + latents = self.vae.config.scaling_factor * latents + + if batch_size != latents.shape[0]: + if batch_size % latents.shape[0] == 0: + # expand image_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_latents_per_image = batch_size // latents.shape[0] + latents = torch.cat([latents] * additional_latents_per_image, dim=0) + else: + raise ValueError( + f"Cannot duplicate `image` of batch size {latents.shape[0]} to {batch_size} text prompts." + ) + else: + latents = torch.cat([latents], dim=0) + + return latents + + def get_epsilon(self, model_output: torch.Tensor, sample: torch.Tensor, timestep: int): + pred_type = self.inverse_scheduler.config.prediction_type + alpha_prod_t = self.inverse_scheduler.alphas_cumprod[timestep] + + beta_prod_t = 1 - alpha_prod_t + + if pred_type == "epsilon": + return model_output + elif pred_type == "sample": + return (sample - alpha_prod_t ** (0.5) * model_output) / beta_prod_t ** (0.5) + elif pred_type == "v_prediction": + return (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"prediction_type given as {pred_type} must be one of `epsilon`, `sample`, or `v_prediction`" + ) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def generate_mask( + self, + image: torch.Tensor | PIL.Image.Image = None, + target_prompt: str | list[str] | None = None, + target_negative_prompt: str | list[str] | None = None, + target_prompt_embeds: torch.Tensor | None = None, + target_negative_prompt_embeds: torch.Tensor | None = None, + source_prompt: str | list[str] | None = None, + source_negative_prompt: str | list[str] | None = None, + source_prompt_embeds: torch.Tensor | None = None, + source_negative_prompt_embeds: torch.Tensor | None = None, + num_maps_per_mask: int | None = 10, + mask_encode_strength: float | None = 0.5, + mask_thresholding_ratio: float | None = 3.0, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + generator: torch.Generator | list[torch.Generator] | None = None, + output_type: str | None = "np", + cross_attention_kwargs: dict[str, Any] | None = None, + ): + r""" + Generate a latent mask given a mask prompt, a target prompt, and an image. + + Args: + image (`PIL.Image.Image`): + `Image` or tensor representing an image batch to be used for computing the mask. + target_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide semantic mask generation. If not defined, you need to pass + `prompt_embeds`. + target_negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + target_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + target_negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + source_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide semantic mask generation using DiffEdit. If not defined, you need to + pass `source_prompt_embeds` or `source_image` instead. + source_negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide semantic mask generation away from using DiffEdit. If not defined, you + need to pass `source_negative_prompt_embeds` or `source_image` instead. + source_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings to guide the semantic mask generation. Can be used to easily tweak text + inputs (prompt weighting). If not provided, text embeddings are generated from `source_prompt` input + argument. + source_negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings to negatively guide the semantic mask generation. Can be used to easily + tweak text inputs (prompt weighting). If not provided, text embeddings are generated from + `source_negative_prompt` input argument. + num_maps_per_mask (`int`, *optional*, defaults to 10): + The number of noise maps sampled to generate the semantic mask using DiffEdit. + mask_encode_strength (`float`, *optional*, defaults to 0.5): + The strength of the noise maps sampled to generate the semantic mask using DiffEdit. Must be between 0 + and 1. + mask_thresholding_ratio (`float`, *optional*, defaults to 3.0): + The maximum multiple of the mean absolute difference used to clamp the semantic guidance map before + mask binarization. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the + [`~models.attention_processor.AttnProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + + Examples: + + Returns: + `list[PIL.Image.Image]` or `np.array`: + When returning a `list[PIL.Image.Image]`, the list consists of a batch of single-channel binary images + with dimensions `(height // self.vae_scale_factor, width // self.vae_scale_factor)`. If it's + `np.array`, the shape is `(batch_size, height // self.vae_scale_factor, width // + self.vae_scale_factor)`. + """ + + # 1. Check inputs (Provide dummy argument for callback_steps) + self.check_inputs( + target_prompt, + mask_encode_strength, + 1, + target_negative_prompt, + target_prompt_embeds, + target_negative_prompt_embeds, + ) + + self.check_source_inputs( + source_prompt, + source_negative_prompt, + source_prompt_embeds, + source_negative_prompt_embeds, + ) + + if (num_maps_per_mask is None) or ( + num_maps_per_mask is not None and (not isinstance(num_maps_per_mask, int) or num_maps_per_mask <= 0) + ): + raise ValueError( + f"`num_maps_per_mask` has to be a positive integer but is {num_maps_per_mask} of type" + f" {type(num_maps_per_mask)}." + ) + + if mask_thresholding_ratio is None or mask_thresholding_ratio <= 0: + raise ValueError( + f"`mask_thresholding_ratio` has to be positive but is {mask_thresholding_ratio} of type" + f" {type(mask_thresholding_ratio)}." + ) + + # 2. Define call parameters + if target_prompt is not None and isinstance(target_prompt, str): + batch_size = 1 + elif target_prompt is not None and isinstance(target_prompt, list): + batch_size = len(target_prompt) + else: + batch_size = target_prompt_embeds.shape[0] + if cross_attention_kwargs is None: + cross_attention_kwargs = {} + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompts + (cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None) + target_negative_prompt_embeds, target_prompt_embeds = self.encode_prompt( + target_prompt, + device, + num_maps_per_mask, + do_classifier_free_guidance, + target_negative_prompt, + prompt_embeds=target_prompt_embeds, + negative_prompt_embeds=target_negative_prompt_embeds, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + target_prompt_embeds = torch.cat([target_negative_prompt_embeds, target_prompt_embeds]) + + source_negative_prompt_embeds, source_prompt_embeds = self.encode_prompt( + source_prompt, + device, + num_maps_per_mask, + do_classifier_free_guidance, + source_negative_prompt, + prompt_embeds=source_prompt_embeds, + negative_prompt_embeds=source_negative_prompt_embeds, + ) + if do_classifier_free_guidance: + source_prompt_embeds = torch.cat([source_negative_prompt_embeds, source_prompt_embeds]) + + # 4. Preprocess image + image = self.image_processor.preprocess(image).repeat_interleave(num_maps_per_mask, dim=0) + + # 5. Set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, _ = self.get_timesteps(num_inference_steps, mask_encode_strength, device) + encode_timestep = timesteps[0] + + # 6. Prepare image latents and add noise with specified strength + image_latents = self.prepare_image_latents( + image, batch_size * num_maps_per_mask, self.vae.dtype, device, generator + ) + noise = randn_tensor(image_latents.shape, generator=generator, device=device, dtype=self.vae.dtype) + image_latents = self.scheduler.add_noise(image_latents, noise, encode_timestep) + + latent_model_input = torch.cat([image_latents] * (4 if do_classifier_free_guidance else 2)) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, encode_timestep) + + # 7. Predict the noise residual + prompt_embeds = torch.cat([source_prompt_embeds, target_prompt_embeds]) + noise_pred = self.unet( + latent_model_input, + encode_timestep, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + if do_classifier_free_guidance: + noise_pred_neg_src, noise_pred_source, noise_pred_uncond, noise_pred_target = noise_pred.chunk(4) + noise_pred_source = noise_pred_neg_src + guidance_scale * (noise_pred_source - noise_pred_neg_src) + noise_pred_target = noise_pred_uncond + guidance_scale * (noise_pred_target - noise_pred_uncond) + else: + noise_pred_source, noise_pred_target = noise_pred.chunk(2) + + # 8. Compute the mask from the absolute difference of predicted noise residuals + # TODO: Consider smoothing mask guidance map + mask_guidance_map = ( + torch.abs(noise_pred_target - noise_pred_source) + .reshape(batch_size, num_maps_per_mask, *noise_pred_target.shape[-3:]) + .mean([1, 2]) + ) + clamp_magnitude = mask_guidance_map.mean() * mask_thresholding_ratio + semantic_mask_image = mask_guidance_map.clamp(0, clamp_magnitude) / clamp_magnitude + semantic_mask_image = torch.where(semantic_mask_image <= 0.5, 0, 1) + mask_image = semantic_mask_image.cpu().numpy() + + # 9. Convert to Numpy array or PIL. + if output_type == "pil": + mask_image = self.image_processor.numpy_to_pil(mask_image) + + # Offload all models + self.maybe_free_model_hooks() + + return mask_image + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_INVERT_DOC_STRING) + def invert( + self, + prompt: str | list[str] | None = None, + image: torch.Tensor | PIL.Image.Image = None, + num_inference_steps: int = 50, + inpaint_strength: float = 0.8, + guidance_scale: float = 7.5, + negative_prompt: str | list[str] | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + decode_latents: bool = False, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int | None = 1, + cross_attention_kwargs: dict[str, Any] | None = None, + lambda_auto_corr: float = 20.0, + lambda_kl: float = 20.0, + num_reg_steps: int = 0, + num_auto_corr_rolls: int = 5, + ): + r""" + Generate inverted latents given a prompt and image. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`PIL.Image.Image`): + `Image` or tensor representing an image batch to produce the inverted latents guided by `prompt`. + inpaint_strength (`float`, *optional*, defaults to 0.8): + Indicates extent of the noising process to run latent inversion. Must be between 0 and 1. When + `inpaint_strength` is 1, the inversion process is run for the full number of iterations specified in + `num_inference_steps`. `image` is used as a reference for the inversion process, and adding more noise + increases `inpaint_strength`. If `inpaint_strength` is 0, no inpainting occurs. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + decode_latents (`bool`, *optional*, defaults to `False`): + Whether or not to decode the inverted latents into a generated image. Setting this argument to `True` + decodes all inverted latents for each timestep into a list of generated images. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.DiffEditInversionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the + [`~models.attention_processor.AttnProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + lambda_auto_corr (`float`, *optional*, defaults to 20.0): + Lambda parameter to control auto correction. + lambda_kl (`float`, *optional*, defaults to 20.0): + Lambda parameter to control Kullback-Leibler divergence output. + num_reg_steps (`int`, *optional*, defaults to 0): + Number of regularization loss steps. + num_auto_corr_rolls (`int`, *optional*, defaults to 5): + Number of auto correction roll steps. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.pipeline_stable_diffusion_diffedit.DiffEditInversionPipelineOutput`] or + `tuple`: + If `return_dict` is `True`, + [`~pipelines.stable_diffusion.pipeline_stable_diffusion_diffedit.DiffEditInversionPipelineOutput`] is + returned, otherwise a `tuple` is returned where the first element is the inverted latents tensors + ordered by increasing noise, and the second is the corresponding decoded images if `decode_latents` is + `True`, otherwise `None`. + """ + + # 1. Check inputs + self.check_inputs( + prompt, + inpaint_strength, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + if image is None: + raise ValueError("`image` input cannot be undefined.") + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + if cross_attention_kwargs is None: + cross_attention_kwargs = {} + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Preprocess image + image = self.image_processor.preprocess(image) + + # 4. Prepare latent variables + num_images_per_prompt = 1 + latents = self.prepare_image_latents( + image, batch_size * num_images_per_prompt, self.vae.dtype, device, generator + ) + + # 5. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 6. Prepare timesteps + self.inverse_scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_inverse_timesteps(num_inference_steps, inpaint_strength, device) + + # 7. Noising loop where we obtain the intermediate noised latent image for each timestep. + num_warmup_steps = len(timesteps) - num_inference_steps * self.inverse_scheduler.order + inverted_latents = [] + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.inverse_scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # regularization of the noise prediction (not in original code or paper but borrowed from Pix2PixZero) + if num_reg_steps > 0: + with torch.enable_grad(): + for _ in range(num_reg_steps): + if lambda_auto_corr > 0: + for _ in range(num_auto_corr_rolls): + var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True) + + # Derive epsilon from model output before regularizing to IID standard normal + var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t) + + l_ac = auto_corr_loss(var_epsilon, generator=generator) + l_ac.backward() + + grad = var.grad.detach() / num_auto_corr_rolls + noise_pred = noise_pred - lambda_auto_corr * grad + + if lambda_kl > 0: + var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True) + + # Derive epsilon from model output before regularizing to IID standard normal + var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t) + + l_kld = kl_divergence(var_epsilon) + l_kld.backward() + + grad = var.grad.detach() + noise_pred = noise_pred - lambda_kl * grad + + noise_pred = noise_pred.detach() + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.inverse_scheduler.step(noise_pred, t, latents).prev_sample + inverted_latents.append(latents.detach().clone()) + + # call the callback, if provided + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.inverse_scheduler.order == 0 + ): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + assert len(inverted_latents) == len(timesteps) + latents = torch.stack(list(reversed(inverted_latents)), 1) + + # 8. Post-processing + image = None + if decode_latents: + image = self.decode_latents(latents.flatten(0, 1)) + + # 9. Convert to PIL. + if decode_latents and output_type == "pil": + image = self.image_processor.numpy_to_pil(image) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (latents, image) + + return DiffEditInversionPipelineOutput(latents=latents, images=image) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] | None = None, + mask_image: torch.Tensor | PIL.Image.Image = None, + image_latents: torch.Tensor | PIL.Image.Image = None, + inpaint_strength: float | None = 0.8, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + cross_attention_kwargs: dict[str, Any] | None = None, + clip_skip: int = None, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + mask_image (`PIL.Image.Image`): + `Image` or tensor representing an image batch to mask the generated image. White pixels in the mask are + repainted, while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a + single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) + instead of 3, so the expected shape would be `(B, 1, H, W)`. + image_latents (`PIL.Image.Image` or `torch.Tensor`): + Partially noised image latents from the inversion process to be used as inputs for image generation. + inpaint_strength (`float`, *optional*, defaults to 0.8): + Indicates extent to inpaint the masked area. Must be between 0 and 1. When `inpaint_strength` is 1, the + denoising process is run on the masked area for the full number of iterations specified in + `num_inference_steps`. `image_latents` is used as a reference for the masked area, and adding more + noise to a region increases `inpaint_strength`. If `inpaint_strength` is 0, no inpainting occurs. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + # 1. Check inputs + self.check_inputs( + prompt, + inpaint_strength, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + if mask_image is None: + raise ValueError( + "`mask_image` input cannot be undefined. Use `generate_mask()` to compute `mask_image` from text prompts." + ) + if image_latents is None: + raise ValueError( + "`image_latents` input cannot be undefined. Use `invert()` to compute `image_latents` from input images." + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + if cross_attention_kwargs is None: + cross_attention_kwargs = {} + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Preprocess mask + mask_image = preprocess_mask(mask_image, batch_size) + latent_height, latent_width = mask_image.shape[-2:] + mask_image = torch.cat([mask_image] * num_images_per_prompt) + mask_image = mask_image.to(device=device, dtype=prompt_embeds.dtype) + + # 5. Set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, inpaint_strength, device) + + # 6. Preprocess image latents + if isinstance(image_latents, list) and any(isinstance(l, torch.Tensor) and l.ndim == 5 for l in image_latents): + image_latents = torch.cat(image_latents).detach() + elif isinstance(image_latents, torch.Tensor) and image_latents.ndim == 5: + image_latents = image_latents.detach() + else: + image_latents = self.image_processor.preprocess(image_latents).detach() + + latent_shape = (self.vae.config.latent_channels, latent_height, latent_width) + if image_latents.shape[-3:] != latent_shape: + raise ValueError( + f"Each latent image in `image_latents` must have shape {latent_shape}, " + f"but has shape {image_latents.shape[-3:]}" + ) + if image_latents.ndim == 4: + image_latents = image_latents.reshape(batch_size, len(timesteps), *latent_shape) + if image_latents.shape[:2] != (batch_size, len(timesteps)): + raise ValueError( + f"`image_latents` must have batch size {batch_size} with latent images from {len(timesteps)}" + f" timesteps, but has batch size {image_latents.shape[0]} with latent images from" + f" {image_latents.shape[1]} timesteps." + ) + image_latents = image_latents.transpose(0, 1).repeat_interleave(num_images_per_prompt, dim=1) + image_latents = image_latents.to(device=device, dtype=prompt_embeds.dtype) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 8. Denoising loop + latents = image_latents[0].clone() + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # mask with inverted latents from appropriate timestep - use original image latent for last step + latents = latents * mask_image + image_latents[i] * (1 - mask_image) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/diffusers/pipelines/stable_diffusion_safe/__init__.py b/diffusers/pipelines/stable_diffusion_safe/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b35015a9f72913526ee4bdece969f74d545aecfb --- /dev/null +++ b/diffusers/pipelines/stable_diffusion_safe/__init__.py @@ -0,0 +1,99 @@ +from dataclasses import dataclass +from enum import Enum +from typing import TYPE_CHECKING + +import numpy as np +import PIL +from PIL import Image + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + BaseOutput, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +@dataclass +class SafetyConfig(object): + WEAK = { + "sld_warmup_steps": 15, + "sld_guidance_scale": 20, + "sld_threshold": 0.0, + "sld_momentum_scale": 0.0, + "sld_mom_beta": 0.0, + } + MEDIUM = { + "sld_warmup_steps": 10, + "sld_guidance_scale": 1000, + "sld_threshold": 0.01, + "sld_momentum_scale": 0.3, + "sld_mom_beta": 0.4, + } + STRONG = { + "sld_warmup_steps": 7, + "sld_guidance_scale": 2000, + "sld_threshold": 0.025, + "sld_momentum_scale": 0.5, + "sld_mom_beta": 0.7, + } + MAX = { + "sld_warmup_steps": 0, + "sld_guidance_scale": 5000, + "sld_threshold": 1.0, + "sld_momentum_scale": 0.5, + "sld_mom_beta": 0.7, + } + + +_dummy_objects = {} +_additional_imports = {} +_import_structure = {} + +_additional_imports.update({"SafetyConfig": SafetyConfig}) + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure.update( + { + "pipeline_output": ["StableDiffusionSafePipelineOutput"], + "pipeline_stable_diffusion_safe": ["StableDiffusionPipelineSafe"], + "safety_checker": ["StableDiffusionSafetyChecker"], + } + ) + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_output import StableDiffusionSafePipelineOutput + from .pipeline_stable_diffusion_safe import StableDiffusionPipelineSafe + from .safety_checker import SafeStableDiffusionSafetyChecker + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) + for name, value in _additional_imports.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/pipelines/stable_diffusion_safe/pipeline_output.py b/diffusers/pipelines/stable_diffusion_safe/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..6b784bb0e10256451d1d3950edeebb58a929cf0c --- /dev/null +++ b/diffusers/pipelines/stable_diffusion_safe/pipeline_output.py @@ -0,0 +1,33 @@ +from dataclasses import dataclass + +import numpy as np +import PIL.Image + +from ...utils import ( + BaseOutput, +) + + +@dataclass +class StableDiffusionSafePipelineOutput(BaseOutput): + """ + Output class for Safe Stable Diffusion pipelines. + + Args: + images (`list[PIL.Image.Image]` or `np.ndarray`) + list of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + nsfw_content_detected (`list[bool]`) + list of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, or `None` if safety checking could not be performed. + images (`list[PIL.Image.Image]` or `np.ndarray`) + list of denoised PIL images that were flagged by the safety checker any may contain "not-safe-for-work" + (nsfw) content, or `None` if no safety check was performed or no images were flagged. + applied_safety_concept (`str`) + The safety concept that was applied for safety guidance, or `None` if safety guidance was disabled + """ + + images: list[PIL.Image.Image] | np.ndarray + nsfw_content_detected: list[bool] | None + unsafe_images: list[PIL.Image.Image] | np.ndarray | None + applied_safety_concept: str | None diff --git a/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py b/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py new file mode 100644 index 0000000000000000000000000000000000000000..26bb5128ba9bf57623ff19ab4abf690f07e65b0b --- /dev/null +++ b/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py @@ -0,0 +1,784 @@ +import inspect +import warnings +from typing import Callable + +import numpy as np +import torch +from packaging import version +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...configuration_utils import FrozenDict +from ...image_processor import PipelineImageInput +from ...loaders import IPAdapterMixin +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import deprecate, is_torch_xla_available, logging +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin +from . import StableDiffusionSafePipelineOutput +from .safety_checker import SafeStableDiffusionSafetyChecker + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class StableDiffusionPipelineSafe(DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin, IPAdapterMixin): + _last_supported_version = "0.33.1" + + r""" + Pipeline based on the [`StableDiffusionPipeline`] for text-to-image generation using Safe Latent Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for + more details about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: SafeStableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection | None = None, + requires_safety_checker: bool = True, + ): + super().__init__() + safety_concept: str | None = ( + "an image showing hate, harassment, violence, suffering, humiliation, harm, suicide, sexual, nudity," + " bodily fluids, blood, obscene gestures, illegal activity, drug use, theft, vandalism, weapons, child" + " abuse, brutality, cruelty" + ) + + if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = ( + unet is not None + and hasattr(unet.config, "_diffusers_version") + and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0") + ) + is_unet_sample_size_less_64 = ( + unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5" + " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self._safety_text_concept = safety_concept + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.register_to_config(requires_safety_checker=requires_safety_checker) + + @property + def safety_concept(self): + r""" + Getter method for the safety concept used with SLD + + Returns: + `str`: The text describing the safety concept + """ + return self._safety_text_concept + + @safety_concept.setter + def safety_concept(self, concept): + r""" + Setter method for the safety concept used with SLD + + Args: + concept (`str`): + The text of the new safety concept + """ + self._safety_text_concept = concept + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + enable_safety_guidance, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + """ + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids + + if not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # Encode the safety concept text + if enable_safety_guidance: + safety_concept_input = self.tokenizer( + [self._safety_text_concept], + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + safety_embeddings = self.text_encoder(safety_concept_input.input_ids.to(self.device))[0] + + # duplicate safety embeddings for each generation per prompt, using mps friendly method + seq_len = safety_embeddings.shape[1] + safety_embeddings = safety_embeddings.repeat(batch_size, num_images_per_prompt, 1) + safety_embeddings = safety_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance + sld, we need to do three forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing three forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, safety_embeddings]) + + else: + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + def run_safety_checker(self, image, device, dtype, enable_safety_guidance): + if self.safety_checker is not None: + images = image.copy() + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + flagged_images = np.zeros((2, *image.shape[1:])) + if any(has_nsfw_concept): + logger.warning( + "Potential NSFW content was detected in one or more images. A black image will be returned" + " instead." + f"{'You may look at this images in the `unsafe_images` variable of the output at your own discretion.' if enable_safety_guidance else 'Try again with a different prompt and/or seed.'}" + ) + for idx, has_nsfw_concept in enumerate(has_nsfw_concept): + if has_nsfw_concept: + flagged_images[idx] = images[idx] + image[idx] = np.zeros(image[idx].shape) # black image + else: + has_nsfw_concept = None + flagged_images = None + return image, has_nsfw_concept, flagged_images + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def perform_safety_guidance( + self, + enable_safety_guidance, + safety_momentum, + noise_guidance, + noise_pred_out, + i, + sld_guidance_scale, + sld_warmup_steps, + sld_threshold, + sld_momentum_scale, + sld_mom_beta, + ): + # Perform SLD guidance + if enable_safety_guidance: + if safety_momentum is None: + safety_momentum = torch.zeros_like(noise_guidance) + noise_pred_text, noise_pred_uncond = noise_pred_out[0], noise_pred_out[1] + noise_pred_safety_concept = noise_pred_out[2] + + # Equation 6 + scale = torch.clamp(torch.abs((noise_pred_text - noise_pred_safety_concept)) * sld_guidance_scale, max=1.0) + + # Equation 6 + safety_concept_scale = torch.where( + (noise_pred_text - noise_pred_safety_concept) >= sld_threshold, torch.zeros_like(scale), scale + ) + + # Equation 4 + noise_guidance_safety = torch.mul((noise_pred_safety_concept - noise_pred_uncond), safety_concept_scale) + + # Equation 7 + noise_guidance_safety = noise_guidance_safety + sld_momentum_scale * safety_momentum + + # Equation 8 + safety_momentum = sld_mom_beta * safety_momentum + (1 - sld_mom_beta) * noise_guidance_safety + + if i >= sld_warmup_steps: # Warmup + # Equation 3 + noise_guidance = noise_guidance - noise_guidance_safety + return noise_guidance, safety_momentum + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + @torch.no_grad() + def __call__( + self, + prompt: str | list[str], + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + sld_guidance_scale: float | None = 1000, + sld_warmup_steps: int | None = 10, + sld_threshold: float | None = 0.01, + sld_momentum_scale: float | None = 0.3, + sld_mom_beta: float | None = 0.4, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + ip_adapter_image: (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + sld_guidance_scale (`float`, *optional*, defaults to 1000): + If `sld_guidance_scale < 1`, safety guidance is disabled. + sld_warmup_steps (`int`, *optional*, defaults to 10): + Number of warmup steps for safety guidance. SLD is only be applied for diffusion steps greater than + `sld_warmup_steps`. + sld_threshold (`float`, *optional*, defaults to 0.01): + Threshold that separates the hyperplane between appropriate and inappropriate images. + sld_momentum_scale (`float`, *optional*, defaults to 0.3): + Scale of the SLD momentum to be added to the safety guidance at each diffusion step. If set to 0.0, + momentum is disabled. Momentum is built up during warmup for diffusion steps smaller than + `sld_warmup_steps`. + sld_mom_beta (`float`, *optional*, defaults to 0.4): + Defines how safety guidance momentum builds up. `sld_mom_beta` indicates how much of the previous + momentum is kept. Momentum is built up during warmup for diffusion steps smaller than + `sld_warmup_steps`. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + + Examples: + + ```py + import torch + from diffusers import StableDiffusionPipelineSafe + from diffusers.pipelines.stable_diffusion_safe import SafetyConfig + + pipeline = StableDiffusionPipelineSafe.from_pretrained( + "AIML-TUDA/stable-diffusion-safe", torch_dtype=torch.float16 + ).to("cuda") + prompt = "the four horsewomen of the apocalypse, painting by tom of finland, gaston bussiere, craig mullins, j. c. leyendecker" + image = pipeline(prompt=prompt, **SafetyConfig.MEDIUM).images[0] + ``` + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, callback_steps) + + # 2. Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + enable_safety_guidance = sld_guidance_scale > 1.0 and do_classifier_free_guidance + if not enable_safety_guidance: + warnings.warn("Safety checker disabled!") + + if ip_adapter_image is not None: + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) + if do_classifier_free_guidance: + if enable_safety_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds, image_embeds]) + else: + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + + # 3. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, enable_safety_guidance + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 6.1 Add image embeds for IP-Adapter + added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + + safety_momentum = None + + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = ( + torch.cat([latents] * (3 if enable_safety_guidance else 2)) + if do_classifier_free_guidance + else latents + ) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, t, encoder_hidden_states=prompt_embeds, added_cond_kwargs=added_cond_kwargs + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_out = noise_pred.chunk((3 if enable_safety_guidance else 2)) + noise_pred_uncond, noise_pred_text = noise_pred_out[0], noise_pred_out[1] + + # default classifier free guidance + noise_guidance = noise_pred_text - noise_pred_uncond + + # Perform SLD guidance + if enable_safety_guidance: + if safety_momentum is None: + safety_momentum = torch.zeros_like(noise_guidance) + noise_pred_safety_concept = noise_pred_out[2] + + # Equation 6 + scale = torch.clamp( + torch.abs((noise_pred_text - noise_pred_safety_concept)) * sld_guidance_scale, max=1.0 + ) + + # Equation 6 + safety_concept_scale = torch.where( + (noise_pred_text - noise_pred_safety_concept) >= sld_threshold, + torch.zeros_like(scale), + scale, + ) + + # Equation 4 + noise_guidance_safety = torch.mul( + (noise_pred_safety_concept - noise_pred_uncond), safety_concept_scale + ) + + # Equation 7 + noise_guidance_safety = noise_guidance_safety + sld_momentum_scale * safety_momentum + + # Equation 8 + safety_momentum = sld_mom_beta * safety_momentum + (1 - sld_mom_beta) * noise_guidance_safety + + if i >= sld_warmup_steps: # Warmup + # Equation 3 + noise_guidance = noise_guidance - noise_guidance_safety + + noise_pred = noise_pred_uncond + guidance_scale * noise_guidance + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + # 8. Post-processing + image = self.decode_latents(latents) + + # 9. Run safety checker + image, has_nsfw_concept, flagged_images = self.run_safety_checker( + image, device, prompt_embeds.dtype, enable_safety_guidance + ) + + # 10. Convert to PIL + if output_type == "pil": + image = self.numpy_to_pil(image) + if flagged_images is not None: + flagged_images = self.numpy_to_pil(flagged_images) + + if not return_dict: + return ( + image, + has_nsfw_concept, + self._safety_text_concept if enable_safety_guidance else None, + flagged_images, + ) + + return StableDiffusionSafePipelineOutput( + images=image, + nsfw_content_detected=has_nsfw_concept, + applied_safety_concept=self._safety_text_concept if enable_safety_guidance else None, + unsafe_images=flagged_images, + ) diff --git a/diffusers/pipelines/stable_diffusion_safe/safety_checker.py b/diffusers/pipelines/stable_diffusion_safe/safety_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..1f6ad5f2a348c3a0aed208011e40e507fe0195e5 --- /dev/null +++ b/diffusers/pipelines/stable_diffusion_safe/safety_checker.py @@ -0,0 +1,109 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel + +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +def cosine_distance(image_embeds, text_embeds): + normalized_image_embeds = nn.functional.normalize(image_embeds) + normalized_text_embeds = nn.functional.normalize(text_embeds) + return torch.mm(normalized_image_embeds, normalized_text_embeds.t()) + + +class SafeStableDiffusionSafetyChecker(PreTrainedModel): + config_class = CLIPConfig + + _no_split_modules = ["CLIPEncoderLayer"] + + def __init__(self, config: CLIPConfig): + super().__init__(config) + + self.vision_model = CLIPVisionModel(config.vision_config) + self.visual_projection = nn.Linear(config.vision_config.hidden_size, config.projection_dim, bias=False) + + self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False) + self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False) + + self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False) + self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False) + + @torch.no_grad() + def forward(self, clip_input, images): + pooled_output = self.vision_model(clip_input)[1] # pooled_output + image_embeds = self.visual_projection(pooled_output) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().float().numpy() + cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().float().numpy() + + result = [] + batch_size = image_embeds.shape[0] + for i in range(batch_size): + result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []} + + # increase this value to create a stronger `nfsw` filter + # at the cost of increasing the possibility of filtering benign images + adjustment = 0.0 + + for concept_idx in range(len(special_cos_dist[0])): + concept_cos = special_cos_dist[i][concept_idx] + concept_threshold = self.special_care_embeds_weights[concept_idx].item() + result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) + if result_img["special_scores"][concept_idx] > 0: + result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]}) + adjustment = 0.01 + + for concept_idx in range(len(cos_dist[0])): + concept_cos = cos_dist[i][concept_idx] + concept_threshold = self.concept_embeds_weights[concept_idx].item() + result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) + if result_img["concept_scores"][concept_idx] > 0: + result_img["bad_concepts"].append(concept_idx) + + result.append(result_img) + + has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result] + + return images, has_nsfw_concepts + + @torch.no_grad() + def forward_onnx(self, clip_input: torch.Tensor, images: torch.Tensor): + pooled_output = self.vision_model(clip_input)[1] # pooled_output + image_embeds = self.visual_projection(pooled_output) + + special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds) + cos_dist = cosine_distance(image_embeds, self.concept_embeds) + + # increase this value to create a stronger `nsfw` filter + # at the cost of increasing the possibility of filtering benign images + adjustment = 0.0 + + special_scores = special_cos_dist - self.special_care_embeds_weights + adjustment + # special_scores = special_scores.round(decimals=3) + special_care = torch.any(special_scores > 0, dim=1) + special_adjustment = special_care * 0.01 + special_adjustment = special_adjustment.unsqueeze(1).expand(-1, cos_dist.shape[1]) + + concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment + # concept_scores = concept_scores.round(decimals=3) + has_nsfw_concepts = torch.any(concept_scores > 0, dim=1) + + return images, has_nsfw_concepts diff --git a/diffusers/pipelines/unclip/__init__.py b/diffusers/pipelines/unclip/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c89e899463beede59b8ccf02688f6168b8ee3d77 --- /dev/null +++ b/diffusers/pipelines/unclip/__init__.py @@ -0,0 +1,52 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + is_torch_available, + is_transformers_available, + is_transformers_version, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import UnCLIPImageVariationPipeline, UnCLIPPipeline + + _dummy_objects.update( + {"UnCLIPImageVariationPipeline": UnCLIPImageVariationPipeline, "UnCLIPPipeline": UnCLIPPipeline} + ) +else: + _import_structure["pipeline_unclip"] = ["UnCLIPPipeline"] + _import_structure["pipeline_unclip_image_variation"] = ["UnCLIPImageVariationPipeline"] + _import_structure["text_proj"] = ["UnCLIPTextProjModel"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .pipeline_unclip import UnCLIPPipeline + from .pipeline_unclip_image_variation import UnCLIPImageVariationPipeline + from .text_proj import UnCLIPTextProjModel + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/pipelines/unclip/pipeline_unclip.py b/diffusers/pipelines/unclip/pipeline_unclip.py new file mode 100644 index 0000000000000000000000000000000000000000..430f1a1e52651b9f98174a081567f463235fcee3 --- /dev/null +++ b/diffusers/pipelines/unclip/pipeline_unclip.py @@ -0,0 +1,503 @@ +# Copyright 2025 Kakao Brain and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect + +import torch +from torch.nn import functional as F +from transformers import CLIPTextModelWithProjection, CLIPTokenizer +from transformers.models.clip.modeling_clip import CLIPTextModelOutput + +from ...models import PriorTransformer, UNet2DConditionModel, UNet2DModel +from ...schedulers import UnCLIPScheduler +from ...utils import is_torch_xla_available, logging +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, ImagePipelineOutput +from .text_proj import UnCLIPTextProjModel + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class UnCLIPPipeline(DeprecatedPipelineMixin, DiffusionPipeline): + """ + Pipeline for text-to-image generation using unCLIP. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + text_encoder ([`~transformers.CLIPTextModelWithProjection`]): + Frozen text-encoder. + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + prior ([`PriorTransformer`]): + The canonical unCLIP prior to approximate the image embedding from the text embedding. + text_proj ([`UnCLIPTextProjModel`]): + Utility class to prepare and combine the embeddings before they are passed to the decoder. + decoder ([`UNet2DConditionModel`]): + The decoder to invert the image embedding into an image. + super_res_first ([`UNet2DModel`]): + Super resolution UNet. Used in all but the last step of the super resolution diffusion process. + super_res_last ([`UNet2DModel`]): + Super resolution UNet. Used in the last step of the super resolution diffusion process. + prior_scheduler ([`UnCLIPScheduler`]): + Scheduler used in the prior denoising process (a modified [`DDPMScheduler`]). + decoder_scheduler ([`UnCLIPScheduler`]): + Scheduler used in the decoder denoising process (a modified [`DDPMScheduler`]). + super_res_scheduler ([`UnCLIPScheduler`]): + Scheduler used in the super resolution denoising process (a modified [`DDPMScheduler`]). + + """ + + _last_supported_version = "0.33.1" + _exclude_from_cpu_offload = ["prior"] + + prior: PriorTransformer + decoder: UNet2DConditionModel + text_proj: UnCLIPTextProjModel + text_encoder: CLIPTextModelWithProjection + tokenizer: CLIPTokenizer + super_res_first: UNet2DModel + super_res_last: UNet2DModel + + prior_scheduler: UnCLIPScheduler + decoder_scheduler: UnCLIPScheduler + super_res_scheduler: UnCLIPScheduler + + model_cpu_offload_seq = "text_encoder->text_proj->decoder->super_res_first->super_res_last" + + def __init__( + self, + prior: PriorTransformer, + decoder: UNet2DConditionModel, + text_encoder: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + text_proj: UnCLIPTextProjModel, + super_res_first: UNet2DModel, + super_res_last: UNet2DModel, + prior_scheduler: UnCLIPScheduler, + decoder_scheduler: UnCLIPScheduler, + super_res_scheduler: UnCLIPScheduler, + ): + super().__init__() + + self.register_modules( + prior=prior, + decoder=decoder, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_proj=text_proj, + super_res_first=super_res_first, + super_res_last=super_res_last, + prior_scheduler=prior_scheduler, + decoder_scheduler=decoder_scheduler, + super_res_scheduler=super_res_scheduler, + ) + + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + text_model_output: CLIPTextModelOutput | tuple | None = None, + text_attention_mask: torch.Tensor | None = None, + ): + if text_model_output is None: + batch_size = len(prompt) if isinstance(prompt, list) else 1 + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + text_mask = text_inputs.attention_mask.bool().to(device) + + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + + text_encoder_output = self.text_encoder(text_input_ids.to(device)) + + prompt_embeds = text_encoder_output.text_embeds + text_enc_hid_states = text_encoder_output.last_hidden_state + + else: + batch_size = text_model_output[0].shape[0] + prompt_embeds, text_enc_hid_states = text_model_output[0], text_model_output[1] + text_mask = text_attention_mask + + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + text_enc_hid_states = text_enc_hid_states.repeat_interleave(num_images_per_prompt, dim=0) + text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + uncond_tokens = [""] * batch_size + + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + uncond_text_mask = uncond_input.attention_mask.bool().to(device) + negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device)) + + negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds + uncond_text_enc_hid_states = negative_prompt_embeds_text_encoder_output.last_hidden_state + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) + + seq_len = uncond_text_enc_hid_states.shape[1] + uncond_text_enc_hid_states = uncond_text_enc_hid_states.repeat(1, num_images_per_prompt, 1) + uncond_text_enc_hid_states = uncond_text_enc_hid_states.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + # done duplicates + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + text_enc_hid_states = torch.cat([uncond_text_enc_hid_states, text_enc_hid_states]) + + text_mask = torch.cat([uncond_text_mask, text_mask]) + + return prompt_embeds, text_enc_hid_states, text_mask + + @torch.no_grad() + def __call__( + self, + prompt: str | list[str] | None = None, + num_images_per_prompt: int = 1, + prior_num_inference_steps: int = 25, + decoder_num_inference_steps: int = 25, + super_res_num_inference_steps: int = 7, + generator: torch.Generator | list[torch.Generator] | None = None, + prior_latents: torch.Tensor | None = None, + decoder_latents: torch.Tensor | None = None, + super_res_latents: torch.Tensor | None = None, + text_model_output: CLIPTextModelOutput | tuple | None = None, + text_attention_mask: torch.Tensor | None = None, + prior_guidance_scale: float = 4.0, + decoder_guidance_scale: float = 8.0, + output_type: str | None = "pil", + return_dict: bool = True, + ): + """ + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`): + The prompt or prompts to guide image generation. This can only be left undefined if `text_model_output` + and `text_attention_mask` is passed. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prior_num_inference_steps (`int`, *optional*, defaults to 25): + The number of denoising steps for the prior. More denoising steps usually lead to a higher quality + image at the expense of slower inference. + decoder_num_inference_steps (`int`, *optional*, defaults to 25): + The number of denoising steps for the decoder. More denoising steps usually lead to a higher quality + image at the expense of slower inference. + super_res_num_inference_steps (`int`, *optional*, defaults to 7): + The number of denoising steps for super resolution. More denoising steps usually lead to a higher + quality image at the expense of slower inference. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + prior_latents (`torch.Tensor` of shape (batch size, embeddings dimension), *optional*): + Pre-generated noisy latents to be used as inputs for the prior. + decoder_latents (`torch.Tensor` of shape (batch size, channels, height, width), *optional*): + Pre-generated noisy latents to be used as inputs for the decoder. + super_res_latents (`torch.Tensor` of shape (batch size, channels, super res height, super res width), *optional*): + Pre-generated noisy latents to be used as inputs for the decoder. + prior_guidance_scale (`float`, *optional*, defaults to 4.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + decoder_guidance_scale (`float`, *optional*, defaults to 4.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + text_model_output (`CLIPTextModelOutput`, *optional*): + Pre-defined [`CLIPTextModel`] outputs that can be derived from the text encoder. Pre-defined text + outputs can be passed for tasks like text embedding interpolations. Make sure to also pass + `text_attention_mask` in this case. `prompt` can the be left `None`. + text_attention_mask (`torch.Tensor`, *optional*): + Pre-defined CLIP text attention mask that can be derived from the tokenizer. Pre-defined text attention + masks are necessary when passing `text_model_output`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images. + """ + if prompt is not None: + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + else: + batch_size = text_model_output[0].shape[0] + + device = self._execution_device + + batch_size = batch_size * num_images_per_prompt + + do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0 + + prompt_embeds, text_enc_hid_states, text_mask = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, text_model_output, text_attention_mask + ) + + # prior + + self.prior_scheduler.set_timesteps(prior_num_inference_steps, device=device) + prior_timesteps_tensor = self.prior_scheduler.timesteps + + embedding_dim = self.prior.config.embedding_dim + + prior_latents = self.prepare_latents( + (batch_size, embedding_dim), + prompt_embeds.dtype, + device, + generator, + prior_latents, + self.prior_scheduler, + ) + + for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([prior_latents] * 2) if do_classifier_free_guidance else prior_latents + + predicted_image_embedding = self.prior( + latent_model_input, + timestep=t, + proj_embedding=prompt_embeds, + encoder_hidden_states=text_enc_hid_states, + attention_mask=text_mask, + ).predicted_image_embedding + + if do_classifier_free_guidance: + predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2) + predicted_image_embedding = predicted_image_embedding_uncond + prior_guidance_scale * ( + predicted_image_embedding_text - predicted_image_embedding_uncond + ) + + if i + 1 == prior_timesteps_tensor.shape[0]: + prev_timestep = None + else: + prev_timestep = prior_timesteps_tensor[i + 1] + + prior_latents = self.prior_scheduler.step( + predicted_image_embedding, + timestep=t, + sample=prior_latents, + generator=generator, + prev_timestep=prev_timestep, + ).prev_sample + + prior_latents = self.prior.post_process_latents(prior_latents) + + image_embeddings = prior_latents + + # done prior + + # decoder + + text_enc_hid_states, additive_clip_time_embeddings = self.text_proj( + image_embeddings=image_embeddings, + prompt_embeds=prompt_embeds, + text_encoder_hidden_states=text_enc_hid_states, + do_classifier_free_guidance=do_classifier_free_guidance, + ) + + if device.type == "mps": + # HACK: MPS: There is a panic when padding bool tensors, + # so cast to int tensor for the pad and back to bool afterwards + text_mask = text_mask.type(torch.int) + decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1) + decoder_text_mask = decoder_text_mask.type(torch.bool) + else: + decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=True) + + self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device) + decoder_timesteps_tensor = self.decoder_scheduler.timesteps + + num_channels_latents = self.decoder.config.in_channels + height = self.decoder.config.sample_size + width = self.decoder.config.sample_size + + decoder_latents = self.prepare_latents( + (batch_size, num_channels_latents, height, width), + text_enc_hid_states.dtype, + device, + generator, + decoder_latents, + self.decoder_scheduler, + ) + + for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([decoder_latents] * 2) if do_classifier_free_guidance else decoder_latents + + noise_pred = self.decoder( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=text_enc_hid_states, + class_labels=additive_clip_time_embeddings, + attention_mask=decoder_text_mask, + ).sample + + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred_uncond, _ = noise_pred_uncond.split(latent_model_input.shape[1], dim=1) + noise_pred_text, predicted_variance = noise_pred_text.split(latent_model_input.shape[1], dim=1) + noise_pred = noise_pred_uncond + decoder_guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) + + if i + 1 == decoder_timesteps_tensor.shape[0]: + prev_timestep = None + else: + prev_timestep = decoder_timesteps_tensor[i + 1] + + # compute the previous noisy sample x_t -> x_t-1 + decoder_latents = self.decoder_scheduler.step( + noise_pred, t, decoder_latents, prev_timestep=prev_timestep, generator=generator + ).prev_sample + + decoder_latents = decoder_latents.clamp(-1, 1) + + image_small = decoder_latents + + # done decoder + + # super res + + self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device) + super_res_timesteps_tensor = self.super_res_scheduler.timesteps + + channels = self.super_res_first.config.in_channels // 2 + height = self.super_res_first.config.sample_size + width = self.super_res_first.config.sample_size + + super_res_latents = self.prepare_latents( + (batch_size, channels, height, width), + image_small.dtype, + device, + generator, + super_res_latents, + self.super_res_scheduler, + ) + + if device.type == "mps": + # MPS does not support many interpolations + image_upscaled = F.interpolate(image_small, size=[height, width]) + else: + interpolate_antialias = {} + if "antialias" in inspect.signature(F.interpolate).parameters: + interpolate_antialias["antialias"] = True + + image_upscaled = F.interpolate( + image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias + ) + + for i, t in enumerate(self.progress_bar(super_res_timesteps_tensor)): + # no classifier free guidance + + if i == super_res_timesteps_tensor.shape[0] - 1: + unet = self.super_res_last + else: + unet = self.super_res_first + + latent_model_input = torch.cat([super_res_latents, image_upscaled], dim=1) + + noise_pred = unet( + sample=latent_model_input, + timestep=t, + ).sample + + if i + 1 == super_res_timesteps_tensor.shape[0]: + prev_timestep = None + else: + prev_timestep = super_res_timesteps_tensor[i + 1] + + # compute the previous noisy sample x_t -> x_t-1 + super_res_latents = self.super_res_scheduler.step( + noise_pred, t, super_res_latents, prev_timestep=prev_timestep, generator=generator + ).prev_sample + + if XLA_AVAILABLE: + xm.mark_step() + + image = super_res_latents + # done super res + + self.maybe_free_model_hooks() + + # post processing + image = image * 0.5 + 0.5 + image = image.clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py b/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py new file mode 100644 index 0000000000000000000000000000000000000000..d0d8bdc447879151ea401e42062a8236618a815c --- /dev/null +++ b/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py @@ -0,0 +1,430 @@ +# Copyright 2025 Kakao Brain and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect + +import PIL.Image +import torch +from torch.nn import functional as F +from transformers import ( + CLIPImageProcessor, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from ...models import UNet2DConditionModel, UNet2DModel +from ...schedulers import UnCLIPScheduler +from ...utils import is_torch_xla_available, logging +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, ImagePipelineOutput +from .text_proj import UnCLIPTextProjModel + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class UnCLIPImageVariationPipeline(DeprecatedPipelineMixin, DiffusionPipeline): + """ + Pipeline to generate image variations from an input image using UnCLIP. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + text_encoder ([`~transformers.CLIPTextModelWithProjection`]): + Frozen text-encoder. + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `image_encoder`. + image_encoder ([`~transformers.CLIPVisionModelWithProjection`]): + Frozen CLIP image-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + text_proj ([`UnCLIPTextProjModel`]): + Utility class to prepare and combine the embeddings before they are passed to the decoder. + decoder ([`UNet2DConditionModel`]): + The decoder to invert the image embedding into an image. + super_res_first ([`UNet2DModel`]): + Super resolution UNet. Used in all but the last step of the super resolution diffusion process. + super_res_last ([`UNet2DModel`]): + Super resolution UNet. Used in the last step of the super resolution diffusion process. + decoder_scheduler ([`UnCLIPScheduler`]): + Scheduler used in the decoder denoising process (a modified [`DDPMScheduler`]). + super_res_scheduler ([`UnCLIPScheduler`]): + Scheduler used in the super resolution denoising process (a modified [`DDPMScheduler`]). + """ + + _last_supported_version = "0.33.1" + decoder: UNet2DConditionModel + text_proj: UnCLIPTextProjModel + text_encoder: CLIPTextModelWithProjection + tokenizer: CLIPTokenizer + feature_extractor: CLIPImageProcessor + image_encoder: CLIPVisionModelWithProjection + super_res_first: UNet2DModel + super_res_last: UNet2DModel + + decoder_scheduler: UnCLIPScheduler + super_res_scheduler: UnCLIPScheduler + model_cpu_offload_seq = "text_encoder->image_encoder->text_proj->decoder->super_res_first->super_res_last" + + def __init__( + self, + decoder: UNet2DConditionModel, + text_encoder: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + text_proj: UnCLIPTextProjModel, + feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection, + super_res_first: UNet2DModel, + super_res_last: UNet2DModel, + decoder_scheduler: UnCLIPScheduler, + super_res_scheduler: UnCLIPScheduler, + ): + super().__init__() + + self.register_modules( + decoder=decoder, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_proj=text_proj, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + super_res_first=super_res_first, + super_res_last=super_res_last, + decoder_scheduler=decoder_scheduler, + super_res_scheduler=super_res_scheduler, + ) + + # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance): + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + text_mask = text_inputs.attention_mask.bool().to(device) + text_encoder_output = self.text_encoder(text_input_ids.to(device)) + + prompt_embeds = text_encoder_output.text_embeds + text_encoder_hidden_states = text_encoder_output.last_hidden_state + + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + uncond_tokens = [""] * batch_size + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + uncond_text_mask = uncond_input.attention_mask.bool().to(device) + negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device)) + + negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds + uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) + + seq_len = uncond_text_encoder_hidden_states.shape[1] + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + # done duplicates + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) + + text_mask = torch.cat([uncond_text_mask, text_mask]) + + return prompt_embeds, text_encoder_hidden_states, text_mask + + def _encode_image(self, image, device, num_images_per_prompt, image_embeddings: torch.Tensor | None = None): + dtype = next(self.image_encoder.parameters()).dtype + + if image_embeddings is None: + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(images=image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeddings = self.image_encoder(image).image_embeds + + image_embeddings = image_embeddings.repeat_interleave(num_images_per_prompt, dim=0) + + return image_embeddings + + @torch.no_grad() + def __call__( + self, + image: PIL.Image.Image | list[PIL.Image.Image] | torch.Tensor | None = None, + num_images_per_prompt: int = 1, + decoder_num_inference_steps: int = 25, + super_res_num_inference_steps: int = 7, + generator: torch.Generator | None = None, + decoder_latents: torch.Tensor | None = None, + super_res_latents: torch.Tensor | None = None, + image_embeddings: torch.Tensor | None = None, + decoder_guidance_scale: float = 8.0, + output_type: str | None = "pil", + return_dict: bool = True, + ): + """ + The call function to the pipeline for generation. + + Args: + image (`PIL.Image.Image` or `list[PIL.Image.Image]` or `torch.Tensor`): + `Image` or tensor representing an image batch to be used as the starting point. If you provide a + tensor, it needs to be compatible with the [`CLIPImageProcessor`] + [configuration](https://huggingface.co/fusing/karlo-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json). + Can be left as `None` only when `image_embeddings` are passed. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + decoder_num_inference_steps (`int`, *optional*, defaults to 25): + The number of denoising steps for the decoder. More denoising steps usually lead to a higher quality + image at the expense of slower inference. + super_res_num_inference_steps (`int`, *optional*, defaults to 7): + The number of denoising steps for super resolution. More denoising steps usually lead to a higher + quality image at the expense of slower inference. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + decoder_latents (`torch.Tensor` of shape (batch size, channels, height, width), *optional*): + Pre-generated noisy latents to be used as inputs for the decoder. + super_res_latents (`torch.Tensor` of shape (batch size, channels, super res height, super res width), *optional*): + Pre-generated noisy latents to be used as inputs for the decoder. + decoder_guidance_scale (`float`, *optional*, defaults to 4.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + image_embeddings (`torch.Tensor`, *optional*): + Pre-defined image embeddings that can be derived from the image encoder. Pre-defined image embeddings + can be passed for tasks like image interpolations. `image` can be left as `None`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images. + """ + if image is not None: + if isinstance(image, PIL.Image.Image): + batch_size = 1 + elif isinstance(image, list): + batch_size = len(image) + else: + batch_size = image.shape[0] + else: + batch_size = image_embeddings.shape[0] + + prompt = [""] * batch_size + + device = self._execution_device + + batch_size = batch_size * num_images_per_prompt + + do_classifier_free_guidance = decoder_guidance_scale > 1.0 + + prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance + ) + + image_embeddings = self._encode_image(image, device, num_images_per_prompt, image_embeddings) + + # decoder + text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj( + image_embeddings=image_embeddings, + prompt_embeds=prompt_embeds, + text_encoder_hidden_states=text_encoder_hidden_states, + do_classifier_free_guidance=do_classifier_free_guidance, + ) + + if device.type == "mps": + # HACK: MPS: There is a panic when padding bool tensors, + # so cast to int tensor for the pad and back to bool afterwards + text_mask = text_mask.type(torch.int) + decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1) + decoder_text_mask = decoder_text_mask.type(torch.bool) + else: + decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=True) + + self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device) + decoder_timesteps_tensor = self.decoder_scheduler.timesteps + + num_channels_latents = self.decoder.config.in_channels + height = self.decoder.config.sample_size + width = self.decoder.config.sample_size + + if decoder_latents is None: + decoder_latents = self.prepare_latents( + (batch_size, num_channels_latents, height, width), + text_encoder_hidden_states.dtype, + device, + generator, + decoder_latents, + self.decoder_scheduler, + ) + + for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([decoder_latents] * 2) if do_classifier_free_guidance else decoder_latents + + noise_pred = self.decoder( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=text_encoder_hidden_states, + class_labels=additive_clip_time_embeddings, + attention_mask=decoder_text_mask, + ).sample + + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred_uncond, _ = noise_pred_uncond.split(latent_model_input.shape[1], dim=1) + noise_pred_text, predicted_variance = noise_pred_text.split(latent_model_input.shape[1], dim=1) + noise_pred = noise_pred_uncond + decoder_guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) + + if i + 1 == decoder_timesteps_tensor.shape[0]: + prev_timestep = None + else: + prev_timestep = decoder_timesteps_tensor[i + 1] + + # compute the previous noisy sample x_t -> x_t-1 + decoder_latents = self.decoder_scheduler.step( + noise_pred, t, decoder_latents, prev_timestep=prev_timestep, generator=generator + ).prev_sample + + decoder_latents = decoder_latents.clamp(-1, 1) + + image_small = decoder_latents + + # done decoder + + # super res + + self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device) + super_res_timesteps_tensor = self.super_res_scheduler.timesteps + + channels = self.super_res_first.config.in_channels // 2 + height = self.super_res_first.config.sample_size + width = self.super_res_first.config.sample_size + + if super_res_latents is None: + super_res_latents = self.prepare_latents( + (batch_size, channels, height, width), + image_small.dtype, + device, + generator, + super_res_latents, + self.super_res_scheduler, + ) + + if device.type == "mps": + # MPS does not support many interpolations + image_upscaled = F.interpolate(image_small, size=[height, width]) + else: + interpolate_antialias = {} + if "antialias" in inspect.signature(F.interpolate).parameters: + interpolate_antialias["antialias"] = True + + image_upscaled = F.interpolate( + image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias + ) + + for i, t in enumerate(self.progress_bar(super_res_timesteps_tensor)): + # no classifier free guidance + + if i == super_res_timesteps_tensor.shape[0] - 1: + unet = self.super_res_last + else: + unet = self.super_res_first + + latent_model_input = torch.cat([super_res_latents, image_upscaled], dim=1) + + noise_pred = unet( + sample=latent_model_input, + timestep=t, + ).sample + + if i + 1 == super_res_timesteps_tensor.shape[0]: + prev_timestep = None + else: + prev_timestep = super_res_timesteps_tensor[i + 1] + + # compute the previous noisy sample x_t -> x_t-1 + super_res_latents = self.super_res_scheduler.step( + noise_pred, t, super_res_latents, prev_timestep=prev_timestep, generator=generator + ).prev_sample + + if XLA_AVAILABLE: + xm.mark_step() + + image = super_res_latents + + # done super res + self.maybe_free_model_hooks() + + # post processing + + image = image * 0.5 + 0.5 + image = image.clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/diffusers/pipelines/unclip/text_proj.py b/diffusers/pipelines/unclip/text_proj.py new file mode 100644 index 0000000000000000000000000000000000000000..5e04e48ba621caa4a617e52c5b83a1573bb11975 --- /dev/null +++ b/diffusers/pipelines/unclip/text_proj.py @@ -0,0 +1,86 @@ +# Copyright 2025 Kakao Brain and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch import nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models import ModelMixin + + +class UnCLIPTextProjModel(ModelMixin, ConfigMixin): + """ + Utility class for CLIP embeddings. Used to combine the image and text embeddings into a format usable by the + decoder. + + For more details, see the original paper: https://huggingface.co/papers/2204.06125 section 2.1 + """ + + @register_to_config + def __init__( + self, + *, + clip_extra_context_tokens: int = 4, + clip_embeddings_dim: int = 768, + time_embed_dim: int, + cross_attention_dim, + ): + super().__init__() + + self.learned_classifier_free_guidance_embeddings = nn.Parameter(torch.zeros(clip_embeddings_dim)) + + # parameters for additional clip time embeddings + self.embedding_proj = nn.Linear(clip_embeddings_dim, time_embed_dim) + self.clip_image_embeddings_project_to_time_embeddings = nn.Linear(clip_embeddings_dim, time_embed_dim) + + # parameters for encoder hidden states + self.clip_extra_context_tokens = clip_extra_context_tokens + self.clip_extra_context_tokens_proj = nn.Linear( + clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim + ) + self.encoder_hidden_states_proj = nn.Linear(clip_embeddings_dim, cross_attention_dim) + self.text_encoder_hidden_states_norm = nn.LayerNorm(cross_attention_dim) + + def forward(self, *, image_embeddings, prompt_embeds, text_encoder_hidden_states, do_classifier_free_guidance): + if do_classifier_free_guidance: + # Add the classifier free guidance embeddings to the image embeddings + image_embeddings_batch_size = image_embeddings.shape[0] + classifier_free_guidance_embeddings = self.learned_classifier_free_guidance_embeddings.unsqueeze(0) + classifier_free_guidance_embeddings = classifier_free_guidance_embeddings.expand( + image_embeddings_batch_size, -1 + ) + image_embeddings = torch.cat([classifier_free_guidance_embeddings, image_embeddings], dim=0) + + # The image embeddings batch size and the text embeddings batch size are equal + assert image_embeddings.shape[0] == prompt_embeds.shape[0] + + batch_size = prompt_embeds.shape[0] + + # "Specifically, we modify the architecture described in Nichol et al. (2021) by projecting and + # adding CLIP embeddings to the existing timestep embedding, ... + time_projected_prompt_embeds = self.embedding_proj(prompt_embeds) + time_projected_image_embeddings = self.clip_image_embeddings_project_to_time_embeddings(image_embeddings) + additive_clip_time_embeddings = time_projected_image_embeddings + time_projected_prompt_embeds + + # ... and by projecting CLIP embeddings into four + # extra tokens of context that are concatenated to the sequence of outputs from the GLIDE text encoder" + clip_extra_context_tokens = self.clip_extra_context_tokens_proj(image_embeddings) + clip_extra_context_tokens = clip_extra_context_tokens.reshape(batch_size, -1, self.clip_extra_context_tokens) + clip_extra_context_tokens = clip_extra_context_tokens.permute(0, 2, 1) + + text_encoder_hidden_states = self.encoder_hidden_states_proj(text_encoder_hidden_states) + text_encoder_hidden_states = self.text_encoder_hidden_states_norm(text_encoder_hidden_states) + text_encoder_hidden_states = torch.cat([clip_extra_context_tokens, text_encoder_hidden_states], dim=1) + + return text_encoder_hidden_states, additive_clip_time_embeddings diff --git a/diffusers/pipelines/wuerstchen/__init__.py b/diffusers/pipelines/wuerstchen/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ddb852d1931558fe0948e81e16cf9a92fc2a114b --- /dev/null +++ b/diffusers/pipelines/wuerstchen/__init__.py @@ -0,0 +1,56 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["modeling_paella_vq_model"] = ["PaellaVQModel"] + _import_structure["modeling_wuerstchen_diffnext"] = ["WuerstchenDiffNeXt"] + _import_structure["modeling_wuerstchen_prior"] = ["WuerstchenPrior"] + _import_structure["pipeline_wuerstchen"] = ["WuerstchenDecoderPipeline"] + _import_structure["pipeline_wuerstchen_combined"] = ["WuerstchenCombinedPipeline"] + _import_structure["pipeline_wuerstchen_prior"] = ["DEFAULT_STAGE_C_TIMESTEPS", "WuerstchenPriorPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .modeling_paella_vq_model import PaellaVQModel + from .modeling_wuerstchen_diffnext import WuerstchenDiffNeXt + from .modeling_wuerstchen_prior import WuerstchenPrior + from .pipeline_wuerstchen import WuerstchenDecoderPipeline + from .pipeline_wuerstchen_combined import WuerstchenCombinedPipeline + from .pipeline_wuerstchen_prior import DEFAULT_STAGE_C_TIMESTEPS, WuerstchenPriorPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py b/diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py new file mode 100644 index 0000000000000000000000000000000000000000..932c7ac618f6a3ed7c110a7fd7e205f035f789b9 --- /dev/null +++ b/diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py @@ -0,0 +1,171 @@ +# Copyright (c) 2022 Dominic Rampas MIT License +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models.autoencoders.vae import DecoderOutput, VectorQuantizer +from ...models.modeling_utils import ModelMixin +from ...models.vq_model import VQEncoderOutput +from ...utils.accelerate_utils import apply_forward_hook + + +class MixingResidualBlock(nn.Module): + """ + Residual block with mixing used by Paella's VQ-VAE. + """ + + def __init__(self, inp_channels, embed_dim): + super().__init__() + # depthwise + self.norm1 = nn.LayerNorm(inp_channels, elementwise_affine=False, eps=1e-6) + self.depthwise = nn.Sequential( + nn.ReplicationPad2d(1), nn.Conv2d(inp_channels, inp_channels, kernel_size=3, groups=inp_channels) + ) + + # channelwise + self.norm2 = nn.LayerNorm(inp_channels, elementwise_affine=False, eps=1e-6) + self.channelwise = nn.Sequential( + nn.Linear(inp_channels, embed_dim), nn.GELU(), nn.Linear(embed_dim, inp_channels) + ) + + self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True) + + def forward(self, x): + mods = self.gammas + x_temp = self.norm1(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * (1 + mods[0]) + mods[1] + x = x + self.depthwise(x_temp) * mods[2] + x_temp = self.norm2(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * (1 + mods[3]) + mods[4] + x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5] + return x + + +class PaellaVQModel(ModelMixin, ConfigMixin): + r"""VQ-VAE model from Paella model. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the model (such as downloading or saving, etc.) + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + up_down_scale_factor (int, *optional*, defaults to 2): Up and Downscale factor of the input image. + levels (int, *optional*, defaults to 2): Number of levels in the model. + bottleneck_blocks (int, *optional*, defaults to 12): Number of bottleneck blocks in the model. + embed_dim (int, *optional*, defaults to 384): Number of hidden channels in the model. + latent_channels (int, *optional*, defaults to 4): Number of latent channels in the VQ-VAE model. + num_vq_embeddings (int, *optional*, defaults to 8192): Number of codebook vectors in the VQ-VAE. + scale_factor (float, *optional*, defaults to 0.3764): Scaling factor of the latent space. + """ + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + up_down_scale_factor: int = 2, + levels: int = 2, + bottleneck_blocks: int = 12, + embed_dim: int = 384, + latent_channels: int = 4, + num_vq_embeddings: int = 8192, + scale_factor: float = 0.3764, + ): + super().__init__() + + c_levels = [embed_dim // (2**i) for i in reversed(range(levels))] + # Encoder blocks + self.in_block = nn.Sequential( + nn.PixelUnshuffle(up_down_scale_factor), + nn.Conv2d(in_channels * up_down_scale_factor**2, c_levels[0], kernel_size=1), + ) + down_blocks = [] + for i in range(levels): + if i > 0: + down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1)) + block = MixingResidualBlock(c_levels[i], c_levels[i] * 4) + down_blocks.append(block) + down_blocks.append( + nn.Sequential( + nn.Conv2d(c_levels[-1], latent_channels, kernel_size=1, bias=False), + nn.BatchNorm2d(latent_channels), # then normalize them to have mean 0 and std 1 + ) + ) + self.down_blocks = nn.Sequential(*down_blocks) + + # Vector Quantizer + self.vquantizer = VectorQuantizer(num_vq_embeddings, vq_embed_dim=latent_channels, legacy=False, beta=0.25) + + # Decoder blocks + up_blocks = [nn.Sequential(nn.Conv2d(latent_channels, c_levels[-1], kernel_size=1))] + for i in range(levels): + for j in range(bottleneck_blocks if i == 0 else 1): + block = MixingResidualBlock(c_levels[levels - 1 - i], c_levels[levels - 1 - i] * 4) + up_blocks.append(block) + if i < levels - 1: + up_blocks.append( + nn.ConvTranspose2d( + c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2, padding=1 + ) + ) + self.up_blocks = nn.Sequential(*up_blocks) + self.out_block = nn.Sequential( + nn.Conv2d(c_levels[0], out_channels * up_down_scale_factor**2, kernel_size=1), + nn.PixelShuffle(up_down_scale_factor), + ) + + @apply_forward_hook + def encode(self, x: torch.Tensor, return_dict: bool = True) -> VQEncoderOutput: + h = self.in_block(x) + h = self.down_blocks(h) + + if not return_dict: + return (h,) + + return VQEncoderOutput(latents=h) + + @apply_forward_hook + def decode( + self, h: torch.Tensor, force_not_quantize: bool = True, return_dict: bool = True + ) -> DecoderOutput | torch.Tensor: + if not force_not_quantize: + quant, _, _ = self.vquantizer(h) + else: + quant = h + + x = self.up_blocks(quant) + dec = self.out_block(x) + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + def forward(self, sample: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor: + r""" + Args: + sample (`torch.Tensor`): Input sample. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + h = self.encode(x).latents + dec = self.decode(h).sample + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) diff --git a/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py b/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py new file mode 100644 index 0000000000000000000000000000000000000000..73e71b3076fbca259ae76138bc4ab3d3797e2755 --- /dev/null +++ b/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py @@ -0,0 +1,76 @@ +import torch +import torch.nn as nn + +from ...models.attention_processor import Attention + + +class WuerstchenLayerNorm(nn.LayerNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x): + x = x.permute(0, 2, 3, 1) + x = super().forward(x) + return x.permute(0, 3, 1, 2) + + +class TimestepBlock(nn.Module): + def __init__(self, c, c_timestep): + super().__init__() + + self.mapper = nn.Linear(c_timestep, c * 2) + + def forward(self, x, t): + a, b = self.mapper(t)[:, :, None, None].chunk(2, dim=1) + return x * (1 + a) + b + + +class ResBlock(nn.Module): + def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0): + super().__init__() + + self.depthwise = nn.Conv2d(c + c_skip, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c) + self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6) + self.channelwise = nn.Sequential( + nn.Linear(c, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), nn.Linear(c * 4, c) + ) + + def forward(self, x, x_skip=None): + x_res = x + if x_skip is not None: + x = torch.cat([x, x_skip], dim=1) + x = self.norm(self.depthwise(x)).permute(0, 2, 3, 1) + x = self.channelwise(x).permute(0, 3, 1, 2) + return x + x_res + + +# from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105 +class GlobalResponseNorm(nn.Module): + def __init__(self, dim): + super().__init__() + self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) + self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) + + def forward(self, x): + agg_norm = torch.norm(x, p=2, dim=(1, 2), keepdim=True) + stand_div_norm = agg_norm / (agg_norm.mean(dim=-1, keepdim=True) + 1e-6) + return self.gamma * (x * stand_div_norm) + self.beta + x + + +class AttnBlock(nn.Module): + def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0): + super().__init__() + + self.self_attn = self_attn + self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6) + self.attention = Attention(query_dim=c, heads=nhead, dim_head=c // nhead, dropout=dropout, bias=True) + self.kv_mapper = nn.Sequential(nn.SiLU(), nn.Linear(c_cond, c)) + + def forward(self, x, kv): + kv = self.kv_mapper(kv) + norm_x = self.norm(x) + if self.self_attn: + batch_size, channel, _, _ = x.shape + kv = torch.cat([norm_x.view(batch_size, channel, -1).transpose(1, 2), kv], dim=1) + x = x + self.attention(norm_x, encoder_hidden_states=kv) + return x diff --git a/diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py b/diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py new file mode 100644 index 0000000000000000000000000000000000000000..77ae597655d164e25add03a641a6b8eca5396200 --- /dev/null +++ b/diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py @@ -0,0 +1,254 @@ +# Copyright (c) 2023 Dominic Rampas MIT License +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import numpy as np +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models.modeling_utils import ModelMixin +from .modeling_wuerstchen_common import AttnBlock, GlobalResponseNorm, TimestepBlock, WuerstchenLayerNorm + + +class WuerstchenDiffNeXt(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + c_in=4, + c_out=4, + c_r=64, + patch_size=2, + c_cond=1024, + c_hidden=[320, 640, 1280, 1280], + nhead=[-1, 10, 20, 20], + blocks=[4, 4, 14, 4], + level_config=["CT", "CTA", "CTA", "CTA"], + inject_effnet=[False, True, True, True], + effnet_embd=16, + clip_embd=1024, + kernel_size=3, + dropout=0.1, + ): + super().__init__() + self.c_r = c_r + self.c_cond = c_cond + if not isinstance(dropout, list): + dropout = [dropout] * len(c_hidden) + + # CONDITIONING + self.clip_mapper = nn.Linear(clip_embd, c_cond) + self.effnet_mappers = nn.ModuleList( + [ + nn.Conv2d(effnet_embd, c_cond, kernel_size=1) if inject else None + for inject in inject_effnet + list(reversed(inject_effnet)) + ] + ) + self.seq_norm = nn.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6) + + self.embedding = nn.Sequential( + nn.PixelUnshuffle(patch_size), + nn.Conv2d(c_in * (patch_size**2), c_hidden[0], kernel_size=1), + WuerstchenLayerNorm(c_hidden[0], elementwise_affine=False, eps=1e-6), + ) + + def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0): + if block_type == "C": + return ResBlockStageB(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout) + elif block_type == "A": + return AttnBlock(c_hidden, c_cond, nhead, self_attn=True, dropout=dropout) + elif block_type == "T": + return TimestepBlock(c_hidden, c_r) + else: + raise ValueError(f"Block type {block_type} not supported") + + # BLOCKS + # -- down blocks + self.down_blocks = nn.ModuleList() + for i in range(len(c_hidden)): + down_block = nn.ModuleList() + if i > 0: + down_block.append( + nn.Sequential( + WuerstchenLayerNorm(c_hidden[i - 1], elementwise_affine=False, eps=1e-6), + nn.Conv2d(c_hidden[i - 1], c_hidden[i], kernel_size=2, stride=2), + ) + ) + for _ in range(blocks[i]): + for block_type in level_config[i]: + c_skip = c_cond if inject_effnet[i] else 0 + down_block.append(get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i])) + self.down_blocks.append(down_block) + + # -- up blocks + self.up_blocks = nn.ModuleList() + for i in reversed(range(len(c_hidden))): + up_block = nn.ModuleList() + for j in range(blocks[i]): + for k, block_type in enumerate(level_config[i]): + c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0 + c_skip += c_cond if inject_effnet[i] else 0 + up_block.append(get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i])) + if i > 0: + up_block.append( + nn.Sequential( + WuerstchenLayerNorm(c_hidden[i], elementwise_affine=False, eps=1e-6), + nn.ConvTranspose2d(c_hidden[i], c_hidden[i - 1], kernel_size=2, stride=2), + ) + ) + self.up_blocks.append(up_block) + + # OUTPUT + self.clf = nn.Sequential( + WuerstchenLayerNorm(c_hidden[0], elementwise_affine=False, eps=1e-6), + nn.Conv2d(c_hidden[0], 2 * c_out * (patch_size**2), kernel_size=1), + nn.PixelShuffle(patch_size), + ) + + # --- WEIGHT INIT --- + self.apply(self._init_weights) + + def _init_weights(self, m): + # General init + if isinstance(m, (nn.Conv2d, nn.Linear)): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + for mapper in self.effnet_mappers: + if mapper is not None: + nn.init.normal_(mapper.weight, std=0.02) # conditionings + nn.init.normal_(self.clip_mapper.weight, std=0.02) # conditionings + nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs + nn.init.constant_(self.clf[1].weight, 0) # outputs + + # blocks + for level_block in self.down_blocks + self.up_blocks: + for block in level_block: + if isinstance(block, ResBlockStageB): + block.channelwise[-1].weight.data *= np.sqrt(1 / sum(self.config.blocks)) + elif isinstance(block, TimestepBlock): + nn.init.constant_(block.mapper.weight, 0) + + def gen_r_embedding(self, r, max_positions=10000): + r = r * max_positions + half_dim = self.c_r // 2 + emb = math.log(max_positions) / (half_dim - 1) + emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp() + emb = r[:, None] * emb[None, :] + emb = torch.cat([emb.sin(), emb.cos()], dim=1) + if self.c_r % 2 == 1: # zero pad + emb = nn.functional.pad(emb, (0, 1), mode="constant") + return emb.to(dtype=r.dtype) + + def gen_c_embeddings(self, clip): + clip = self.clip_mapper(clip) + clip = self.seq_norm(clip) + return clip + + def _down_encode(self, x, r_embed, effnet, clip=None): + level_outputs = [] + for i, down_block in enumerate(self.down_blocks): + effnet_c = None + for block in down_block: + if isinstance(block, ResBlockStageB): + if effnet_c is None and self.effnet_mappers[i] is not None: + dtype = effnet.dtype + effnet_c = self.effnet_mappers[i]( + nn.functional.interpolate( + effnet.float(), size=x.shape[-2:], mode="bicubic", antialias=True, align_corners=True + ).to(dtype) + ) + skip = effnet_c if self.effnet_mappers[i] is not None else None + x = block(x, skip) + elif isinstance(block, AttnBlock): + x = block(x, clip) + elif isinstance(block, TimestepBlock): + x = block(x, r_embed) + else: + x = block(x) + level_outputs.insert(0, x) + return level_outputs + + def _up_decode(self, level_outputs, r_embed, effnet, clip=None): + x = level_outputs[0] + for i, up_block in enumerate(self.up_blocks): + effnet_c = None + for j, block in enumerate(up_block): + if isinstance(block, ResBlockStageB): + if effnet_c is None and self.effnet_mappers[len(self.down_blocks) + i] is not None: + dtype = effnet.dtype + effnet_c = self.effnet_mappers[len(self.down_blocks) + i]( + nn.functional.interpolate( + effnet.float(), size=x.shape[-2:], mode="bicubic", antialias=True, align_corners=True + ).to(dtype) + ) + skip = level_outputs[i] if j == 0 and i > 0 else None + if effnet_c is not None: + if skip is not None: + skip = torch.cat([skip, effnet_c], dim=1) + else: + skip = effnet_c + x = block(x, skip) + elif isinstance(block, AttnBlock): + x = block(x, clip) + elif isinstance(block, TimestepBlock): + x = block(x, r_embed) + else: + x = block(x) + return x + + def forward(self, x, r, effnet, clip=None, x_cat=None, eps=1e-3, return_noise=True): + if x_cat is not None: + x = torch.cat([x, x_cat], dim=1) + # Process the conditioning embeddings + r_embed = self.gen_r_embedding(r) + if clip is not None: + clip = self.gen_c_embeddings(clip) + + # Model Blocks + x_in = x + x = self.embedding(x) + level_outputs = self._down_encode(x, r_embed, effnet, clip) + x = self._up_decode(level_outputs, r_embed, effnet, clip) + a, b = self.clf(x).chunk(2, dim=1) + b = b.sigmoid() * (1 - eps * 2) + eps + if return_noise: + return (x_in - a) / b + else: + return a, b + + +class ResBlockStageB(nn.Module): + def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0): + super().__init__() + self.depthwise = nn.Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c) + self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6) + self.channelwise = nn.Sequential( + nn.Linear(c + c_skip, c * 4), + nn.GELU(), + GlobalResponseNorm(c * 4), + nn.Dropout(dropout), + nn.Linear(c * 4, c), + ) + + def forward(self, x, x_skip=None): + x_res = x + x = self.norm(self.depthwise(x)) + if x_skip is not None: + x = torch.cat([x, x_skip], dim=1) + x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + return x + x_res diff --git a/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py b/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py new file mode 100644 index 0000000000000000000000000000000000000000..dbdd50871b43e6e4930451f2090a30fb95e14501 --- /dev/null +++ b/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py @@ -0,0 +1,113 @@ +# Copyright (c) 2023 Dominic Rampas MIT License +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin +from ...models.attention import AttentionMixin +from ...models.attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + AttnAddedKVProcessor, + AttnProcessor, +) +from ...models.modeling_utils import ModelMixin +from .modeling_wuerstchen_common import AttnBlock, ResBlock, TimestepBlock, WuerstchenLayerNorm + + +class WuerstchenPrior(ModelMixin, AttentionMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin): + unet_name = "prior" + _supports_gradient_checkpointing = True + + @register_to_config + def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, dropout=0.1): + super().__init__() + + self.c_r = c_r + self.projection = nn.Conv2d(c_in, c, kernel_size=1) + self.cond_mapper = nn.Sequential( + nn.Linear(c_cond, c), + nn.LeakyReLU(0.2), + nn.Linear(c, c), + ) + + self.blocks = nn.ModuleList() + for _ in range(depth): + self.blocks.append(ResBlock(c, dropout=dropout)) + self.blocks.append(TimestepBlock(c, c_r)) + self.blocks.append(AttnBlock(c, c, nhead, self_attn=True, dropout=dropout)) + self.out = nn.Sequential( + WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6), + nn.Conv2d(c, c_in * 2, kernel_size=1), + ) + + self.gradient_checkpointing = False + self.set_default_attn_processor() + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnAddedKVProcessor() + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + def gen_r_embedding(self, r, max_positions=10000): + r = r * max_positions + half_dim = self.c_r // 2 + emb = math.log(max_positions) / (half_dim - 1) + emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp() + emb = r[:, None] * emb[None, :] + emb = torch.cat([emb.sin(), emb.cos()], dim=1) + if self.c_r % 2 == 1: # zero pad + emb = nn.functional.pad(emb, (0, 1), mode="constant") + return emb.to(dtype=r.dtype) + + def forward(self, x, r, c): + x_in = x + x = self.projection(x) + c_embed = self.cond_mapper(c) + r_embed = self.gen_r_embedding(r) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for block in self.blocks: + if isinstance(block, AttnBlock): + x = self._gradient_checkpointing_func(block, x, c_embed) + elif isinstance(block, TimestepBlock): + x = self._gradient_checkpointing_func(block, x, r_embed) + else: + x = self._gradient_checkpointing_func(block, x) + else: + for block in self.blocks: + if isinstance(block, AttnBlock): + x = block(x, c_embed) + elif isinstance(block, TimestepBlock): + x = block(x, r_embed) + else: + x = block(x) + a, b = self.out(x).chunk(2, dim=1) + return (x_in - a) / ((1 - b).abs() + 1e-5) diff --git a/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py new file mode 100644 index 0000000000000000000000000000000000000000..cce05c1892013d5ec3dd69adc3b190abcefb7aad --- /dev/null +++ b/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -0,0 +1,449 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable + +import numpy as np +import torch +from transformers import CLIPTextModel, CLIPTokenizer + +from ...schedulers import DDPMWuerstchenScheduler +from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, ImagePipelineOutput +from .modeling_paella_vq_model import PaellaVQModel +from .modeling_wuerstchen_diffnext import WuerstchenDiffNeXt + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import WuerstchenPriorPipeline, WuerstchenDecoderPipeline + + >>> prior_pipe = WuerstchenPriorPipeline.from_pretrained( + ... "warp-ai/wuerstchen-prior", torch_dtype=torch.float16 + ... ).to("cuda") + >>> gen_pipe = WuerstchenDecoderPipeline.from_pretrain("warp-ai/wuerstchen", torch_dtype=torch.float16).to( + ... "cuda" + ... ) + + >>> prompt = "an image of a shiba inu, donning a spacesuit and helmet" + >>> prior_output = pipe(prompt) + >>> images = gen_pipe(prior_output.image_embeddings, prompt=prompt) + ``` +""" + + +class WuerstchenDecoderPipeline(DeprecatedPipelineMixin, DiffusionPipeline): + """ + Pipeline for generating images from the Wuerstchen model. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + tokenizer (`CLIPTokenizer`): + The CLIP tokenizer. + text_encoder (`CLIPTextModel`): + The CLIP text encoder. + decoder ([`WuerstchenDiffNeXt`]): + The WuerstchenDiffNeXt unet decoder. + vqgan ([`PaellaVQModel`]): + The VQGAN model. + scheduler ([`DDPMWuerstchenScheduler`]): + A scheduler to be used in combination with `prior` to generate image embedding. + latent_dim_scale (float, `optional`, defaults to 10.67): + Multiplier to determine the VQ latent space size from the image embeddings. If the image embeddings are + height=24 and width=24, the VQ latent shape needs to be height=int(24*10.67)=256 and + width=int(24*10.67)=256 in order to match the training conditions. + """ + + model_cpu_offload_seq = "text_encoder->decoder->vqgan" + _callback_tensor_inputs = [ + "latents", + "text_encoder_hidden_states", + "negative_prompt_embeds", + "image_embeddings", + ] + + def __init__( + self, + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModel, + decoder: WuerstchenDiffNeXt, + scheduler: DDPMWuerstchenScheduler, + vqgan: PaellaVQModel, + latent_dim_scale: float = 10.67, + ) -> None: + super().__init__() + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + decoder=decoder, + scheduler=scheduler, + vqgan=vqgan, + ) + self.register_to_config(latent_dim_scale=latent_dim_scale) + + # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + ): + batch_size = len(prompt) if isinstance(prompt, list) else 1 + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + attention_mask = text_inputs.attention_mask + + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + attention_mask = attention_mask[:, : self.tokenizer.model_max_length] + + text_encoder_output = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask.to(device)) + text_encoder_hidden_states = text_encoder_output.last_hidden_state + text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + + uncond_text_encoder_hidden_states = None + if do_classifier_free_guidance: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + negative_prompt_embeds_text_encoder_output = self.text_encoder( + uncond_input.input_ids.to(device), attention_mask=uncond_input.attention_mask.to(device) + ) + + uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_text_encoder_hidden_states.shape[1] + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + # done duplicates + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + return text_encoder_hidden_states, uncond_text_encoder_hidden_states + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image_embeddings: torch.Tensor | list[torch.Tensor], + prompt: str | list[str] = None, + num_inference_steps: int = 12, + timesteps: list[float] | None = None, + guidance_scale: float = 0.0, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + **kwargs, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + image_embedding (`torch.Tensor` or `list[torch.Tensor]`): + Image Embeddings either extracted from an image or generated by a Prior Model. + prompt (`str` or `list[str]`): + The prompt or prompts to guide the image generation. + num_inference_steps (`int`, *optional*, defaults to 12): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps are used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 0.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `decoder_guidance_scale` is defined as `w` of + equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by + setting `decoder_guidance_scale > 1`. Higher guidance scale encourages to generate images that are + closely linked to the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `decoder_guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` [`~pipelines.ImagePipelineOutput`] if `return_dict` is True, + otherwise a `tuple`. When returning a tuple, the first element is a list with the generated image + embeddings. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + # 0. Define commonly used variables + device = self._execution_device + dtype = self.decoder.dtype + self._guidance_scale = guidance_scale + + # 1. Check inputs. Raise error if not correct + if not isinstance(prompt, list): + if isinstance(prompt, str): + prompt = [prompt] + else: + raise TypeError(f"'prompt' must be of type 'list' or 'str', but got {type(prompt)}.") + + if self.do_classifier_free_guidance: + if negative_prompt is not None and not isinstance(negative_prompt, list): + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + else: + raise TypeError( + f"'negative_prompt' must be of type 'list' or 'str', but got {type(negative_prompt)}." + ) + + if isinstance(image_embeddings, list): + image_embeddings = torch.cat(image_embeddings, dim=0) + if isinstance(image_embeddings, np.ndarray): + image_embeddings = torch.Tensor(image_embeddings, device=device).to(dtype=dtype) + if not isinstance(image_embeddings, torch.Tensor): + raise TypeError( + f"'image_embeddings' must be of type 'torch.Tensor' or 'np.array', but got {type(image_embeddings)}." + ) + + if not isinstance(num_inference_steps, int): + raise TypeError( + f"'num_inference_steps' must be of type 'int', but got {type(num_inference_steps)}\ + In Case you want to provide explicit timesteps, please use the 'timesteps' argument." + ) + + # 2. Encode caption + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + image_embeddings.size(0) * num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + ) + text_encoder_hidden_states = ( + torch.cat([prompt_embeds, negative_prompt_embeds]) if negative_prompt_embeds is not None else prompt_embeds + ) + effnet = ( + torch.cat([image_embeddings, torch.zeros_like(image_embeddings)]) + if self.do_classifier_free_guidance + else image_embeddings + ) + + # 3. Determine latent shape of latents + latent_height = int(image_embeddings.size(2) * self.config.latent_dim_scale) + latent_width = int(image_embeddings.size(3) * self.config.latent_dim_scale) + latent_features_shape = (image_embeddings.size(0) * num_images_per_prompt, 4, latent_height, latent_width) + + # 4. Prepare and set timesteps + if timesteps is not None: + self.scheduler.set_timesteps(timesteps=timesteps, device=device) + timesteps = self.scheduler.timesteps + num_inference_steps = len(timesteps) + else: + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latents + latents = self.prepare_latents(latent_features_shape, dtype, device, generator, latents, self.scheduler) + + # 6. Run denoising loop + self._num_timesteps = len(timesteps[:-1]) + for i, t in enumerate(self.progress_bar(timesteps[:-1])): + ratio = t.expand(latents.size(0)).to(dtype) + # 7. Denoise latents + predicted_latents = self.decoder( + torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents, + r=torch.cat([ratio] * 2) if self.do_classifier_free_guidance else ratio, + effnet=effnet, + clip=text_encoder_hidden_states, + ) + + # 8. Check for classifier free guidance and apply it + if self.do_classifier_free_guidance: + predicted_latents_text, predicted_latents_uncond = predicted_latents.chunk(2) + predicted_latents = torch.lerp(predicted_latents_uncond, predicted_latents_text, self.guidance_scale) + + # 9. Renoise latents to next timestep + latents = self.scheduler.step( + model_output=predicted_latents, + timestep=ratio, + sample=latents, + generator=generator, + ).prev_sample + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + image_embeddings = callback_outputs.pop("image_embeddings", image_embeddings) + text_encoder_hidden_states = callback_outputs.pop( + "text_encoder_hidden_states", text_encoder_hidden_states + ) + + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type not in ["pt", "np", "pil", "latent"]: + raise ValueError( + f"Only the output types `pt`, `np`, `pil` and `latent` are supported not output_type={output_type}" + ) + + if not output_type == "latent": + # 10. Scale and decode the image latents with vq-vae + latents = self.vqgan.config.scale_factor * latents + images = self.vqgan.decode(latents).sample.clamp(0, 1) + if output_type == "np": + images = images.permute(0, 2, 3, 1).cpu().float().numpy() + elif output_type == "pil": + images = images.permute(0, 2, 3, 1).cpu().float().numpy() + images = self.numpy_to_pil(images) + else: + images = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return images + return ImagePipelineOutput(images) diff --git a/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py b/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py new file mode 100644 index 0000000000000000000000000000000000000000..16300a7c71d2446b381afc02951d7ac4c3d8ff32 --- /dev/null +++ b/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py @@ -0,0 +1,307 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable + +import torch +from transformers import CLIPTextModel, CLIPTokenizer + +from ...schedulers import DDPMWuerstchenScheduler +from ...utils import deprecate, replace_example_docstring +from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline +from .modeling_paella_vq_model import PaellaVQModel +from .modeling_wuerstchen_diffnext import WuerstchenDiffNeXt +from .modeling_wuerstchen_prior import WuerstchenPrior +from .pipeline_wuerstchen import WuerstchenDecoderPipeline +from .pipeline_wuerstchen_prior import WuerstchenPriorPipeline + + +TEXT2IMAGE_EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusions import WuerstchenCombinedPipeline + + >>> pipe = WuerstchenCombinedPipeline.from_pretrained("warp-ai/Wuerstchen", torch_dtype=torch.float16).to( + ... "cuda" + ... ) + >>> prompt = "an image of a shiba inu, donning a spacesuit and helmet" + >>> images = pipe(prompt=prompt) + ``` +""" + + +class WuerstchenCombinedPipeline(DeprecatedPipelineMixin, DiffusionPipeline): + """ + Combined Pipeline for text-to-image generation using Wuerstchen + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + tokenizer (`CLIPTokenizer`): + The decoder tokenizer to be used for text inputs. + text_encoder (`CLIPTextModel`): + The decoder text encoder to be used for text inputs. + decoder (`WuerstchenDiffNeXt`): + The decoder model to be used for decoder image generation pipeline. + scheduler (`DDPMWuerstchenScheduler`): + The scheduler to be used for decoder image generation pipeline. + vqgan (`PaellaVQModel`): + The VQGAN model to be used for decoder image generation pipeline. + prior_tokenizer (`CLIPTokenizer`): + The prior tokenizer to be used for text inputs. + prior_text_encoder (`CLIPTextModel`): + The prior text encoder to be used for text inputs. + prior_prior (`WuerstchenPrior`): + The prior model to be used for prior pipeline. + prior_scheduler (`DDPMWuerstchenScheduler`): + The scheduler to be used for prior pipeline. + """ + + _last_supported_version = "0.33.1" + _load_connected_pipes = True + + def __init__( + self, + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModel, + decoder: WuerstchenDiffNeXt, + scheduler: DDPMWuerstchenScheduler, + vqgan: PaellaVQModel, + prior_tokenizer: CLIPTokenizer, + prior_text_encoder: CLIPTextModel, + prior_prior: WuerstchenPrior, + prior_scheduler: DDPMWuerstchenScheduler, + ): + super().__init__() + + self.register_modules( + text_encoder=text_encoder, + tokenizer=tokenizer, + decoder=decoder, + scheduler=scheduler, + vqgan=vqgan, + prior_prior=prior_prior, + prior_text_encoder=prior_text_encoder, + prior_tokenizer=prior_tokenizer, + prior_scheduler=prior_scheduler, + ) + self.prior_pipe = WuerstchenPriorPipeline( + prior=prior_prior, + text_encoder=prior_text_encoder, + tokenizer=prior_tokenizer, + scheduler=prior_scheduler, + ) + self.decoder_pipe = WuerstchenDecoderPipeline( + text_encoder=text_encoder, + tokenizer=tokenizer, + decoder=decoder, + scheduler=scheduler, + vqgan=vqgan, + ) + + def enable_xformers_memory_efficient_attention(self, attention_op: Callable | None = None): + self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op) + + def enable_model_cpu_offload(self, gpu_id: int | None = None, device: torch.device | str = None): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + self.prior_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device) + self.decoder_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device) + + def enable_sequential_cpu_offload(self, gpu_id: int | None = None, device: torch.device | str = None): + r""" + Offloads all models (`unet`, `text_encoder`, `vae`, and `safety checker` state dicts) to CPU using 🤗 + Accelerate, significantly reducing memory usage. Models are moved to a `torch.device('meta')` and loaded on a + GPU only when their specific submodule's `forward` method is called. Offloading happens on a submodule basis. + Memory savings are higher than using `enable_model_cpu_offload`, but performance is lower. + """ + self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device) + self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device) + + def progress_bar(self, iterable=None, total=None): + self.prior_pipe.progress_bar(iterable=iterable, total=total) + self.decoder_pipe.progress_bar(iterable=iterable, total=total) + + def set_progress_bar_config(self, **kwargs): + self.prior_pipe.set_progress_bar_config(**kwargs) + self.decoder_pipe.set_progress_bar_config(**kwargs) + + @torch.no_grad() + @replace_example_docstring(TEXT2IMAGE_EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] | None = None, + height: int = 512, + width: int = 512, + prior_num_inference_steps: int = 60, + prior_timesteps: list[float] | None = None, + prior_guidance_scale: float = 4.0, + num_inference_steps: int = 12, + decoder_timesteps: list[float] | None = None, + decoder_guidance_scale: float = 0.0, + negative_prompt: str | list[str] | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + num_images_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + prior_callback_on_step_end: Callable[[int, int], None] | None = None, + prior_callback_on_step_end_tensor_inputs: list[str] = ["latents"], + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + **kwargs, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`): + The prompt or prompts to guide the image generation for the prior and decoder. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.* + prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` + input argument. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + prior_guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `prior_guidance_scale` is defined as `w` of + equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by + setting `prior_guidance_scale > 1`. Higher guidance scale encourages to generate images that are + closely linked to the text `prompt`, usually at the expense of lower image quality. + prior_num_inference_steps (`int | dict[float, int]`, *optional*, defaults to 60): + The number of prior denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. For more specific timestep spacing, you can pass customized + `prior_timesteps` + num_inference_steps (`int`, *optional*, defaults to 12): + The number of decoder denoising steps. More denoising steps usually lead to a higher quality image at + the expense of slower inference. For more specific timestep spacing, you can pass customized + `timesteps` + prior_timesteps (`list[float]`, *optional*): + Custom timesteps to use for the denoising process for the prior. If not defined, equal spaced + `prior_num_inference_steps` timesteps are used. Must be in descending order. + decoder_timesteps (`list[float]`, *optional*): + Custom timesteps to use for the denoising process for the decoder. If not defined, equal spaced + `num_inference_steps` timesteps are used. Must be in descending order. + decoder_guidance_scale (`float`, *optional*, defaults to 0.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + prior_callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `prior_callback_on_step_end(self: DiffusionPipeline, step: int, timestep: + int, callback_kwargs: Dict)`. + prior_callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `prior_callback_on_step_end` function. The tensors specified in the + list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in + the `._callback_tensor_inputs` attribute of your pipeline class. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` [`~pipelines.ImagePipelineOutput`] if `return_dict` is True, + otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + prior_kwargs = {} + if kwargs.get("prior_callback", None) is not None: + prior_kwargs["callback"] = kwargs.pop("prior_callback") + deprecate( + "prior_callback", + "1.0.0", + "Passing `prior_callback` as an input argument to `__call__` is deprecated, consider use `prior_callback_on_step_end`", + ) + if kwargs.get("prior_callback_steps", None) is not None: + deprecate( + "prior_callback_steps", + "1.0.0", + "Passing `prior_callback_steps` as an input argument to `__call__` is deprecated, consider use `prior_callback_on_step_end`", + ) + prior_kwargs["callback_steps"] = kwargs.pop("prior_callback_steps") + + prior_outputs = self.prior_pipe( + prompt=prompt if prompt_embeds is None else None, + height=height, + width=width, + num_inference_steps=prior_num_inference_steps, + timesteps=prior_timesteps, + guidance_scale=prior_guidance_scale, + negative_prompt=negative_prompt if negative_prompt_embeds is None else None, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + num_images_per_prompt=num_images_per_prompt, + generator=generator, + latents=latents, + output_type="pt", + return_dict=False, + callback_on_step_end=prior_callback_on_step_end, + callback_on_step_end_tensor_inputs=prior_callback_on_step_end_tensor_inputs, + **prior_kwargs, + ) + image_embeddings = prior_outputs[0] + + outputs = self.decoder_pipe( + image_embeddings=image_embeddings, + prompt=prompt if prompt is not None else "", + num_inference_steps=num_inference_steps, + timesteps=decoder_timesteps, + guidance_scale=decoder_guidance_scale, + negative_prompt=negative_prompt, + generator=generator, + output_type=output_type, + return_dict=return_dict, + callback_on_step_end=callback_on_step_end, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + **kwargs, + ) + + return outputs diff --git a/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py b/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py new file mode 100644 index 0000000000000000000000000000000000000000..e79fcf8378aa05dc8ee7acd1678ca5476284b9bd --- /dev/null +++ b/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py @@ -0,0 +1,528 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from math import ceil +from typing import Callable + +import numpy as np +import torch +from transformers import CLIPTextModel, CLIPTokenizer + +from ...loaders import StableDiffusionLoraLoaderMixin +from ...schedulers import DDPMWuerstchenScheduler +from ...utils import BaseOutput, deprecate, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .modeling_wuerstchen_prior import WuerstchenPrior + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +DEFAULT_STAGE_C_TIMESTEPS = list(np.linspace(1.0, 2 / 3, 20)) + list(np.linspace(2 / 3, 0.0, 11))[1:] + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import WuerstchenPriorPipeline + + >>> prior_pipe = WuerstchenPriorPipeline.from_pretrained( + ... "warp-ai/wuerstchen-prior", torch_dtype=torch.float16 + ... ).to("cuda") + + >>> prompt = "an image of a shiba inu, donning a spacesuit and helmet" + >>> prior_output = pipe(prompt) + ``` +""" + + +@dataclass +class WuerstchenPriorPipelineOutput(BaseOutput): + """ + Output class for WuerstchenPriorPipeline. + + Args: + image_embeddings (`torch.Tensor` or `np.ndarray`) + Prior image embeddings for text prompt + + """ + + image_embeddings: torch.Tensor | np.ndarray + + +class WuerstchenPriorPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin): + """ + Pipeline for generating image prior for Wuerstchen. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + The pipeline also inherits the following loading methods: + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + + Args: + prior ([`Prior`]): + The canonical unCLIP prior to approximate the image embedding from the text embedding. + text_encoder ([`CLIPTextModelWithProjection`]): + Frozen text-encoder. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + scheduler ([`DDPMWuerstchenScheduler`]): + A scheduler to be used in combination with `prior` to generate image embedding. + latent_mean ('float', *optional*, defaults to 42.0): + Mean value for latent diffusers. + latent_std ('float', *optional*, defaults to 1.0): + Standard value for latent diffusers. + resolution_multiple ('float', *optional*, defaults to 42.67): + Default resolution for multiple images generated. + """ + + unet_name = "prior" + text_encoder_name = "text_encoder" + model_cpu_offload_seq = "text_encoder->prior" + _callback_tensor_inputs = ["latents", "text_encoder_hidden_states", "negative_prompt_embeds"] + _lora_loadable_modules = ["prior", "text_encoder"] + + def __init__( + self, + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModel, + prior: WuerstchenPrior, + scheduler: DDPMWuerstchenScheduler, + latent_mean: float = 42.0, + latent_std: float = 1.0, + resolution_multiple: float = 42.67, + ) -> None: + super().__init__() + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + prior=prior, + scheduler=scheduler, + ) + self.register_to_config( + latent_mean=latent_mean, latent_std=latent_std, resolution_multiple=resolution_multiple + ) + + # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + def encode_prompt( + self, + device, + num_images_per_prompt, + do_classifier_free_guidance, + prompt=None, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + ): + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + attention_mask = text_inputs.attention_mask + + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + attention_mask = attention_mask[:, : self.tokenizer.model_max_length] + + text_encoder_output = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask.to(device) + ) + prompt_embeds = text_encoder_output.last_hidden_state + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + if negative_prompt_embeds is None and do_classifier_free_guidance: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + negative_prompt_embeds_text_encoder_output = self.text_encoder( + uncond_input.input_ids.to(device), attention_mask=uncond_input.attention_mask.to(device) + ) + + negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.last_hidden_state + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + # done duplicates + + return prompt_embeds, negative_prompt_embeds + + def check_inputs( + self, + prompt, + negative_prompt, + num_inference_steps, + do_classifier_free_guidance, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if not isinstance(num_inference_steps, int): + raise TypeError( + f"'num_inference_steps' must be of type 'int', but got {type(num_inference_steps)}\ + In Case you want to provide explicit timesteps, please use the 'timesteps' argument." + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] | None = None, + height: int = 1024, + width: int = 1024, + num_inference_steps: int = 60, + timesteps: list[float] = None, + guidance_scale: float = 8.0, + negative_prompt: str | list[str] | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + output_type: str | None = "pt", + return_dict: bool = True, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + **kwargs, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to 1024): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 1024): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 60): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps are used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 8.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `decoder_guidance_scale` is defined as `w` of + equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by + setting `decoder_guidance_scale > 1`. Higher guidance scale encourages to generate images that are + closely linked to the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `decoder_guidance_scale` is less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.WuerstchenPriorPipelineOutput`] or `tuple` [`~pipelines.WuerstchenPriorPipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated image embeddings. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + # 0. Define commonly used variables + device = self._execution_device + self._guidance_scale = guidance_scale + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 1. Check inputs. Raise error if not correct + if prompt is not None and not isinstance(prompt, list): + if isinstance(prompt, str): + prompt = [prompt] + else: + raise TypeError(f"'prompt' must be of type 'list' or 'str', but got {type(prompt)}.") + + if self.do_classifier_free_guidance: + if negative_prompt is not None and not isinstance(negative_prompt, list): + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + else: + raise TypeError( + f"'negative_prompt' must be of type 'list' or 'str', but got {type(negative_prompt)}." + ) + + self.check_inputs( + prompt, + negative_prompt, + num_inference_steps, + self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # 2. Encode caption + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_encoder_hidden_states = ( + torch.cat([prompt_embeds, negative_prompt_embeds]) if negative_prompt_embeds is not None else prompt_embeds + ) + + # 3. Determine latent shape of image embeddings + dtype = text_encoder_hidden_states.dtype + latent_height = ceil(height / self.config.resolution_multiple) + latent_width = ceil(width / self.config.resolution_multiple) + num_channels = self.prior.config.c_in + effnet_features_shape = (num_images_per_prompt * batch_size, num_channels, latent_height, latent_width) + + # 4. Prepare and set timesteps + if timesteps is not None: + self.scheduler.set_timesteps(timesteps=timesteps, device=device) + timesteps = self.scheduler.timesteps + num_inference_steps = len(timesteps) + else: + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latents + latents = self.prepare_latents(effnet_features_shape, dtype, device, generator, latents, self.scheduler) + + # 6. Run denoising loop + self._num_timesteps = len(timesteps[:-1]) + for i, t in enumerate(self.progress_bar(timesteps[:-1])): + ratio = t.expand(latents.size(0)).to(dtype) + + # 7. Denoise image embeddings + predicted_image_embedding = self.prior( + torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents, + r=torch.cat([ratio] * 2) if self.do_classifier_free_guidance else ratio, + c=text_encoder_hidden_states, + ) + + # 8. Check for classifier free guidance and apply it + if self.do_classifier_free_guidance: + predicted_image_embedding_text, predicted_image_embedding_uncond = predicted_image_embedding.chunk(2) + predicted_image_embedding = torch.lerp( + predicted_image_embedding_uncond, predicted_image_embedding_text, self.guidance_scale + ) + + # 9. Renoise latents to next timestep + latents = self.scheduler.step( + model_output=predicted_image_embedding, + timestep=ratio, + sample=latents, + generator=generator, + ).prev_sample + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + text_encoder_hidden_states = callback_outputs.pop( + "text_encoder_hidden_states", text_encoder_hidden_states + ) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + # 10. Denormalize the latents + latents = latents * self.config.latent_mean - self.config.latent_std + + # Offload all models + self.maybe_free_model_hooks() + + if output_type == "np": + latents = latents.cpu().float().numpy() + + if not return_dict: + return (latents,) + + return WuerstchenPriorPipelineOutput(latents) diff --git a/diffusers/quantizers/bitsandbytes/__init__.py b/diffusers/quantizers/bitsandbytes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9e745bc810faad76608b00ab9c493369466d0970 --- /dev/null +++ b/diffusers/quantizers/bitsandbytes/__init__.py @@ -0,0 +1,2 @@ +from .bnb_quantizer import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer +from .utils import dequantize_and_replace, dequantize_bnb_weight, replace_with_bnb_linear diff --git a/diffusers/quantizers/bitsandbytes/bnb_quantizer.py b/diffusers/quantizers/bitsandbytes/bnb_quantizer.py new file mode 100644 index 0000000000000000000000000000000000000000..16ff0d83b8c4311c72003c97c939bf42262610c6 --- /dev/null +++ b/diffusers/quantizers/bitsandbytes/bnb_quantizer.py @@ -0,0 +1,577 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Adapted from +https://github.com/huggingface/transformers/blob/c409cd81777fb27aadc043ed3d8339dbc020fb3b/src/transformers/quantizers/quantizer_bnb_4bit.py +""" + +from typing import TYPE_CHECKING, Any + +from ...utils import get_module_from_name +from ..base import DiffusersQuantizer + + +if TYPE_CHECKING: + from ...models.modeling_utils import ModelMixin + +from ...utils import ( + is_accelerate_available, + is_accelerate_version, + is_bitsandbytes_available, + is_bitsandbytes_version, + is_torch_available, + logging, +) + + +if is_torch_available(): + import torch + +logger = logging.get_logger(__name__) + + +class BnB4BitDiffusersQuantizer(DiffusersQuantizer): + """ + 4-bit quantization from bitsandbytes.py quantization method: + before loading: converts transformer layers into Linear4bit during loading: load 16bit weight and pass to the + layer object after: quantizes individual weights in Linear4bit into 4bit at the first .cuda() call saving: + from state dict, as usual; saves weights and `quant_state` components + loading: + need to locate `quant_state` components and pass to Param4bit constructor + """ + + use_keep_in_fp32_modules = True + requires_calibration = False + + def __init__(self, quantization_config, **kwargs): + super().__init__(quantization_config, **kwargs) + + if self.quantization_config.llm_int8_skip_modules is not None: + self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules + + def validate_environment(self, *args, **kwargs): + if not (torch.cuda.is_available() or torch.xpu.is_available()): + raise RuntimeError("No GPU found. A GPU is needed for quantization.") + if not is_accelerate_available() or is_accelerate_version("<", "0.26.0"): + raise ImportError( + "Using `bitsandbytes` 4-bit quantization requires Accelerate: `pip install 'accelerate>=0.26.0'`" + ) + if not is_bitsandbytes_available() or is_bitsandbytes_version("<", "0.43.3"): + raise ImportError( + "Using `bitsandbytes` 4-bit quantization requires the latest version of bitsandbytes: `pip install -U bitsandbytes`" + ) + + if kwargs.get("from_flax", False): + raise ValueError( + "Converting into 4-bit weights from flax weights is currently not supported, please make" + " sure the weights are in PyTorch format." + ) + + device_map = kwargs.get("device_map", None) + if ( + device_map is not None + and isinstance(device_map, dict) + and not self.quantization_config.llm_int8_enable_fp32_cpu_offload + ): + device_map_without_no_convert = { + key: device_map[key] for key in device_map.keys() if key not in self.modules_to_not_convert + } + if "cpu" in device_map_without_no_convert.values() or "disk" in device_map_without_no_convert.values(): + raise ValueError( + "Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit the " + "quantized model. If you want to dispatch the model on the CPU or the disk while keeping these modules " + "in 32-bit, you need to set `load_in_8bit_fp32_cpu_offload=True` and pass a custom `device_map` to " + "`from_pretrained`. Check " + "https://huggingface.co/docs/transformers/main/en/main_classes/quantization#offload-between-cpu-and-gpu " + "for more details. " + ) + + def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype": + if target_dtype != torch.int8: + from accelerate.utils import CustomDtype + + logger.info("target_dtype {target_dtype} is replaced by `CustomDtype.INT4` for 4-bit BnB quantization") + return CustomDtype.INT4 + else: + raise ValueError(f"Wrong `target_dtype` ({target_dtype}) provided.") + + def check_if_quantized_param( + self, + model: "ModelMixin", + param_value: "torch.Tensor", + param_name: str, + state_dict: dict[str, Any], + **kwargs, + ) -> bool: + import bitsandbytes as bnb + + module, tensor_name = get_module_from_name(model, param_name) + if isinstance(module._parameters.get(tensor_name, None), bnb.nn.Params4bit): + # Add here check for loaded components' dtypes once serialization is implemented + return True + elif isinstance(module, bnb.nn.Linear4bit) and tensor_name == "bias": + # bias could be loaded by regular set_module_tensor_to_device() from accelerate, + # but it would wrongly use uninitialized weight there. + return True + else: + return False + + def create_quantized_param( + self, + model: "ModelMixin", + param_value: "torch.Tensor", + param_name: str, + target_device: "torch.device", + state_dict: dict[str, Any], + unexpected_keys: list[str] | None = None, + **kwargs, + ): + import bitsandbytes as bnb + + module, tensor_name = get_module_from_name(model, param_name) + + if tensor_name not in module._parameters: + raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.") + + old_value = getattr(module, tensor_name) + + if tensor_name == "bias": + if param_value is None: + new_value = old_value.to(target_device) + else: + new_value = param_value.to(target_device) + + new_value = torch.nn.Parameter(new_value, requires_grad=old_value.requires_grad) + module._parameters[tensor_name] = new_value + return + + if not isinstance(module._parameters[tensor_name], bnb.nn.Params4bit): + raise ValueError("this function only loads `Linear4bit components`") + if ( + old_value.device == torch.device("meta") + and target_device not in ["meta", torch.device("meta")] + and param_value is None + ): + raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {target_device}.") + + # construct `new_value` for the module._parameters[tensor_name]: + if self.pre_quantized: + # 4bit loading. Collecting components for restoring quantized weight + # This can be expanded to make a universal call for any quantized weight loading + + if not self.is_serializable: + raise ValueError( + "Detected int4 weights but the version of bitsandbytes is not compatible with int4 serialization. " + "Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`." + ) + + if (param_name + ".quant_state.bitsandbytes__fp4" not in state_dict) and ( + param_name + ".quant_state.bitsandbytes__nf4" not in state_dict + ): + raise ValueError( + f"Supplied state dict for {param_name} does not contain `bitsandbytes__*` and possibly other `quantized_stats` components." + ) + + quantized_stats = {} + for k, v in state_dict.items(): + # `startswith` to counter for edge cases where `param_name` + # substring can be present in multiple places in the `state_dict` + if param_name + "." in k and k.startswith(param_name): + quantized_stats[k] = v + if unexpected_keys is not None and k in unexpected_keys: + unexpected_keys.remove(k) + + new_value = bnb.nn.Params4bit.from_prequantized( + data=param_value, + quantized_stats=quantized_stats, + requires_grad=False, + device=target_device, + ) + else: + new_value = param_value.to("cpu") + kwargs = old_value.__dict__ + new_value = bnb.nn.Params4bit(new_value, requires_grad=False, **kwargs).to(target_device) + + module._parameters[tensor_name] = new_value + + def check_quantized_param_shape(self, param_name, current_param, loaded_param): + current_param_shape = current_param.shape + loaded_param_shape = loaded_param.shape + + n = current_param_shape.numel() + inferred_shape = (n,) if "bias" in param_name else ((n + 1) // 2, 1) + if loaded_param_shape != inferred_shape: + raise ValueError( + f"Expected the flattened shape of the current param ({param_name}) to be {loaded_param_shape} but is {inferred_shape}." + ) + else: + return True + + def adjust_max_memory(self, max_memory: dict[str, int | str]) -> dict[str, int | str]: + # need more space for buffers that are created during quantization + max_memory = {key: val * 0.90 for key, val in max_memory.items()} + return max_memory + + def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": + if torch_dtype is None: + # We force the `dtype` to be float16, this is a requirement from `bitsandbytes` + logger.info( + "Overriding torch_dtype=%s with `torch_dtype=torch.float16` due to " + "requirements of `bitsandbytes` to enable model loading in 8-bit or 4-bit. " + "Pass your own torch_dtype to specify the dtype of the remaining non-linear layers or pass" + " torch_dtype=torch.float16 to remove this warning.", + torch_dtype, + ) + torch_dtype = torch.float16 + return torch_dtype + + def update_device_map(self, device_map): + if device_map is None: + if torch.xpu.is_available(): + current_device = f"xpu:{torch.xpu.current_device()}" + else: + current_device = f"cuda:{torch.cuda.current_device()}" + device_map = {"": current_device} + logger.info( + "The device_map was not initialized. " + "Setting device_map to {" + ": {current_device}}. " + "If you want to use the model for inference, please set device_map ='auto' " + ) + return device_map + + def _process_model_before_weight_loading( + self, + model: "ModelMixin", + device_map, + keep_in_fp32_modules: list[str] = [], + **kwargs, + ): + from .utils import replace_with_bnb_linear + + load_in_8bit_fp32_cpu_offload = self.quantization_config.llm_int8_enable_fp32_cpu_offload + + # We may keep some modules such as the `proj_out` in their original dtype for numerical stability reasons + self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules + + if not isinstance(self.modules_to_not_convert, list): + self.modules_to_not_convert = [self.modules_to_not_convert] + + self.modules_to_not_convert.extend(keep_in_fp32_modules) + + # Extend `self.modules_to_not_convert` to keys that are supposed to be offloaded to `cpu` or `disk` + if isinstance(device_map, dict) and len(device_map.keys()) > 1: + keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]] + + if len(keys_on_cpu) > 0 and not load_in_8bit_fp32_cpu_offload: + raise ValueError( + "If you want to offload some keys to `cpu` or `disk`, you need to set " + "`llm_int8_enable_fp32_cpu_offload=True`. Note that these modules will not be " + " converted to 8-bit but kept in 32-bit." + ) + self.modules_to_not_convert.extend(keys_on_cpu) + + # Purge `None`. + # Unlike `transformers`, we don't know if we should always keep certain modules in FP32 + # in case of diffusion transformer models. For language models and others alike, `lm_head` + # and tied modules are usually kept in FP32. + self.modules_to_not_convert = [module for module in self.modules_to_not_convert if module is not None] + + model = replace_with_bnb_linear( + model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config + ) + model.config.quantization_config = self.quantization_config + model.is_loaded_in_4bit = True + + def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs): + model.is_4bit_serializable = self.is_serializable + return model + + @property + def is_serializable(self): + # Because we're mandating `bitsandbytes` 0.43.3. + return True + + @property + def is_trainable(self) -> bool: + # Because we're mandating `bitsandbytes` 0.43.3. + return True + + def _dequantize(self, model): + from .utils import dequantize_and_replace + + is_model_on_cpu = model.device.type == "cpu" + if is_model_on_cpu: + logger.info( + "Model was found to be on CPU (could happen as a result of `enable_model_cpu_offload()`). So, moving it to GPU. After dequantization, will move the model back to CPU again to preserve the previous device." + ) + if torch.xpu.is_available(): + model.to(torch.xpu.current_device()) + else: + model.to(torch.cuda.current_device()) + + model = dequantize_and_replace( + model, self.modules_to_not_convert, quantization_config=self.quantization_config + ) + if is_model_on_cpu: + model.to("cpu") + return model + + +class BnB8BitDiffusersQuantizer(DiffusersQuantizer): + """ + 8-bit quantization from bitsandbytes quantization method: + before loading: converts transformer layers into Linear8bitLt during loading: load 16bit weight and pass to the + layer object after: quantizes individual weights in Linear8bitLt into 8bit at fitst .cuda() call + saving: + from state dict, as usual; saves weights and 'SCB' component + loading: + need to locate SCB component and pass to the Linear8bitLt object + """ + + use_keep_in_fp32_modules = True + requires_calibration = False + + def __init__(self, quantization_config, **kwargs): + super().__init__(quantization_config, **kwargs) + + if self.quantization_config.llm_int8_skip_modules is not None: + self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules + + def validate_environment(self, *args, **kwargs): + if not (torch.cuda.is_available() or torch.xpu.is_available()): + raise RuntimeError("No GPU found. A GPU is needed for quantization.") + if not is_accelerate_available() or is_accelerate_version("<", "0.26.0"): + raise ImportError( + "Using `bitsandbytes` 8-bit quantization requires Accelerate: `pip install 'accelerate>=0.26.0'`" + ) + if not is_bitsandbytes_available() or is_bitsandbytes_version("<", "0.43.3"): + raise ImportError( + "Using `bitsandbytes` 8-bit quantization requires the latest version of bitsandbytes: `pip install -U bitsandbytes`" + ) + + if kwargs.get("from_flax", False): + raise ValueError( + "Converting into 8-bit weights from flax weights is currently not supported, please make" + " sure the weights are in PyTorch format." + ) + + device_map = kwargs.get("device_map", None) + if ( + device_map is not None + and isinstance(device_map, dict) + and not self.quantization_config.llm_int8_enable_fp32_cpu_offload + ): + device_map_without_no_convert = { + key: device_map[key] for key in device_map.keys() if key not in self.modules_to_not_convert + } + if "cpu" in device_map_without_no_convert.values() or "disk" in device_map_without_no_convert.values(): + raise ValueError( + "Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit the " + "quantized model. If you want to dispatch the model on the CPU or the disk while keeping these modules " + "in 32-bit, you need to set `load_in_8bit_fp32_cpu_offload=True` and pass a custom `device_map` to " + "`from_pretrained`. Check " + "https://huggingface.co/docs/transformers/main/en/main_classes/quantization#offload-between-cpu-and-gpu " + "for more details. " + ) + + # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.adjust_max_memory + def adjust_max_memory(self, max_memory: dict[str, int | str]) -> dict[str, int | str]: + # need more space for buffers that are created during quantization + max_memory = {key: val * 0.90 for key, val in max_memory.items()} + return max_memory + + # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.update_torch_dtype + def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": + if torch_dtype is None: + # We force the `dtype` to be float16, this is a requirement from `bitsandbytes` + logger.info( + "Overriding torch_dtype=%s with `torch_dtype=torch.float16` due to " + "requirements of `bitsandbytes` to enable model loading in 8-bit or 4-bit. " + "Pass your own torch_dtype to specify the dtype of the remaining non-linear layers or pass" + " torch_dtype=torch.float16 to remove this warning.", + torch_dtype, + ) + torch_dtype = torch.float16 + return torch_dtype + + # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.update_device_map + def update_device_map(self, device_map): + if device_map is None: + if torch.xpu.is_available(): + current_device = f"xpu:{torch.xpu.current_device()}" + else: + current_device = f"cuda:{torch.cuda.current_device()}" + device_map = {"": current_device} + logger.info( + "The device_map was not initialized. " + "Setting device_map to {" + ": {current_device}}. " + "If you want to use the model for inference, please set device_map ='auto' " + ) + return device_map + + def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype": + if target_dtype != torch.int8: + logger.info("target_dtype {target_dtype} is replaced by `torch.int8` for 8-bit BnB quantization") + return torch.int8 + + def check_if_quantized_param( + self, + model: "ModelMixin", + param_value: "torch.Tensor", + param_name: str, + state_dict: dict[str, Any], + **kwargs, + ): + import bitsandbytes as bnb + + module, tensor_name = get_module_from_name(model, param_name) + if isinstance(module._parameters.get(tensor_name, None), bnb.nn.Int8Params): + if self.pre_quantized: + if param_name.replace("weight", "SCB") not in state_dict.keys(): + raise ValueError("Missing quantization component `SCB`") + if param_value.dtype != torch.int8: + raise ValueError( + f"Incompatible dtype `{param_value.dtype}` when loading 8-bit prequantized weight. Expected `torch.int8`." + ) + return True + return False + + def create_quantized_param( + self, + model: "ModelMixin", + param_value: "torch.Tensor", + param_name: str, + target_device: "torch.device", + state_dict: dict[str, Any], + unexpected_keys: list[str] | None = None, + **kwargs, + ): + import bitsandbytes as bnb + + fp16_statistics_key = param_name.replace("weight", "SCB") + fp16_weights_format_key = param_name.replace("weight", "weight_format") + + fp16_statistics = state_dict.get(fp16_statistics_key, None) + fp16_weights_format = state_dict.get(fp16_weights_format_key, None) + + module, tensor_name = get_module_from_name(model, param_name) + if tensor_name not in module._parameters: + raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.") + + old_value = getattr(module, tensor_name) + + if not isinstance(module._parameters[tensor_name], bnb.nn.Int8Params): + raise ValueError(f"Parameter `{tensor_name}` should only be a `bnb.nn.Int8Params` instance.") + if ( + old_value.device == torch.device("meta") + and target_device not in ["meta", torch.device("meta")] + and param_value is None + ): + raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {target_device}.") + + new_value = param_value.to("cpu") + if self.pre_quantized and not self.is_serializable: + raise ValueError( + "Detected int8 weights but the version of bitsandbytes is not compatible with int8 serialization. " + "Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`." + ) + + kwargs = old_value.__dict__ + new_value = bnb.nn.Int8Params(new_value, requires_grad=False, **kwargs).to(target_device) + + module._parameters[tensor_name] = new_value + if fp16_statistics is not None: + setattr(module.weight, "SCB", fp16_statistics.to(target_device)) + if unexpected_keys is not None: + unexpected_keys.remove(fp16_statistics_key) + + # We just need to pop the `weight_format` keys from the state dict to remove unneeded + # messages. The correct format is correctly retrieved during the first forward pass. + if fp16_weights_format is not None and unexpected_keys is not None: + unexpected_keys.remove(fp16_weights_format_key) + + # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer._process_model_after_weight_loading with 4bit->8bit + def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs): + model.is_8bit_serializable = self.is_serializable + return model + + # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer._process_model_before_weight_loading with 4bit->8bit + def _process_model_before_weight_loading( + self, + model: "ModelMixin", + device_map, + keep_in_fp32_modules: list[str] = [], + **kwargs, + ): + from .utils import replace_with_bnb_linear + + load_in_8bit_fp32_cpu_offload = self.quantization_config.llm_int8_enable_fp32_cpu_offload + + # We may keep some modules such as the `proj_out` in their original dtype for numerical stability reasons + self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules + + if not isinstance(self.modules_to_not_convert, list): + self.modules_to_not_convert = [self.modules_to_not_convert] + + self.modules_to_not_convert.extend(keep_in_fp32_modules) + + # Extend `self.modules_to_not_convert` to keys that are supposed to be offloaded to `cpu` or `disk` + if isinstance(device_map, dict) and len(device_map.keys()) > 1: + keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]] + + if len(keys_on_cpu) > 0 and not load_in_8bit_fp32_cpu_offload: + raise ValueError( + "If you want to offload some keys to `cpu` or `disk`, you need to set " + "`llm_int8_enable_fp32_cpu_offload=True`. Note that these modules will not be " + " converted to 8-bit but kept in 32-bit." + ) + self.modules_to_not_convert.extend(keys_on_cpu) + + # Purge `None`. + # Unlike `transformers`, we don't know if we should always keep certain modules in FP32 + # in case of diffusion transformer models. For language models and others alike, `lm_head` + # and tied modules are usually kept in FP32. + self.modules_to_not_convert = [module for module in self.modules_to_not_convert if module is not None] + + model = replace_with_bnb_linear( + model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config + ) + model.config.quantization_config = self.quantization_config + model.is_loaded_in_8bit = True + + @property + # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.is_serializable + def is_serializable(self): + # Because we're mandating `bitsandbytes` 0.43.3. + return True + + @property + # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.is_serializable + def is_trainable(self) -> bool: + # Because we're mandating `bitsandbytes` 0.43.3. + return True + + @property + def is_compileable(self) -> bool: + return True + + def _dequantize(self, model): + from .utils import dequantize_and_replace + + model = dequantize_and_replace( + model, self.modules_to_not_convert, quantization_config=self.quantization_config + ) + return model diff --git a/diffusers/quantizers/bitsandbytes/utils.py b/diffusers/quantizers/bitsandbytes/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..aea4e7dda57f297c2cbf615b92d444ee4b7b1e92 --- /dev/null +++ b/diffusers/quantizers/bitsandbytes/utils.py @@ -0,0 +1,318 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Adapted from +https://github.com/huggingface/transformers/blob/c409cd81777fb27aadc043ed3d8339dbc020fb3b/src/transformers/integrations/bitsandbytes.py +""" + +import inspect +from inspect import signature + +from ...utils import is_accelerate_available, is_bitsandbytes_available, is_torch_available, logging +from ..quantization_config import QuantizationMethod + + +if is_torch_available(): + import torch + import torch.nn as nn + +if is_bitsandbytes_available(): + import bitsandbytes as bnb + +if is_accelerate_available(): + import accelerate + from accelerate import init_empty_weights + from accelerate.hooks import add_hook_to_module, remove_hook_from_module + +logger = logging.get_logger(__name__) + + +def _replace_with_bnb_linear( + model, + modules_to_not_convert=None, + current_key_name=None, + quantization_config=None, + has_been_replaced=False, +): + """ + Private method that wraps the recursion for module replacement. + + Returns the converted model and a boolean that indicates if the conversion has been successful or not. + """ + for name, module in model.named_children(): + if current_key_name is None: + current_key_name = [] + current_key_name.append(name) + + if isinstance(module, nn.Linear) and name not in modules_to_not_convert: + # Check if the current key is not in the `modules_to_not_convert` + current_key_name_str = ".".join(current_key_name) + if not any( + (key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert + ): + with init_empty_weights(): + in_features = module.in_features + out_features = module.out_features + + if quantization_config.quantization_method() == "llm_int8": + model._modules[name] = bnb.nn.Linear8bitLt( + in_features, + out_features, + module.bias is not None, + has_fp16_weights=quantization_config.llm_int8_has_fp16_weight, + threshold=quantization_config.llm_int8_threshold, + ) + has_been_replaced = True + else: + if ( + quantization_config.llm_int8_skip_modules is not None + and name in quantization_config.llm_int8_skip_modules + ): + pass + else: + extra_kwargs = ( + {"quant_storage": quantization_config.bnb_4bit_quant_storage} + if "quant_storage" in list(signature(bnb.nn.Linear4bit).parameters) + else {} + ) + model._modules[name] = bnb.nn.Linear4bit( + in_features, + out_features, + module.bias is not None, + quantization_config.bnb_4bit_compute_dtype, + compress_statistics=quantization_config.bnb_4bit_use_double_quant, + quant_type=quantization_config.bnb_4bit_quant_type, + **extra_kwargs, + ) + has_been_replaced = True + # Store the module class in case we need to transpose the weight later + model._modules[name].source_cls = type(module) + # Force requires grad to False to avoid unexpected errors + model._modules[name].requires_grad_(False) + if len(list(module.children())) > 0: + _, has_been_replaced = _replace_with_bnb_linear( + module, + modules_to_not_convert, + current_key_name, + quantization_config, + has_been_replaced=has_been_replaced, + ) + # Remove the last key for recursion + current_key_name.pop(-1) + return model, has_been_replaced + + +def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name=None, quantization_config=None): + """ + Helper function to replace the `nn.Linear` layers within `model` with either `bnb.nn.Linear8bit` or + `bnb.nn.Linear4bit` using the `bitsandbytes` library. + + References: + * `bnb.nn.Linear8bit`: [LLM.int8(): 8-bit Matrix Multiplication for Transformers at + Scale](https://huggingface.co/papers/2208.07339) + * `bnb.nn.Linear4bit`: [QLoRA: Efficient Finetuning of Quantized + LLMs](https://huggingface.co/papers/2305.14314) + + Parameters: + model (`torch.nn.Module`): + Input model or `torch.nn.Module` as the function is run recursively. + modules_to_not_convert (`list[`str`]`, *optional*, defaults to `[]`): + Names of the modules to not convert in `Linear8bitLt`. In practice we keep the `modules_to_not_convert` in + full precision for numerical stability reasons. + current_key_name (`list[`str`]`, *optional*): + An array to track the current key of the recursion. This is used to check whether the current key (part of + it) is not in the list of modules to not convert (for instances modules that are offloaded to `cpu` or + `disk`). + quantization_config ('transformers.utils.quantization_config.BitsAndBytesConfig'): + To configure and manage settings related to quantization, a technique used to compress neural network + models by reducing the precision of the weights and activations, thus making models more efficient in terms + of both storage and computation. + """ + model, _ = _replace_with_bnb_linear(model, modules_to_not_convert, current_key_name, quantization_config) + + has_been_replaced = any( + isinstance(replaced_module, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)) + for _, replaced_module in model.named_modules() + ) + if not has_been_replaced: + logger.warning( + "You are loading your model in 8bit or 4bit but no linear modules were found in your model." + " Please double check your model architecture, or submit an issue on github if you think this is" + " a bug." + ) + + return model + + +# Adapted from PEFT: https://github.com/huggingface/peft/blob/6d458b300fc2ed82e19f796b53af4c97d03ea604/src/peft/utils/integrations.py#L81 +def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None, dtype: "torch.dtype" = None): + """ + Helper function to dequantize 4bit or 8bit bnb weights. + + If the weight is not a bnb quantized weight, it will be returned as is. + """ + if not isinstance(weight, torch.nn.Parameter): + raise TypeError(f"Input weight should be of type nn.Parameter, got {type(weight)} instead") + + cls_name = weight.__class__.__name__ + if cls_name not in ("Params4bit", "Int8Params"): + return weight + + if cls_name == "Params4bit": + output_tensor = bnb.functional.dequantize_4bit(weight.data, weight.quant_state) + msg = f"The model is going to be dequantized in {output_tensor.dtype} - if you want to upcast it to another dtype, make sure to pass the desired dtype when quantizing the model through `bnb_4bit_quant_type` argument of `BitsAndBytesConfig`" + if dtype: + msg = f"The model is going to be first dequantized in {output_tensor.dtype} and type-casted to {dtype}" + output_tensor = output_tensor.to(dtype) + logger.warning_once(msg) + return output_tensor + + if state.SCB is None: + state.SCB = weight.SCB + + if hasattr(bnb.functional, "int8_vectorwise_dequant"): + # Use bitsandbytes API if available (requires v0.45.0+) + dequantized = bnb.functional.int8_vectorwise_dequant(weight.data, state.SCB) + else: + # Multiply by (scale/127) to dequantize. + dequantized = weight.data * state.SCB.view(-1, 1) * 7.874015718698502e-3 + + if dtype: + dequantized = dequantized.to(dtype) + return dequantized + + +def _create_accelerate_new_hook(old_hook): + r""" + Creates a new hook based on the old hook. Use it only if you know what you are doing ! This method is a copy of: + https://github.com/huggingface/peft/blob/748f7968f3a31ec06a1c2b0328993319ad9a150a/src/peft/utils/other.py#L245 with + some changes + """ + old_hook_cls = getattr(accelerate.hooks, old_hook.__class__.__name__) + old_hook_attr = old_hook.__dict__ + filtered_old_hook_attr = {} + old_hook_init_signature = inspect.signature(old_hook_cls.__init__) + for k in old_hook_attr.keys(): + if k in old_hook_init_signature.parameters: + filtered_old_hook_attr[k] = old_hook_attr[k] + new_hook = old_hook_cls(**filtered_old_hook_attr) + return new_hook + + +def _dequantize_and_replace( + model, + dtype, + modules_to_not_convert=None, + current_key_name=None, + quantization_config=None, + has_been_replaced=False, +): + """ + Converts a quantized model into its dequantized original version. The newly converted model will have some + performance drop compared to the original model before quantization - use it only for specific usecases such as + QLoRA adapters merging. + + Returns the converted model and a boolean that indicates if the conversion has been successful or not. + """ + quant_method = quantization_config.quantization_method() + + target_cls = bnb.nn.Linear8bitLt if quant_method == "llm_int8" else bnb.nn.Linear4bit + + for name, module in model.named_children(): + if current_key_name is None: + current_key_name = [] + current_key_name.append(name) + + if isinstance(module, target_cls) and name not in modules_to_not_convert: + # Check if the current key is not in the `modules_to_not_convert` + current_key_name_str = ".".join(current_key_name) + + if not any( + (key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert + ): + bias = getattr(module, "bias", None) + + device = module.weight.device + with init_empty_weights(): + new_module = torch.nn.Linear(module.in_features, module.out_features, bias=bias is not None) + + if quant_method == "llm_int8": + state = module.state + else: + state = None + + new_module.weight = torch.nn.Parameter(dequantize_bnb_weight(module.weight, state, dtype)) + + if bias is not None: + new_module.bias = bias + + # Create a new hook and attach it in case we use accelerate + if hasattr(module, "_hf_hook"): + old_hook = module._hf_hook + new_hook = _create_accelerate_new_hook(old_hook) + + remove_hook_from_module(module) + add_hook_to_module(new_module, new_hook) + + new_module.to(device) + model._modules[name] = new_module + has_been_replaced = True + if len(list(module.children())) > 0: + _, has_been_replaced = _dequantize_and_replace( + module, + dtype=dtype, + modules_to_not_convert=modules_to_not_convert, + current_key_name=current_key_name, + quantization_config=quantization_config, + has_been_replaced=has_been_replaced, + ) + # Remove the last key for recursion + current_key_name.pop(-1) + return model, has_been_replaced + + +def dequantize_and_replace( + model, + modules_to_not_convert=None, + quantization_config=None, +): + model, _ = _dequantize_and_replace( + model, + dtype=model.dtype, + modules_to_not_convert=modules_to_not_convert, + quantization_config=quantization_config, + ) + has_been_replaced = any( + isinstance(replaced_module, torch.nn.Linear) for _, replaced_module in model.named_modules() + ) + if not has_been_replaced: + logger.warning( + "Some linear modules were not dequantized. This could lead to unexpected behaviour. Please check your model." + ) + + return model + + +def _check_bnb_status(module) -> bool | bool: + is_loaded_in_4bit_bnb = ( + hasattr(module, "is_loaded_in_4bit") + and module.is_loaded_in_4bit + and getattr(module, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES + ) + is_loaded_in_8bit_bnb = ( + hasattr(module, "is_loaded_in_8bit") + and module.is_loaded_in_8bit + and getattr(module, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES + ) + return is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb, is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb diff --git a/diffusers/quantizers/gguf/__init__.py b/diffusers/quantizers/gguf/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b3d9082ac803c8da97ba7bb37021445307b96f7c --- /dev/null +++ b/diffusers/quantizers/gguf/__init__.py @@ -0,0 +1 @@ +from .gguf_quantizer import GGUFQuantizer diff --git a/diffusers/quantizers/gguf/gguf_quantizer.py b/diffusers/quantizers/gguf/gguf_quantizer.py new file mode 100644 index 0000000000000000000000000000000000000000..15f39dd9605eddb44d0b169cbc492e84afdc8bca --- /dev/null +++ b/diffusers/quantizers/gguf/gguf_quantizer.py @@ -0,0 +1,171 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from ..base import DiffusersQuantizer + + +if TYPE_CHECKING: + from ...models.modeling_utils import ModelMixin + + +from ...utils import ( + get_module_from_name, + is_accelerate_available, + is_accelerate_version, + is_gguf_available, + is_gguf_version, + is_torch_available, + logging, +) + + +if is_torch_available() and is_gguf_available(): + import torch + + from .utils import ( + GGML_QUANT_SIZES, + GGUFParameter, + _dequantize_gguf_and_restore_linear, + _quant_shape_from_byte_shape, + _replace_with_gguf_linear, + ) + + +logger = logging.get_logger(__name__) + + +class GGUFQuantizer(DiffusersQuantizer): + use_keep_in_fp32_modules = True + + def __init__(self, quantization_config, **kwargs): + super().__init__(quantization_config, **kwargs) + + self.compute_dtype = quantization_config.compute_dtype + self.pre_quantized = quantization_config.pre_quantized + self.modules_to_not_convert = quantization_config.modules_to_not_convert or [] + + if not isinstance(self.modules_to_not_convert, list): + self.modules_to_not_convert = [self.modules_to_not_convert] + + def validate_environment(self, *args, **kwargs): + if not is_accelerate_available() or is_accelerate_version("<", "0.26.0"): + raise ImportError( + "Loading GGUF Parameters requires `accelerate` installed in your environment: `pip install 'accelerate>=0.26.0'`" + ) + if not is_gguf_available() or is_gguf_version("<", "0.10.0"): + raise ImportError( + "To load GGUF format files you must have `gguf` installed in your environment: `pip install gguf>=0.10.0`" + ) + + # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.adjust_max_memory + def adjust_max_memory(self, max_memory: dict[str, int | str]) -> dict[str, int | str]: + # need more space for buffers that are created during quantization + max_memory = {key: val * 0.90 for key, val in max_memory.items()} + return max_memory + + def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype": + if target_dtype != torch.uint8: + logger.info(f"target_dtype {target_dtype} is replaced by `torch.uint8` for GGUF quantization") + return torch.uint8 + + def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": + if torch_dtype is None: + torch_dtype = self.compute_dtype + return torch_dtype + + def check_quantized_param_shape(self, param_name, current_param, loaded_param): + loaded_param_shape = loaded_param.shape + current_param_shape = current_param.shape + quant_type = loaded_param.quant_type + + block_size, type_size = GGML_QUANT_SIZES[quant_type] + + inferred_shape = _quant_shape_from_byte_shape(loaded_param_shape, type_size, block_size) + if inferred_shape != current_param_shape: + raise ValueError( + f"{param_name} has an expected quantized shape of: {inferred_shape}, but received shape: {loaded_param_shape}" + ) + + return True + + def check_if_quantized_param( + self, + model: "ModelMixin", + param_value: "GGUFParameter" | "torch.Tensor", + param_name: str, + state_dict: dict[str, Any], + **kwargs, + ) -> bool: + if isinstance(param_value, GGUFParameter): + return True + + return False + + def create_quantized_param( + self, + model: "ModelMixin", + param_value: "GGUFParameter" | "torch.Tensor", + param_name: str, + target_device: "torch.device", + state_dict: dict[str, Any] | None = None, + unexpected_keys: list[str] | None = None, + **kwargs, + ): + module, tensor_name = get_module_from_name(model, param_name) + if tensor_name not in module._parameters and tensor_name not in module._buffers: + raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.") + + if tensor_name in module._parameters: + module._parameters[tensor_name] = param_value.to(target_device) + if tensor_name in module._buffers: + module._buffers[tensor_name] = param_value.to(target_device) + + def _process_model_before_weight_loading( + self, + model: "ModelMixin", + device_map, + keep_in_fp32_modules: list[str] = [], + **kwargs, + ): + state_dict = kwargs.get("state_dict", None) + + self.modules_to_not_convert.extend(keep_in_fp32_modules) + self.modules_to_not_convert = [module for module in self.modules_to_not_convert if module is not None] + + _replace_with_gguf_linear( + model, self.compute_dtype, state_dict, modules_to_not_convert=self.modules_to_not_convert + ) + + def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs): + return model + + @property + def is_serializable(self): + return False + + @property + def is_trainable(self) -> bool: + return False + + @property + def is_compileable(self) -> bool: + return True + + def _dequantize(self, model): + is_model_on_cpu = model.device.type == "cpu" + if is_model_on_cpu: + logger.info( + "Model was found to be on CPU (could happen as a result of `enable_model_cpu_offload()`). So, moving it to accelerator. After dequantization, will move the model back to CPU again to preserve the previous device." + ) + device = ( + torch.accelerator.current_accelerator() + if hasattr(torch, "accelerator") + else torch.cuda.current_device() + ) + model.to(device) + + model = _dequantize_gguf_and_restore_linear(model, self.modules_to_not_convert) + if is_model_on_cpu: + model.to("cpu") + return model diff --git a/diffusers/quantizers/gguf/utils.py b/diffusers/quantizers/gguf/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e0ad0e1cce424c840e64b85695735bcc93c9cd86 --- /dev/null +++ b/diffusers/quantizers/gguf/utils.py @@ -0,0 +1,610 @@ +# Copyright 2025 The HuggingFace Team and City96. All rights reserved. +# # +# # Licensed under the Apache License, Version 2.0 (the "License"); +# # you may not use this file except in compliance with the License. +# # You may obtain a copy of the License at +# # +# # http://www.apache.org/licenses/LICENSE-2.0 +# # +# # Unless required by applicable law or agreed to in writing, software +# # distributed under the License is distributed on an "AS IS" BASIS, +# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# # See the License for the specific language governing permissions and +# # limitations under the License. + +import inspect +import os +from contextlib import nullcontext + +import gguf +import torch +import torch.nn as nn + +from ...utils import is_accelerate_available, is_kernels_available + + +if is_accelerate_available(): + import accelerate + from accelerate import init_empty_weights + from accelerate.hooks import add_hook_to_module, remove_hook_from_module + + +can_use_cuda_kernels = ( + os.getenv("DIFFUSERS_GGUF_CUDA_KERNELS", "false").lower() in ["1", "true", "yes"] + and torch.cuda.is_available() + and torch.cuda.get_device_capability()[0] >= 7 +) +if can_use_cuda_kernels and is_kernels_available(): + from kernels import get_kernel + + ops = get_kernel("Isotr0py/ggml") +else: + ops = None + +UNQUANTIZED_TYPES = {gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16, gguf.GGMLQuantizationType.BF16} +STANDARD_QUANT_TYPES = { + gguf.GGMLQuantizationType.Q4_0, + gguf.GGMLQuantizationType.Q4_1, + gguf.GGMLQuantizationType.Q5_0, + gguf.GGMLQuantizationType.Q5_1, + gguf.GGMLQuantizationType.Q8_0, + gguf.GGMLQuantizationType.Q8_1, +} +KQUANT_TYPES = { + gguf.GGMLQuantizationType.Q2_K, + gguf.GGMLQuantizationType.Q3_K, + gguf.GGMLQuantizationType.Q4_K, + gguf.GGMLQuantizationType.Q5_K, + gguf.GGMLQuantizationType.Q6_K, +} +IMATRIX_QUANT_TYPES = { + gguf.GGMLQuantizationType.IQ1_M, + gguf.GGMLQuantizationType.IQ1_S, + gguf.GGMLQuantizationType.IQ2_XXS, + gguf.GGMLQuantizationType.IQ2_XS, + gguf.GGMLQuantizationType.IQ2_S, + gguf.GGMLQuantizationType.IQ3_XXS, + gguf.GGMLQuantizationType.IQ3_S, + gguf.GGMLQuantizationType.IQ4_XS, + gguf.GGMLQuantizationType.IQ4_NL, +} +# TODO(Isotr0py): Currently, we don't have MMQ kernel for I-Matrix quantization. +# Consolidate DEQUANT_TYPES, MMVQ_QUANT_TYPES and MMQ_QUANT_TYPES after we add +# MMQ kernel for I-Matrix quantization. +DEQUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES +MMVQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES +MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES + + +def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor, qweight_type: int) -> torch.Tensor: + # there is no need to call any kernel for fp16/bf16 + if qweight_type in UNQUANTIZED_TYPES: + weight = dequantize_gguf_tensor(qweight) + return x @ weight.T + + # TODO(Isotr0py): GGUF's MMQ and MMVQ implementation are designed for + # contiguous batching and inefficient with diffusers' batching, + # so we disabled it now. + + # elif qweight_type in MMVQ_QUANT_TYPES: + # y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0]) + # elif qweight_type in MMQ_QUANT_TYPES: + # y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0]) + + # If there is no available MMQ kernel, fallback to dequantize + if qweight_type in DEQUANT_TYPES: + block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type] + shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size) + weight = ops.ggml_dequantize(qweight, qweight_type, *shape) + y = x @ weight.to(x.dtype).T + else: + # Raise an error if the quantization type is not supported. + # Might be useful if llama.cpp adds a new quantization type. + # Wrap to GGMLQuantizationType IntEnum to make sure it's a valid type. + qweight_type = gguf.GGMLQuantizationType(qweight_type) + raise NotImplementedError(f"Unsupported GGUF quantization type: {qweight_type}") + return y.as_tensor() + + +# Copied from diffusers.quantizers.bitsandbytes.utils._create_accelerate_new_hook +def _create_accelerate_new_hook(old_hook): + r""" + Creates a new hook based on the old hook. Use it only if you know what you are doing ! This method is a copy of: + https://github.com/huggingface/peft/blob/748f7968f3a31ec06a1c2b0328993319ad9a150a/src/peft/utils/other.py#L245 with + some changes + """ + old_hook_cls = getattr(accelerate.hooks, old_hook.__class__.__name__) + old_hook_attr = old_hook.__dict__ + filtered_old_hook_attr = {} + old_hook_init_signature = inspect.signature(old_hook_cls.__init__) + for k in old_hook_attr.keys(): + if k in old_hook_init_signature.parameters: + filtered_old_hook_attr[k] = old_hook_attr[k] + new_hook = old_hook_cls(**filtered_old_hook_attr) + return new_hook + + +def _replace_with_gguf_linear(model, compute_dtype, state_dict, prefix="", modules_to_not_convert=[]): + def _should_convert_to_gguf(state_dict, prefix): + weight_key = prefix + "weight" + return weight_key in state_dict and isinstance(state_dict[weight_key], GGUFParameter) + + has_children = list(model.children()) + if not has_children: + return + + for name, module in model.named_children(): + module_prefix = prefix + name + "." + _replace_with_gguf_linear(module, compute_dtype, state_dict, module_prefix, modules_to_not_convert) + + if ( + isinstance(module, nn.Linear) + and _should_convert_to_gguf(state_dict, module_prefix) + and name not in modules_to_not_convert + ): + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + model._modules[name] = GGUFLinear( + module.in_features, + module.out_features, + module.bias is not None, + compute_dtype=compute_dtype, + ) + model._modules[name].source_cls = type(module) + # Force requires_grad to False to avoid unexpected errors + model._modules[name].requires_grad_(False) + + return model + + +def _dequantize_gguf_and_restore_linear(model, modules_to_not_convert=[]): + for name, module in model.named_children(): + if isinstance(module, GGUFLinear) and name not in modules_to_not_convert: + device = module.weight.device + bias = getattr(module, "bias", None) + + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + new_module = nn.Linear( + module.in_features, + module.out_features, + module.bias is not None, + device=device, + ) + new_module.weight = nn.Parameter(dequantize_gguf_tensor(module.weight)) + if bias is not None: + new_module.bias = bias + + # Create a new hook and attach it in case we use accelerate + if hasattr(module, "_hf_hook"): + old_hook = module._hf_hook + new_hook = _create_accelerate_new_hook(old_hook) + + remove_hook_from_module(module) + add_hook_to_module(new_module, new_hook) + + new_module.to(device) + model._modules[name] = new_module + + has_children = list(module.children()) + if has_children: + _dequantize_gguf_and_restore_linear(module, modules_to_not_convert) + + return model + + +# dequantize operations based on torch ports of GGUF dequantize_functions +# from City96 +# more info: https://github.com/city96/ComfyUI-GGUF/blob/main/dequant.py + + +QK_K = 256 +K_SCALE_SIZE = 12 + + +def to_uint32(x): + x = x.view(torch.uint8).to(torch.int32) + return (x[:, 0] | x[:, 1] << 8 | x[:, 2] << 16 | x[:, 3] << 24).unsqueeze(1) + + +def split_block_dims(blocks, *args): + n_max = blocks.shape[1] + dims = list(args) + [n_max - sum(args)] + return torch.split(blocks, dims, dim=1) + + +def get_scale_min(scales): + n_blocks = scales.shape[0] + scales = scales.view(torch.uint8) + scales = scales.reshape((n_blocks, 3, 4)) + + d, m, m_d = torch.split(scales, scales.shape[-2] // 3, dim=-2) + + sc = torch.cat([d & 0x3F, (m_d & 0x0F) | ((d >> 2) & 0x30)], dim=-1) + min = torch.cat([m & 0x3F, (m_d >> 4) | ((m >> 2) & 0x30)], dim=-1) + + return (sc.reshape((n_blocks, 8)), min.reshape((n_blocks, 8))) + + +def dequantize_blocks_Q8_0(blocks, block_size, type_size, dtype=None): + d, x = split_block_dims(blocks, 2) + d = d.view(torch.float16).to(dtype) + x = x.view(torch.int8) + return d * x + + +def dequantize_blocks_Q5_1(blocks, block_size, type_size, dtype=None): + n_blocks = blocks.shape[0] + + d, m, qh, qs = split_block_dims(blocks, 2, 2, 4) + d = d.view(torch.float16).to(dtype) + m = m.view(torch.float16).to(dtype) + qh = to_uint32(qh) + + qh = qh.reshape((n_blocks, 1)) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32) + ql = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor( + [0, 4], device=d.device, dtype=torch.uint8 + ).reshape(1, 1, 2, 1) + qh = (qh & 1).to(torch.uint8) + ql = (ql & 0x0F).reshape((n_blocks, -1)) + + qs = ql | (qh << 4) + return (d * qs) + m + + +def dequantize_blocks_Q5_0(blocks, block_size, type_size, dtype=None): + n_blocks = blocks.shape[0] + + d, qh, qs = split_block_dims(blocks, 2, 4) + d = d.view(torch.float16).to(dtype) + qh = to_uint32(qh) + + qh = qh.reshape(n_blocks, 1) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32) + ql = qs.reshape(n_blocks, -1, 1, block_size // 2) >> torch.tensor( + [0, 4], device=d.device, dtype=torch.uint8 + ).reshape(1, 1, 2, 1) + + qh = (qh & 1).to(torch.uint8) + ql = (ql & 0x0F).reshape(n_blocks, -1) + + qs = (ql | (qh << 4)).to(torch.int8) - 16 + return d * qs + + +def dequantize_blocks_Q4_1(blocks, block_size, type_size, dtype=None): + n_blocks = blocks.shape[0] + + d, m, qs = split_block_dims(blocks, 2, 2) + d = d.view(torch.float16).to(dtype) + m = m.view(torch.float16).to(dtype) + + qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor( + [0, 4], device=d.device, dtype=torch.uint8 + ).reshape(1, 1, 2, 1) + qs = (qs & 0x0F).reshape(n_blocks, -1) + + return (d * qs) + m + + +def dequantize_blocks_Q4_0(blocks, block_size, type_size, dtype=None): + n_blocks = blocks.shape[0] + + d, qs = split_block_dims(blocks, 2) + d = d.view(torch.float16).to(dtype) + + qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor( + [0, 4], device=d.device, dtype=torch.uint8 + ).reshape((1, 1, 2, 1)) + qs = (qs & 0x0F).reshape((n_blocks, -1)).to(torch.int8) - 8 + return d * qs + + +def dequantize_blocks_Q6_K(blocks, block_size, type_size, dtype=None): + n_blocks = blocks.shape[0] + + ( + ql, + qh, + scales, + d, + ) = split_block_dims(blocks, QK_K // 2, QK_K // 4, QK_K // 16) + + scales = scales.view(torch.int8).to(dtype) + d = d.view(torch.float16).to(dtype) + d = (d * scales).reshape((n_blocks, QK_K // 16, 1)) + + ql = ql.reshape((n_blocks, -1, 1, 64)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape( + (1, 1, 2, 1) + ) + ql = (ql & 0x0F).reshape((n_blocks, -1, 32)) + qh = qh.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape( + (1, 1, 4, 1) + ) + qh = (qh & 0x03).reshape((n_blocks, -1, 32)) + q = (ql | (qh << 4)).to(torch.int8) - 32 + q = q.reshape((n_blocks, QK_K // 16, -1)) + + return (d * q).reshape((n_blocks, QK_K)) + + +def dequantize_blocks_Q5_K(blocks, block_size, type_size, dtype=None): + n_blocks = blocks.shape[0] + + d, dmin, scales, qh, qs = split_block_dims(blocks, 2, 2, K_SCALE_SIZE, QK_K // 8) + + d = d.view(torch.float16).to(dtype) + dmin = dmin.view(torch.float16).to(dtype) + + sc, m = get_scale_min(scales) + + d = (d * sc).reshape((n_blocks, -1, 1)) + dm = (dmin * m).reshape((n_blocks, -1, 1)) + + ql = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape( + (1, 1, 2, 1) + ) + qh = qh.reshape((n_blocks, -1, 1, 32)) >> torch.arange(0, 8, device=d.device, dtype=torch.uint8).reshape( + (1, 1, 8, 1) + ) + ql = (ql & 0x0F).reshape((n_blocks, -1, 32)) + qh = (qh & 0x01).reshape((n_blocks, -1, 32)) + q = ql | (qh << 4) + + return (d * q - dm).reshape((n_blocks, QK_K)) + + +def dequantize_blocks_Q4_K(blocks, block_size, type_size, dtype=None): + n_blocks = blocks.shape[0] + + d, dmin, scales, qs = split_block_dims(blocks, 2, 2, K_SCALE_SIZE) + d = d.view(torch.float16).to(dtype) + dmin = dmin.view(torch.float16).to(dtype) + + sc, m = get_scale_min(scales) + + d = (d * sc).reshape((n_blocks, -1, 1)) + dm = (dmin * m).reshape((n_blocks, -1, 1)) + + qs = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape( + (1, 1, 2, 1) + ) + qs = (qs & 0x0F).reshape((n_blocks, -1, 32)) + + return (d * qs - dm).reshape((n_blocks, QK_K)) + + +def dequantize_blocks_Q3_K(blocks, block_size, type_size, dtype=None): + n_blocks = blocks.shape[0] + + hmask, qs, scales, d = split_block_dims(blocks, QK_K // 8, QK_K // 4, 12) + d = d.view(torch.float16).to(dtype) + + lscales, hscales = scales[:, :8], scales[:, 8:] + lscales = lscales.reshape((n_blocks, 1, 8)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape( + (1, 2, 1) + ) + lscales = lscales.reshape((n_blocks, 16)) + hscales = hscales.reshape((n_blocks, 1, 4)) >> torch.tensor( + [0, 2, 4, 6], device=d.device, dtype=torch.uint8 + ).reshape((1, 4, 1)) + hscales = hscales.reshape((n_blocks, 16)) + scales = (lscales & 0x0F) | ((hscales & 0x03) << 4) + scales = scales.to(torch.int8) - 32 + + dl = (d * scales).reshape((n_blocks, 16, 1)) + + ql = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape( + (1, 1, 4, 1) + ) + qh = hmask.reshape(n_blocks, -1, 1, 32) >> torch.arange(0, 8, device=d.device, dtype=torch.uint8).reshape( + (1, 1, 8, 1) + ) + ql = ql.reshape((n_blocks, 16, QK_K // 16)) & 3 + qh = (qh.reshape((n_blocks, 16, QK_K // 16)) & 1) ^ 1 + q = ql.to(torch.int8) - (qh << 2).to(torch.int8) + + return (dl * q).reshape((n_blocks, QK_K)) + + +def dequantize_blocks_Q2_K(blocks, block_size, type_size, dtype=None): + n_blocks = blocks.shape[0] + + scales, qs, d, dmin = split_block_dims(blocks, QK_K // 16, QK_K // 4, 2) + d = d.view(torch.float16).to(dtype) + dmin = dmin.view(torch.float16).to(dtype) + + # (n_blocks, 16, 1) + dl = (d * (scales & 0xF)).reshape((n_blocks, QK_K // 16, 1)) + ml = (dmin * (scales >> 4)).reshape((n_blocks, QK_K // 16, 1)) + + shift = torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape((1, 1, 4, 1)) + + qs = (qs.reshape((n_blocks, -1, 1, 32)) >> shift) & 3 + qs = qs.reshape((n_blocks, QK_K // 16, 16)) + qs = dl * qs - ml + + return qs.reshape((n_blocks, -1)) + + +def dequantize_blocks_BF16(blocks, block_size, type_size, dtype=None): + return (blocks.view(torch.int16).to(torch.int32) << 16).view(torch.float32) + + +# this part from calcuis (gguf.org) +# more info: https://github.com/calcuis/gguf-connector/blob/main/src/gguf_connector/quant2c.py + + +def dequantize_blocks_IQ4_NL(blocks, block_size, type_size, dtype=None): + kvalues = torch.tensor( + [-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113], + dtype=torch.float32, + device=blocks.device, + ) + n_blocks = blocks.shape[0] + d, qs = split_block_dims(blocks, 2) + d = d.view(torch.float16).to(dtype) + qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor( + [0, 4], device=blocks.device, dtype=torch.uint8 + ).reshape((1, 1, 2, 1)) + qs = (qs & 15).reshape((n_blocks, -1)).to(torch.int64) + kvalues = kvalues.view(1, 1, 16) + qs = qs.unsqueeze(-1) + qs = torch.gather(kvalues.expand(qs.shape[0], qs.shape[1], 16), 2, qs) + qs = qs.squeeze(-1).to(dtype) + return d * qs + + +def dequantize_blocks_IQ4_XS(blocks, block_size, type_size, dtype=None): + kvalues = torch.tensor( + [-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113], + dtype=torch.float32, + device=blocks.device, + ) + n_blocks = blocks.shape[0] + d, scales_h, scales_l, qs = split_block_dims(blocks, 2, 2, QK_K // 64) + d = d.view(torch.float16).to(dtype) + scales_h = scales_h.view(torch.int16) + scales_l = scales_l.reshape((n_blocks, -1, 1)) >> torch.tensor( + [0, 4], device=blocks.device, dtype=torch.uint8 + ).reshape((1, 1, 2)) + scales_h = scales_h.reshape((n_blocks, 1, -1)) >> torch.tensor( + [2 * i for i in range(QK_K // 32)], device=blocks.device, dtype=torch.uint8 + ).reshape((1, -1, 1)) + scales_l = scales_l.reshape((n_blocks, -1)) & 0x0F + scales_h = scales_h.reshape((n_blocks, -1)) & 0x03 + scales = (scales_l | (scales_h << 4)) - 32 + dl = (d * scales.to(dtype)).reshape((n_blocks, -1, 1)) + shifts_q = torch.tensor([0, 4], device=blocks.device, dtype=torch.uint8).reshape(1, 1, 2, 1) + qs = qs.reshape((n_blocks, -1, 1, 16)) >> shifts_q + qs = (qs & 15).reshape((n_blocks, -1, 32)).to(torch.int64) + kvalues = kvalues.view(1, 1, 1, 16) + qs = qs.unsqueeze(-1) + qs = torch.gather(kvalues.expand(qs.shape[0], qs.shape[1], qs.shape[2], 16), 3, qs) + qs = qs.squeeze(-1).to(dtype) + return (dl * qs).reshape(n_blocks, -1) + + +GGML_QUANT_SIZES = gguf.GGML_QUANT_SIZES +dequantize_functions = { + gguf.GGMLQuantizationType.IQ4_NL: dequantize_blocks_IQ4_NL, + gguf.GGMLQuantizationType.IQ4_XS: dequantize_blocks_IQ4_XS, + gguf.GGMLQuantizationType.BF16: dequantize_blocks_BF16, + gguf.GGMLQuantizationType.Q8_0: dequantize_blocks_Q8_0, + gguf.GGMLQuantizationType.Q5_1: dequantize_blocks_Q5_1, + gguf.GGMLQuantizationType.Q5_0: dequantize_blocks_Q5_0, + gguf.GGMLQuantizationType.Q4_1: dequantize_blocks_Q4_1, + gguf.GGMLQuantizationType.Q4_0: dequantize_blocks_Q4_0, + gguf.GGMLQuantizationType.Q6_K: dequantize_blocks_Q6_K, + gguf.GGMLQuantizationType.Q5_K: dequantize_blocks_Q5_K, + gguf.GGMLQuantizationType.Q4_K: dequantize_blocks_Q4_K, + gguf.GGMLQuantizationType.Q3_K: dequantize_blocks_Q3_K, + gguf.GGMLQuantizationType.Q2_K: dequantize_blocks_Q2_K, +} +SUPPORTED_GGUF_QUANT_TYPES = list(dequantize_functions.keys()) + + +def _quant_shape_from_byte_shape(shape, type_size, block_size): + return (*shape[:-1], shape[-1] // type_size * block_size) + + +def dequantize_gguf_tensor(tensor): + if not hasattr(tensor, "quant_type"): + return tensor + + quant_type = tensor.quant_type + dequant_fn = dequantize_functions[quant_type] + + block_size, type_size = GGML_QUANT_SIZES[quant_type] + + # Conver to plain tensor to avoid unnecessary __torch_function__ overhead. + tensor = tensor.as_tensor() + + tensor = tensor.view(torch.uint8) + shape = _quant_shape_from_byte_shape(tensor.shape, type_size, block_size) + + n_blocks = tensor.numel() // type_size + blocks = tensor.reshape((n_blocks, type_size)) + + dequant = dequant_fn(blocks, block_size, type_size) + dequant = dequant.reshape(shape) + + return dequant + + +class GGUFParameter(torch.nn.Parameter): + def __new__(cls, data, requires_grad=False, quant_type=None): + data = data if data is not None else torch.empty(0) + self = torch.Tensor._make_subclass(cls, data, requires_grad) + self.quant_type = quant_type + block_size, type_size = GGML_QUANT_SIZES[quant_type] + self.quant_shape = _quant_shape_from_byte_shape(self.shape, type_size, block_size) + + return self + + def as_tensor(self): + return torch.Tensor._make_subclass(torch.Tensor, self, self.requires_grad) + + @staticmethod + def _extract_quant_type(args): + # When converting from original format checkpoints we often use splits, cats etc on tensors + # this method ensures that the returned tensor type from those operations remains GGUFParameter + # so that we preserve quant_type information + for arg in args: + if isinstance(arg, list) and isinstance(arg[0], GGUFParameter): + return arg[0].quant_type + if isinstance(arg, GGUFParameter): + return arg.quant_type + return None + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + + result = super().__torch_function__(func, types, args, kwargs) + + if isinstance(result, torch.Tensor): + quant_type = cls._extract_quant_type(args) + return cls(result, quant_type=quant_type) + # Handle tuples and lists + elif type(result) in (list, tuple): + # Preserve the original type (tuple or list) + quant_type = cls._extract_quant_type(args) + wrapped = [cls(x, quant_type=quant_type) if isinstance(x, torch.Tensor) else x for x in result] + return type(result)(wrapped) + else: + return result + + +class GGUFLinear(nn.Linear): + def __init__( + self, + in_features, + out_features, + bias=False, + compute_dtype=None, + device=None, + ) -> None: + super().__init__(in_features, out_features, bias, device) + self.compute_dtype = compute_dtype + self.device = device + + def forward(self, inputs: torch.Tensor): + if ops is not None and self.weight.is_cuda and inputs.is_cuda: + return self.forward_cuda(inputs) + return self.forward_native(inputs) + + def forward_native(self, inputs: torch.Tensor): + weight = dequantize_gguf_tensor(self.weight) + weight = weight.to(self.compute_dtype) + bias = self.bias.to(self.compute_dtype) if self.bias is not None else None + + output = torch.nn.functional.linear(inputs, weight, bias) + return output + + def forward_cuda(self, inputs: torch.Tensor): + quant_type = self.weight.quant_type + output = _fused_mul_mat_gguf(inputs.to(self.compute_dtype), self.weight, quant_type) + if self.bias is not None: + output += self.bias.to(self.compute_dtype) + return output diff --git a/diffusers/quantizers/modelopt/__init__.py b/diffusers/quantizers/modelopt/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ae0951cb30d140011f6b3595036f3ca225b423f4 --- /dev/null +++ b/diffusers/quantizers/modelopt/__init__.py @@ -0,0 +1 @@ +from .modelopt_quantizer import NVIDIAModelOptQuantizer diff --git a/diffusers/quantizers/modelopt/modelopt_quantizer.py b/diffusers/quantizers/modelopt/modelopt_quantizer.py new file mode 100644 index 0000000000000000000000000000000000000000..5c9e198c82b2a1fda338cda60cf80a9adbb71c1e --- /dev/null +++ b/diffusers/quantizers/modelopt/modelopt_quantizer.py @@ -0,0 +1,190 @@ +from typing import TYPE_CHECKING, Any + +from ...utils import ( + get_module_from_name, + is_accelerate_available, + is_nvidia_modelopt_available, + is_torch_available, + logging, +) +from ..base import DiffusersQuantizer + + +if TYPE_CHECKING: + from ...models.modeling_utils import ModelMixin + + +if is_torch_available(): + import torch + import torch.nn as nn + +if is_accelerate_available(): + from accelerate.utils import set_module_tensor_to_device + + +logger = logging.get_logger(__name__) + + +class NVIDIAModelOptQuantizer(DiffusersQuantizer): + r""" + Diffusers Quantizer for Nvidia-Model Optimizer + """ + + use_keep_in_fp32_modules = True + requires_calibration = False + required_packages = ["nvidia_modelopt"] + + def __init__(self, quantization_config, **kwargs): + super().__init__(quantization_config, **kwargs) + + def validate_environment(self, *args, **kwargs): + if not is_nvidia_modelopt_available(): + raise ImportError( + "Loading an nvidia-modelopt quantized model requires nvidia-modelopt library (`pip install nvidia-modelopt`)" + ) + + self.offload = False + + device_map = kwargs.get("device_map", None) + if isinstance(device_map, dict): + if "cpu" in device_map.values() or "disk" in device_map.values(): + if self.pre_quantized: + raise ValueError( + "You are attempting to perform cpu/disk offload with a pre-quantized modelopt model " + "This is not supported yet. Please remove the CPU or disk device from the `device_map` argument." + ) + else: + self.offload = True + + def check_if_quantized_param( + self, + model: "ModelMixin", + param_value: "torch.Tensor", + param_name: str, + state_dict: dict[str, Any], + **kwargs, + ): + # ModelOpt imports diffusers internally. This is here to prevent circular imports + from modelopt.torch.quantization.utils import is_quantized + + module, tensor_name = get_module_from_name(model, param_name) + if self.pre_quantized: + return True + elif is_quantized(module) and "weight" in tensor_name: + return True + return False + + def create_quantized_param( + self, + model: "ModelMixin", + param_value: "torch.Tensor", + param_name: str, + target_device: "torch.device", + *args, + **kwargs, + ): + """ + Create the quantized parameter by calling .calibrate() after setting it to the module. + """ + # ModelOpt imports diffusers internally. This is here to prevent circular imports + import modelopt.torch.quantization as mtq + + dtype = kwargs.get("dtype", torch.float32) + module, tensor_name = get_module_from_name(model, param_name) + if self.pre_quantized: + module._parameters[tensor_name] = torch.nn.Parameter(param_value.to(device=target_device)) + else: + set_module_tensor_to_device(model, param_name, target_device, param_value, dtype) + mtq.calibrate( + module, self.quantization_config.modelopt_config["algorithm"], self.quantization_config.forward_loop + ) + mtq.compress(module) + module.weight.requires_grad = False + + def adjust_max_memory(self, max_memory: dict[str, int | str]) -> dict[str, int | str]: + max_memory = {key: val * 0.90 for key, val in max_memory.items()} + return max_memory + + def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype": + if self.quantization_config.quant_type == "FP8": + target_dtype = torch.float8_e4m3fn + return target_dtype + + def update_torch_dtype(self, torch_dtype: "torch.dtype" = None) -> "torch.dtype": + if torch_dtype is None: + logger.info("You did not specify `torch_dtype` in `from_pretrained`. Setting it to `torch.float32`.") + torch_dtype = torch.float32 + return torch_dtype + + def get_conv_param_names(self, model: "ModelMixin") -> list[str]: + """ + Get parameter names for all convolutional layers in a HuggingFace ModelMixin. Includes Conv1d/2d/3d and + ConvTranspose1d/2d/3d. + """ + conv_types = ( + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nn.ConvTranspose1d, + nn.ConvTranspose2d, + nn.ConvTranspose3d, + ) + + conv_param_names = [] + for name, module in model.named_modules(): + if isinstance(module, conv_types): + for param_name, _ in module.named_parameters(recurse=False): + conv_param_names.append(f"{name}.{param_name}") + + return conv_param_names + + def _process_model_before_weight_loading( + self, + model: "ModelMixin", + device_map, + keep_in_fp32_modules: list[str] = [], + **kwargs, + ): + # ModelOpt imports diffusers internally. This is here to prevent circular imports + import modelopt.torch.opt as mto + + if self.pre_quantized: + return + + modules_to_not_convert = self.quantization_config.modules_to_not_convert + + if modules_to_not_convert is None: + modules_to_not_convert = [] + if isinstance(modules_to_not_convert, str): + modules_to_not_convert = [modules_to_not_convert] + modules_to_not_convert.extend(keep_in_fp32_modules) + if self.quantization_config.disable_conv_quantization: + modules_to_not_convert.extend(self.get_conv_param_names(model)) + + for module in modules_to_not_convert: + self.quantization_config.modelopt_config["quant_cfg"]["*" + module + "*"] = {"enable": False} + self.quantization_config.modules_to_not_convert = modules_to_not_convert + mto.apply_mode(model, mode=[("quantize", self.quantization_config.modelopt_config)]) + model.config.quantization_config = self.quantization_config + + def _process_model_after_weight_loading(self, model, **kwargs): + # ModelOpt imports diffusers internally. This is here to prevent circular imports + from modelopt.torch.opt import ModeloptStateManager + + if self.pre_quantized: + return model + + for _, m in model.named_modules(): + if hasattr(m, ModeloptStateManager._state_key) and m is not model: + ModeloptStateManager.remove_state(m) + + return model + + @property + def is_trainable(self): + return True + + @property + def is_serializable(self): + self.quantization_config.check_model_patching(operation="saving") + return True diff --git a/diffusers/quantizers/quanto/__init__.py b/diffusers/quantizers/quanto/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a4e8a1f41a1e4bf3bbb53c5684398277eb671a83 --- /dev/null +++ b/diffusers/quantizers/quanto/__init__.py @@ -0,0 +1 @@ +from .quanto_quantizer import QuantoQuantizer diff --git a/diffusers/quantizers/quanto/quanto_quantizer.py b/diffusers/quantizers/quanto/quanto_quantizer.py new file mode 100644 index 0000000000000000000000000000000000000000..a036dabfe6f4bd971860a7be1c067d7e1066e09e --- /dev/null +++ b/diffusers/quantizers/quanto/quanto_quantizer.py @@ -0,0 +1,181 @@ +from typing import TYPE_CHECKING, Any + +from diffusers.utils.import_utils import is_optimum_quanto_version + +from ...utils import ( + get_module_from_name, + is_accelerate_available, + is_accelerate_version, + is_optimum_quanto_available, + is_torch_available, + logging, +) +from ..base import DiffusersQuantizer + + +if TYPE_CHECKING: + from ...models.modeling_utils import ModelMixin + + +if is_torch_available(): + import torch + +if is_accelerate_available(): + from accelerate.utils import CustomDtype, set_module_tensor_to_device + +if is_optimum_quanto_available(): + from .utils import _replace_with_quanto_layers + +logger = logging.get_logger(__name__) + + +class QuantoQuantizer(DiffusersQuantizer): + r""" + Diffusers Quantizer for Optimum Quanto + """ + + use_keep_in_fp32_modules = True + requires_calibration = False + required_packages = ["quanto", "accelerate"] + + def __init__(self, quantization_config, **kwargs): + super().__init__(quantization_config, **kwargs) + + def validate_environment(self, *args, **kwargs): + if not is_optimum_quanto_available(): + raise ImportError( + "Loading an optimum-quanto quantized model requires optimum-quanto library (`pip install optimum-quanto`)" + ) + if not is_optimum_quanto_version(">=", "0.2.6"): + raise ImportError( + "Loading an optimum-quanto quantized model requires `optimum-quanto>=0.2.6`. " + "Please upgrade your installation with `pip install --upgrade optimum-quanto" + ) + + if not is_accelerate_available(): + raise ImportError( + "Loading an optimum-quanto quantized model requires accelerate library (`pip install accelerate`)" + ) + + device_map = kwargs.get("device_map", None) + if isinstance(device_map, dict) and len(device_map.keys()) > 1: + raise ValueError( + "`device_map` for multi-GPU inference or CPU/disk offload is currently not supported with Diffusers and the Quanto backend" + ) + + def check_if_quantized_param( + self, + model: "ModelMixin", + param_value: "torch.Tensor", + param_name: str, + state_dict: dict[str, Any], + **kwargs, + ): + # Quanto imports diffusers internally. This is here to prevent circular imports + from optimum.quanto import QModuleMixin, QTensor + from optimum.quanto.tensor.packed import PackedTensor + + module, tensor_name = get_module_from_name(model, param_name) + if self.pre_quantized and any(isinstance(module, t) for t in [QTensor, PackedTensor]): + return True + elif isinstance(module, QModuleMixin) and "weight" in tensor_name: + return not module.frozen + + return False + + def create_quantized_param( + self, + model: "ModelMixin", + param_value: "torch.Tensor", + param_name: str, + target_device: "torch.device", + *args, + **kwargs, + ): + """ + Create the quantized parameter by calling .freeze() after setting it to the module. + """ + + dtype = kwargs.get("dtype", torch.float32) + module, tensor_name = get_module_from_name(model, param_name) + if self.pre_quantized: + setattr(module, tensor_name, param_value) + else: + set_module_tensor_to_device(model, param_name, target_device, param_value, dtype) + module.freeze() + module.weight.requires_grad = False + + def adjust_max_memory(self, max_memory: dict[str, int | str]) -> dict[str, int | str]: + max_memory = {key: val * 0.90 for key, val in max_memory.items()} + return max_memory + + def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype": + if is_accelerate_version(">=", "0.27.0"): + mapping = { + "int8": torch.int8, + "float8": CustomDtype.FP8, + "int4": CustomDtype.INT4, + "int2": CustomDtype.INT2, + } + target_dtype = mapping[self.quantization_config.weights_dtype] + + return target_dtype + + def update_torch_dtype(self, torch_dtype: "torch.dtype" = None) -> "torch.dtype": + if torch_dtype is None: + logger.info("You did not specify `torch_dtype` in `from_pretrained`. Setting it to `torch.float32`.") + torch_dtype = torch.float32 + return torch_dtype + + def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> list[str]: + # Quanto imports diffusers internally. This is here to prevent circular imports + from optimum.quanto import QModuleMixin + + not_missing_keys = [] + for name, module in model.named_modules(): + if isinstance(module, QModuleMixin): + for missing in missing_keys: + if ( + (name in missing or name in f"{prefix}.{missing}") + and not missing.endswith(".weight") + and not missing.endswith(".bias") + ): + not_missing_keys.append(missing) + return [k for k in missing_keys if k not in not_missing_keys] + + def _process_model_before_weight_loading( + self, + model: "ModelMixin", + device_map, + keep_in_fp32_modules: list[str] = [], + **kwargs, + ): + self.modules_to_not_convert = self.quantization_config.modules_to_not_convert + + if not isinstance(self.modules_to_not_convert, list): + self.modules_to_not_convert = [self.modules_to_not_convert] + + self.modules_to_not_convert.extend(keep_in_fp32_modules) + + model = _replace_with_quanto_layers( + model, + modules_to_not_convert=self.modules_to_not_convert, + quantization_config=self.quantization_config, + pre_quantized=self.pre_quantized, + ) + model.config.quantization_config = self.quantization_config + + def _process_model_after_weight_loading(self, model, **kwargs): + return model + + @property + def is_trainable(self): + return True + + @property + def is_serializable(self): + return True + + @property + def is_compileable(self) -> bool: + return True diff --git a/diffusers/quantizers/quanto/utils.py b/diffusers/quantizers/quanto/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6f41fd36b43a7c222de17a297523939d27f34537 --- /dev/null +++ b/diffusers/quantizers/quanto/utils.py @@ -0,0 +1,60 @@ +import torch.nn as nn + +from ...utils import is_accelerate_available, logging + + +logger = logging.get_logger(__name__) + +if is_accelerate_available(): + from accelerate import init_empty_weights + + +def _replace_with_quanto_layers(model, quantization_config, modules_to_not_convert: list, pre_quantized=False): + # Quanto imports diffusers internally. These are placed here to avoid circular imports + from optimum.quanto import QLinear, freeze, qfloat8, qint2, qint4, qint8 + + def _get_weight_type(dtype: str): + return {"float8": qfloat8, "int8": qint8, "int4": qint4, "int2": qint2}[dtype] + + def _replace_layers(model, quantization_config, modules_to_not_convert): + has_children = list(model.children()) + if not has_children: + return model + + for name, module in model.named_children(): + _replace_layers(module, quantization_config, modules_to_not_convert) + + if name in modules_to_not_convert: + continue + + if isinstance(module, nn.Linear): + with init_empty_weights(): + qlinear = QLinear( + in_features=module.in_features, + out_features=module.out_features, + bias=module.bias is not None, + dtype=module.weight.dtype, + weights=_get_weight_type(quantization_config.weights_dtype), + ) + model._modules[name] = qlinear + model._modules[name].source_cls = type(module) + model._modules[name].requires_grad_(False) + + return model + + model = _replace_layers(model, quantization_config, modules_to_not_convert) + has_been_replaced = any(isinstance(replaced_module, QLinear) for _, replaced_module in model.named_modules()) + + if not has_been_replaced: + logger.warning( + f"{model.__class__.__name__} does not appear to have any `nn.Linear` modules. Quantization will not be applied." + " Please check your model architecture, or submit an issue on Github if you think this is a bug." + " https://github.com/huggingface/diffusers/issues/new" + ) + + # We need to freeze the pre_quantized model in order for the loaded state_dict and model state dict + # to match when trying to load weights with load_model_dict_into_meta + if pre_quantized: + freeze(model) + + return model diff --git a/diffusers/quantizers/torchao/__init__.py b/diffusers/quantizers/torchao/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c56bf54c25158fb9e7ef9ff3bc1eb651743764fb --- /dev/null +++ b/diffusers/quantizers/torchao/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .torchao_quantizer import TorchAoHfQuantizer diff --git a/diffusers/quantizers/torchao/torchao_quantizer.py b/diffusers/quantizers/torchao/torchao_quantizer.py new file mode 100644 index 0000000000000000000000000000000000000000..1679ed26a1046d2d8084bc175f30207f22fb37c7 --- /dev/null +++ b/diffusers/quantizers/torchao/torchao_quantizer.py @@ -0,0 +1,424 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Adapted from +https://github.com/huggingface/transformers/blob/3a8eb74668e9c2cc563b2f5c62fac174797063e0/src/transformers/quantizers/quantizer_torchao.py +""" + +import importlib +import re +import types +from fnmatch import fnmatch +from typing import TYPE_CHECKING, Any + +from packaging import version + +from ...utils import ( + get_module_from_name, + is_torch_available, + is_torch_version, + is_torchao_available, + is_torchao_version, + logging, +) +from ..base import DiffusersQuantizer + + +logger = logging.get_logger(__name__) + + +if TYPE_CHECKING: + from ...models.modeling_utils import ModelMixin + + +if is_torch_available(): + import torch + import torch.nn as nn + + if is_torch_version(">=", "2.5"): + SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION = ( + # At the moment, only int8 is supported for integer quantization dtypes. + # In Torch 2.6, int1-int7 will be introduced, so this can be visited in the future + # to support more quantization methods, such as intx_weight_only. + torch.int8, + torch.float8_e4m3fn, + torch.float8_e5m2, + torch.uint1, + torch.uint2, + torch.uint3, + torch.uint4, + torch.uint5, + torch.uint6, + torch.uint7, + ) + else: + SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION = ( + torch.int8, + torch.float8_e4m3fn, + torch.float8_e5m2, + ) + +if is_torchao_available(): + from torchao.quantization import quantize_ + + +def _update_torch_safe_globals(): + safe_globals = [ + (torch.uint1, "torch.uint1"), + (torch.uint2, "torch.uint2"), + (torch.uint3, "torch.uint3"), + (torch.uint4, "torch.uint4"), + (torch.uint5, "torch.uint5"), + (torch.uint6, "torch.uint6"), + (torch.uint7, "torch.uint7"), + ] + try: + from torchao.dtypes import NF4Tensor + from torchao.dtypes.uintx.uintx_layout import UintxAQTTensorImpl, UintxTensor + + safe_globals.extend([UintxTensor, UintxAQTTensorImpl, NF4Tensor]) + + # note: is_torchao_version(">=", "0.16.0") does not work correctly + # with torchao nightly, so using a ">" check which does work correctly + if is_torchao_version(">", "0.15.0"): + pass + else: + from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl + from torchao.dtypes.uintx.uint4_layout import UInt4Tensor + + safe_globals.extend([UInt4Tensor, Float8AQTTensorImpl]) + + except (ImportError, ModuleNotFoundError) as e: + logger.warning( + "Unable to import `torchao` Tensor objects. This may affect loading checkpoints serialized with `torchao`" + ) + logger.debug(e) + + finally: + torch.serialization.add_safe_globals(safe_globals=safe_globals) + + +if ( + is_torch_available() + and is_torch_version(">=", "2.6.0") + and is_torchao_available() + and is_torchao_version(">=", "0.7.0") +): + _update_torch_safe_globals() + + +def fuzzy_match_size(config_name: str) -> str | None: + """ + Extract the size digit from strings like "4weight", "8weight". Returns the digit as an integer if found, otherwise + None. + """ + config_name = config_name.lower() + + str_match = re.search(r"(\d)weight", config_name) + + if str_match: + return str_match.group(1) + + return None + + +def _quantization_type(weight): + from torchao.dtypes import AffineQuantizedTensor + from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor + + if isinstance(weight, AffineQuantizedTensor): + return f"{weight.__class__.__name__}({weight._quantization_type()})" + + if isinstance(weight, LinearActivationQuantizedTensor): + return f"{weight.__class__.__name__}(activation={weight.input_quant_func}, weight={_quantization_type(weight.original_weight_tensor)})" + + +def _linear_extra_repr(self): + weight = _quantization_type(self.weight) + if weight is None: + return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight=None" + else: + return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={weight}" + + +class TorchAoHfQuantizer(DiffusersQuantizer): + r""" + Diffusers Quantizer for TorchAO: https://github.com/pytorch/ao/. + """ + + requires_calibration = False + required_packages = ["torchao"] + + def __init__(self, quantization_config, **kwargs): + super().__init__(quantization_config, **kwargs) + + def validate_environment(self, *args, **kwargs): + if not is_torchao_available(): + raise ImportError( + "Loading a TorchAO quantized model requires the torchao library. Please install with `pip install torchao`" + ) + torchao_version = version.parse(importlib.metadata.version("torch")) + if torchao_version < version.parse("0.7.0"): + raise RuntimeError( + f"The minimum required version of `torchao` is 0.7.0, but the current version is {torchao_version}. Please upgrade with `pip install -U torchao`." + ) + + self.offload = False + + device_map = kwargs.get("device_map", None) + if isinstance(device_map, dict): + if "cpu" in device_map.values() or "disk" in device_map.values(): + if self.pre_quantized: + raise ValueError( + "You are attempting to perform cpu/disk offload with a pre-quantized torchao model " + "This is not supported yet. Please remove the CPU or disk device from the `device_map` argument." + ) + else: + self.offload = True + + if self.pre_quantized: + weights_only = kwargs.get("weights_only", None) + if weights_only: + torch_version = version.parse(importlib.metadata.version("torch")) + if torch_version < version.parse("2.5.0"): + # TODO(aryan): TorchAO is compatible with Pytorch >= 2.2 for certain quantization types. Try to see if we can support it in future + raise RuntimeError( + f"In order to use TorchAO pre-quantized model, you need to have torch>=2.5.0. However, the current version is {torch_version}." + ) + + def update_torch_dtype(self, torch_dtype): + quant_type = self.quantization_config.quant_type + if isinstance(quant_type, str) and (quant_type.startswith("int") or quant_type.startswith("uint")): + if torch_dtype is not None and torch_dtype != torch.bfloat16: + logger.warning( + f"You are trying to set torch_dtype to {torch_dtype} for int4/int8/uintx quantization, but " + f"only bfloat16 is supported right now. Please set `torch_dtype=torch.bfloat16`." + ) + + if torch_dtype is None: + # We need to set the torch_dtype, otherwise we have dtype mismatch when performing the quantized linear op + logger.warning( + "Overriding `torch_dtype` with `torch_dtype=torch.bfloat16` due to requirements of `torchao` " + "to enable model loading in different precisions. Pass your own `torch_dtype` to specify the " + "dtype of the remaining non-linear layers, or pass torch_dtype=torch.bfloat16, to remove this warning." + ) + torch_dtype = torch.bfloat16 + + return torch_dtype + + def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype": + quant_type = self.quantization_config.quant_type + from accelerate.utils import CustomDtype + + if isinstance(quant_type, str): + if quant_type.startswith("int8"): + # Note that int4 weights are created by packing into torch.int8, but since there is no torch.int4, we use torch.int8 + return torch.int8 + elif quant_type.startswith("int4"): + return CustomDtype.INT4 + elif quant_type == "uintx_weight_only": + return self.quantization_config.quant_type_kwargs.get("dtype", torch.uint8) + elif quant_type.startswith("uint"): + return { + 1: torch.uint1, + 2: torch.uint2, + 3: torch.uint3, + 4: torch.uint4, + 5: torch.uint5, + 6: torch.uint6, + 7: torch.uint7, + }[int(quant_type[4])] + elif quant_type.startswith("float") or quant_type.startswith("fp"): + return torch.bfloat16 + + elif is_torchao_version(">", "0.9.0"): + from torchao.core.config import AOBaseConfig + + quant_type = self.quantization_config.quant_type + if isinstance(quant_type, AOBaseConfig): + # Extract size digit using fuzzy match on the class name + config_name = quant_type.__class__.__name__ + size_digit = fuzzy_match_size(config_name) + + # Map the extracted digit to appropriate dtype + if size_digit == "4": + return CustomDtype.INT4 + else: + # Default to int8 + return torch.int8 + + if isinstance(target_dtype, SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION): + return target_dtype + + # We need one of the supported dtypes to be selected in order for accelerate to determine + # the total size of modules/parameters for auto device placement. + possible_device_maps = ["auto", "balanced", "balanced_low_0", "sequential"] + raise ValueError( + f"You have set `device_map` as one of {possible_device_maps} on a TorchAO quantized model but a suitable target dtype " + f"could not be inferred. The supported target_dtypes are: {SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION}. If you think the " + f"dtype you are using should be supported, please open an issue at https://github.com/huggingface/diffusers/issues." + ) + + def adjust_max_memory(self, max_memory: dict[str, int | str]) -> dict[str, int | str]: + max_memory = {key: val * 0.9 for key, val in max_memory.items()} + return max_memory + + def check_if_quantized_param( + self, + model: "ModelMixin", + param_value: "torch.Tensor", + param_name: str, + state_dict: dict[str, Any], + **kwargs, + ) -> bool: + param_device = kwargs.pop("param_device", None) + # Check if the param_name is not in self.modules_to_not_convert + if any((key + "." in param_name) or (key == param_name) for key in self.modules_to_not_convert): + return False + elif param_device == "cpu" and self.offload: + # We don't quantize weights that we offload + return False + else: + # We only quantize the weight of nn.Linear + module, tensor_name = get_module_from_name(model, param_name) + return isinstance(module, torch.nn.Linear) and (tensor_name == "weight") + + def create_quantized_param( + self, + model: "ModelMixin", + param_value: "torch.Tensor", + param_name: str, + target_device: "torch.device", + state_dict: dict[str, Any], + unexpected_keys: list[str], + **kwargs, + ): + r""" + Each nn.Linear layer that needs to be quantized is processed here. First, we set the value the weight tensor, + then we move it to the target device. Finally, we quantize the module. + """ + module, tensor_name = get_module_from_name(model, param_name) + + if self.pre_quantized: + # If we're loading pre-quantized weights, replace the repr of linear layers for pretty printing info + # about AffineQuantizedTensor + module._parameters[tensor_name] = torch.nn.Parameter(param_value.to(device=target_device)) + if isinstance(module, nn.Linear): + module.extra_repr = types.MethodType(_linear_extra_repr, module) + else: + # As we perform quantization here, the repr of linear layers is that of AQT, so we don't have to do it ourselves + module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device) + quantize_(module, self.quantization_config.get_apply_tensor_subclass()) + + def get_cuda_warm_up_factor(self): + """ + This factor is used in caching_allocator_warmup to determine how many bytes to pre-allocate for CUDA warmup. + - A factor of 2 means we pre-allocate the full memory footprint of the model. + - A factor of 4 means we pre-allocate half of that, and so on + + However, when using TorchAO, calculating memory usage with param.numel() * param.element_size() doesn't give + the correct size for quantized weights (like int4 or int8) That's because TorchAO internally represents + quantized tensors using subtensors and metadata, and the reported element_size() still corresponds to the + torch_dtype not the actual bit-width of the quantized data. + + To correct for this: + - Use a division factor of 8 for int4 weights + - Use a division factor of 4 for int8 weights + """ + # Original mapping for non-AOBaseConfig types + # For the uint types, this is a best guess. Once these types become more used + # we can look into their nuances. + if is_torchao_version(">", "0.9.0"): + from torchao.core.config import AOBaseConfig + + quant_type = self.quantization_config.quant_type + if isinstance(quant_type, AOBaseConfig): + # Extract size digit using fuzzy match on the class name + config_name = quant_type.__class__.__name__ + size_digit = fuzzy_match_size(config_name) + + if size_digit == "4": + return 8 + else: + return 4 + + map_to_target_dtype = {"int4_*": 8, "int8_*": 4, "uint*": 8, "float8*": 4} + quant_type = self.quantization_config.quant_type + for pattern, target_dtype in map_to_target_dtype.items(): + if fnmatch(quant_type, pattern): + return target_dtype + raise ValueError(f"Unsupported quant_type: {quant_type!r}") + + def _process_model_before_weight_loading( + self, + model: "ModelMixin", + device_map, + keep_in_fp32_modules: list[str] = [], + **kwargs, + ): + self.modules_to_not_convert = self.quantization_config.modules_to_not_convert + + if not isinstance(self.modules_to_not_convert, list): + self.modules_to_not_convert = [self.modules_to_not_convert] + + self.modules_to_not_convert.extend(keep_in_fp32_modules) + + # Extend `self.modules_to_not_convert` to keys that are supposed to be offloaded to `cpu` or `disk` + if isinstance(device_map, dict) and len(device_map.keys()) > 1: + keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]] + self.modules_to_not_convert.extend(keys_on_cpu) + + # Purge `None`. + # Unlike `transformers`, we don't know if we should always keep certain modules in FP32 + # in case of diffusion transformer models. For language models and others alike, `lm_head` + # and tied modules are usually kept in FP32. + self.modules_to_not_convert = [module for module in self.modules_to_not_convert if module is not None] + + model.config.quantization_config = self.quantization_config + + def _process_model_after_weight_loading(self, model: "ModelMixin"): + return model + + def is_serializable(self, safe_serialization=None): + # TODO(aryan): needs to be tested + if safe_serialization: + logger.warning( + "torchao quantized model does not support safe serialization, please set `safe_serialization` to False." + ) + return False + + _is_torchao_serializable = version.parse(importlib.metadata.version("huggingface_hub")) >= version.parse( + "0.25.0" + ) + + if not _is_torchao_serializable: + logger.warning("torchao quantized model is only serializable after huggingface_hub >= 0.25.0 ") + + if self.offload and self.quantization_config.modules_to_not_convert is None: + logger.warning( + "The model contains offloaded modules and these modules are not quantized. We don't recommend saving the model as we won't be able to reload them." + "If you want to specify modules to not quantize, please specify modules_to_not_convert in the quantization_config." + ) + return False + + return _is_torchao_serializable + + @property + def is_trainable(self): + return self.quantization_config.quant_type.startswith("int8") + + @property + def is_compileable(self) -> bool: + return True diff --git a/einops-0.8.2.dist-info/licenses/LICENSE b/einops-0.8.2.dist-info/licenses/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..3a654e906619009358eb2cfe80609bd12b43fa7f --- /dev/null +++ b/einops-0.8.2.dist-info/licenses/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2018 Alex Rogozhnikov + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/fsspec/implementations/__init__.py b/fsspec/implementations/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/fsspec/implementations/arrow.py b/fsspec/implementations/arrow.py new file mode 100644 index 0000000000000000000000000000000000000000..530df901a7225bab4afb9f08d06540bc18e91aef --- /dev/null +++ b/fsspec/implementations/arrow.py @@ -0,0 +1,304 @@ +import errno +import io +import os +import secrets +import shutil +from contextlib import suppress +from functools import cached_property, wraps +from urllib.parse import parse_qs + +from fsspec.spec import AbstractFileSystem +from fsspec.utils import ( + get_package_version_without_import, + infer_storage_options, + mirror_from, + tokenize, +) + + +def wrap_exceptions(func): + @wraps(func) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except OSError as exception: + if not exception.args: + raise + + message, *args = exception.args + if isinstance(message, str) and "does not exist" in message: + raise FileNotFoundError(errno.ENOENT, message) from exception + else: + raise + + return wrapper + + +PYARROW_VERSION = None + + +class ArrowFSWrapper(AbstractFileSystem): + """FSSpec-compatible wrapper of pyarrow.fs.FileSystem. + + Parameters + ---------- + fs : pyarrow.fs.FileSystem + + """ + + root_marker = "/" + + def __init__(self, fs, **kwargs): + global PYARROW_VERSION + PYARROW_VERSION = get_package_version_without_import("pyarrow") + self.fs = fs + super().__init__(**kwargs) + + @property + def protocol(self): + return self.fs.type_name + + @cached_property + def fsid(self): + return "hdfs_" + tokenize(self.fs.host, self.fs.port) + + @classmethod + def _strip_protocol(cls, path): + ops = infer_storage_options(path) + path = ops["path"] + if path.startswith("//"): + # special case for "hdfs://path" (without the triple slash) + path = path[1:] + return path + + def ls(self, path, detail=False, **kwargs): + path = self._strip_protocol(path) + from pyarrow.fs import FileSelector + + entries = [ + self._make_entry(entry) + for entry in self.fs.get_file_info(FileSelector(path)) + ] + if detail: + return entries + else: + return [entry["name"] for entry in entries] + + def info(self, path, **kwargs): + path = self._strip_protocol(path) + [info] = self.fs.get_file_info([path]) + return self._make_entry(info) + + def exists(self, path): + path = self._strip_protocol(path) + try: + self.info(path) + except FileNotFoundError: + return False + else: + return True + + def _make_entry(self, info): + from pyarrow.fs import FileType + + if info.type is FileType.Directory: + kind = "directory" + elif info.type is FileType.File: + kind = "file" + elif info.type is FileType.NotFound: + raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), info.path) + else: + kind = "other" + + return { + "name": info.path, + "size": info.size, + "type": kind, + "mtime": info.mtime, + } + + @wrap_exceptions + def cp_file(self, path1, path2, **kwargs): + path1 = self._strip_protocol(path1).rstrip("/") + path2 = self._strip_protocol(path2).rstrip("/") + + with self._open(path1, "rb") as lstream: + tmp_fname = f"{path2}.tmp.{secrets.token_hex(6)}" + try: + with self.open(tmp_fname, "wb") as rstream: + shutil.copyfileobj(lstream, rstream) + self.fs.move(tmp_fname, path2) + except BaseException: + with suppress(FileNotFoundError): + self.fs.delete_file(tmp_fname) + raise + + @wrap_exceptions + def mv(self, path1, path2, **kwargs): + path1 = self._strip_protocol(path1).rstrip("/") + path2 = self._strip_protocol(path2).rstrip("/") + self.fs.move(path1, path2) + + @wrap_exceptions + def rm_file(self, path): + path = self._strip_protocol(path) + self.fs.delete_file(path) + + @wrap_exceptions + def rm(self, path, recursive=False, maxdepth=None): + path = self._strip_protocol(path).rstrip("/") + if self.isdir(path): + if recursive: + self.fs.delete_dir(path) + else: + raise ValueError("Can't delete directories without recursive=False") + else: + self.fs.delete_file(path) + + @wrap_exceptions + def _open(self, path, mode="rb", block_size=None, seekable=True, **kwargs): + if mode == "rb": + if seekable: + method = self.fs.open_input_file + else: + method = self.fs.open_input_stream + elif mode == "wb": + method = self.fs.open_output_stream + elif mode == "ab": + method = self.fs.open_append_stream + else: + raise ValueError(f"unsupported mode for Arrow filesystem: {mode!r}") + + _kwargs = {} + if mode != "rb" or not seekable: + if int(PYARROW_VERSION.split(".")[0]) >= 4: + # disable compression auto-detection + _kwargs["compression"] = None + stream = method(path, **_kwargs) + + return ArrowFile(self, stream, path, mode, block_size, **kwargs) + + @wrap_exceptions + def mkdir(self, path, create_parents=True, **kwargs): + path = self._strip_protocol(path) + if create_parents: + self.makedirs(path, exist_ok=True) + else: + self.fs.create_dir(path, recursive=False) + + @wrap_exceptions + def makedirs(self, path, exist_ok=False): + path = self._strip_protocol(path) + self.fs.create_dir(path, recursive=True) + + @wrap_exceptions + def rmdir(self, path): + path = self._strip_protocol(path) + self.fs.delete_dir(path) + + @wrap_exceptions + def modified(self, path): + path = self._strip_protocol(path) + return self.fs.get_file_info(path).mtime + + def cat_file(self, path, start=None, end=None, **kwargs): + kwargs["seekable"] = start not in [None, 0] + return super().cat_file(path, start=None, end=None, **kwargs) + + def get_file(self, rpath, lpath, **kwargs): + kwargs["seekable"] = False + super().get_file(rpath, lpath, **kwargs) + + +@mirror_from( + "stream", + [ + "read", + "seek", + "tell", + "write", + "readable", + "writable", + "close", + "size", + "seekable", + ], +) +class ArrowFile(io.IOBase): + def __init__(self, fs, stream, path, mode, block_size=None, **kwargs): + self.path = path + self.mode = mode + + self.fs = fs + self.stream = stream + + self.blocksize = self.block_size = block_size + self.kwargs = kwargs + + def __enter__(self): + return self + + def __exit__(self, *args): + return self.close() + + +class HadoopFileSystem(ArrowFSWrapper): + """A wrapper on top of the pyarrow.fs.HadoopFileSystem + to connect it's interface with fsspec""" + + protocol = "hdfs" + + def __init__( + self, + host="default", + port=0, + user=None, + kerb_ticket=None, + replication=3, + extra_conf=None, + **kwargs, + ): + """ + + Parameters + ---------- + host: str + Hostname, IP or "default" to try to read from Hadoop config + port: int + Port to connect on, or default from Hadoop config if 0 + user: str or None + If given, connect as this username + kerb_ticket: str or None + If given, use this ticket for authentication + replication: int + set replication factor of file for write operations. default value is 3. + extra_conf: None or dict + Passed on to HadoopFileSystem + """ + from pyarrow.fs import HadoopFileSystem + + fs = HadoopFileSystem( + host=host, + port=port, + user=user, + kerb_ticket=kerb_ticket, + replication=replication, + extra_conf=extra_conf, + ) + super().__init__(fs=fs, **kwargs) + + @staticmethod + def _get_kwargs_from_urls(path): + ops = infer_storage_options(path) + out = {} + if ops.get("host", None): + out["host"] = ops["host"] + if ops.get("username", None): + out["user"] = ops["username"] + if ops.get("port", None): + out["port"] = ops["port"] + if ops.get("url_query", None): + queries = parse_qs(ops["url_query"]) + if queries.get("replication", None): + out["replication"] = int(queries["replication"][0]) + return out diff --git a/fsspec/implementations/asyn_wrapper.py b/fsspec/implementations/asyn_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..36db9c1b42e1a057c5346f0bedac3f9ff015c401 --- /dev/null +++ b/fsspec/implementations/asyn_wrapper.py @@ -0,0 +1,122 @@ +import asyncio +import functools +import inspect + +import fsspec +from fsspec.asyn import AsyncFileSystem, running_async + + +def async_wrapper(func, obj=None, semaphore=None): + """ + Wraps a synchronous function to make it awaitable. + + Parameters + ---------- + func : callable + The synchronous function to wrap. + obj : object, optional + The instance to bind the function to, if applicable. + semaphore : asyncio.Semaphore, optional + A semaphore to limit concurrent calls. + + Returns + ------- + coroutine + An awaitable version of the function. + """ + + @functools.wraps(func) + async def wrapper(*args, **kwargs): + if semaphore: + async with semaphore: + return await asyncio.to_thread(func, *args, **kwargs) + return await asyncio.to_thread(func, *args, **kwargs) + + return wrapper + + +class AsyncFileSystemWrapper(AsyncFileSystem): + """ + A wrapper class to convert a synchronous filesystem into an asynchronous one. + + This class takes an existing synchronous filesystem implementation and wraps all + its methods to provide an asynchronous interface. + + Parameters + ---------- + sync_fs : AbstractFileSystem + The synchronous filesystem instance to wrap. + """ + + protocol = "asyncwrapper", "async_wrapper" + cachable = False + + def __init__( + self, + fs=None, + asynchronous=None, + target_protocol=None, + target_options=None, + semaphore=None, + max_concurrent_tasks=None, + **kwargs, + ): + if asynchronous is None: + asynchronous = running_async() + super().__init__(asynchronous=asynchronous, **kwargs) + if fs is not None: + self.sync_fs = fs + else: + self.sync_fs = fsspec.filesystem(target_protocol, **target_options) + self.protocol = self.sync_fs.protocol + self.semaphore = semaphore + self._wrap_all_sync_methods() + + @property + def fsid(self): + return f"async_{self.sync_fs.fsid}" + + def _wrap_all_sync_methods(self): + """ + Wrap all synchronous methods of the underlying filesystem with asynchronous versions. + """ + excluded_methods = {"open"} + for method_name in dir(self.sync_fs): + if method_name.startswith("_") or method_name in excluded_methods: + continue + + attr = inspect.getattr_static(self.sync_fs, method_name) + if isinstance(attr, property): + continue + + method = getattr(self.sync_fs, method_name) + if callable(method) and not inspect.iscoroutinefunction(method): + async_method = async_wrapper(method, obj=self, semaphore=self.semaphore) + setattr(self, f"_{method_name}", async_method) + + @classmethod + def wrap_class(cls, sync_fs_class): + """ + Create a new class that can be used to instantiate an AsyncFileSystemWrapper + with lazy instantiation of the underlying synchronous filesystem. + + Parameters + ---------- + sync_fs_class : type + The class of the synchronous filesystem to wrap. + + Returns + ------- + type + A new class that wraps the provided synchronous filesystem class. + """ + + class GeneratedAsyncFileSystemWrapper(cls): + def __init__(self, *args, **kwargs): + sync_fs = sync_fs_class(*args, **kwargs) + super().__init__(sync_fs) + + GeneratedAsyncFileSystemWrapper.__name__ = ( + f"Async{sync_fs_class.__name__}Wrapper" + ) + return GeneratedAsyncFileSystemWrapper diff --git a/fsspec/implementations/cache_mapper.py b/fsspec/implementations/cache_mapper.py new file mode 100644 index 0000000000000000000000000000000000000000..6e7c7d88afdddf12f77b26bb635bd8bf1e2bd7f1 --- /dev/null +++ b/fsspec/implementations/cache_mapper.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +import abc +import hashlib + +from fsspec.implementations.local import make_path_posix + + +class AbstractCacheMapper(abc.ABC): + """Abstract super-class for mappers from remote URLs to local cached + basenames. + """ + + @abc.abstractmethod + def __call__(self, path: str) -> str: ... + + def __eq__(self, other: object) -> bool: + # Identity only depends on class. When derived classes have attributes + # they will need to be included. + return isinstance(other, type(self)) + + def __hash__(self) -> int: + # Identity only depends on class. When derived classes have attributes + # they will need to be included. + return hash(type(self)) + + +class BasenameCacheMapper(AbstractCacheMapper): + """Cache mapper that uses the basename of the remote URL and a fixed number + of directory levels above this. + + The default is zero directory levels, meaning different paths with the same + basename will have the same cached basename. + """ + + def __init__(self, directory_levels: int = 0): + if directory_levels < 0: + raise ValueError( + "BasenameCacheMapper requires zero or positive directory_levels" + ) + self.directory_levels = directory_levels + + # Separator for directories when encoded as strings. + self._separator = "_@_" + + def __call__(self, path: str) -> str: + path = make_path_posix(path) + prefix, *bits = path.rsplit("/", self.directory_levels + 1) + if bits: + return self._separator.join(bits) + else: + return prefix # No separator found, simple filename + + def __eq__(self, other: object) -> bool: + return super().__eq__(other) and self.directory_levels == other.directory_levels + + def __hash__(self) -> int: + return super().__hash__() ^ hash(self.directory_levels) + + +class HashCacheMapper(AbstractCacheMapper): + """Cache mapper that uses a hash of the remote URL.""" + + def __call__(self, path: str) -> str: + return hashlib.sha256(path.encode()).hexdigest() + + +def create_cache_mapper(same_names: bool) -> AbstractCacheMapper: + """Factory method to create cache mapper for backward compatibility with + ``CachingFileSystem`` constructor using ``same_names`` kwarg. + """ + if same_names: + return BasenameCacheMapper() + else: + return HashCacheMapper() diff --git a/fsspec/implementations/cache_metadata.py b/fsspec/implementations/cache_metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..4a519158d4aca975dcbcc5978aeb57dab0e53cb0 --- /dev/null +++ b/fsspec/implementations/cache_metadata.py @@ -0,0 +1,233 @@ +from __future__ import annotations + +import os +import pickle +import time +from typing import TYPE_CHECKING + +from fsspec.utils import atomic_write + +try: + import ujson as json +except ImportError: + if not TYPE_CHECKING: + import json + +if TYPE_CHECKING: + from collections.abc import Iterator + from typing import Any, Literal + + from typing_extensions import TypeAlias + + from .cached import CachingFileSystem + + Detail: TypeAlias = dict[str, Any] + + +class CacheMetadata: + """Cache metadata. + + All reading and writing of cache metadata is performed by this class, + accessing the cached files and blocks is not. + + Metadata is stored in a single file per storage directory in JSON format. + For backward compatibility, also reads metadata stored in pickle format + which is converted to JSON when next saved. + """ + + def __init__(self, storage: list[str]): + """ + + Parameters + ---------- + storage: list[str] + Directories containing cached files, must be at least one. Metadata + is stored in the last of these directories by convention. + """ + if not storage: + raise ValueError("CacheMetadata expects at least one storage location") + + self._storage = storage + self.cached_files: list[Detail] = [{}] + + # Private attribute to force saving of metadata in pickle format rather than + # JSON for use in tests to confirm can read both pickle and JSON formats. + self._force_save_pickle = False + + def _load(self, fn: str) -> Detail: + """Low-level function to load metadata from specific file""" + try: + with open(fn, "r") as f: + loaded = json.load(f) + except ValueError: + with open(fn, "rb") as f: + loaded = pickle.load(f) + for c in loaded.values(): + if isinstance(c.get("blocks"), list): + c["blocks"] = set(c["blocks"]) + return loaded + + def _save(self, metadata_to_save: Detail, fn: str) -> None: + """Low-level function to save metadata to specific file""" + if self._force_save_pickle: + with atomic_write(fn) as f: + pickle.dump(metadata_to_save, f) + else: + with atomic_write(fn, mode="w") as f: + json.dump(metadata_to_save, f) + + def _scan_locations( + self, writable_only: bool = False + ) -> Iterator[tuple[str, str, bool]]: + """Yield locations (filenames) where metadata is stored, and whether + writable or not. + + Parameters + ---------- + writable: bool + Set to True to only yield writable locations. + + Returns + ------- + Yields (str, str, bool) + """ + n = len(self._storage) + for i, storage in enumerate(self._storage): + writable = i == n - 1 + if writable_only and not writable: + continue + yield os.path.join(storage, "cache"), storage, writable + + def check_file( + self, path: str, cfs: CachingFileSystem | None + ) -> Literal[False] | tuple[Detail, str]: + """If path is in cache return its details, otherwise return ``False``. + + If the optional CachingFileSystem is specified then it is used to + perform extra checks to reject possible matches, such as if they are + too old. + """ + for (fn, base, _), cache in zip(self._scan_locations(), self.cached_files): + if path not in cache: + continue + detail = cache[path].copy() + + if cfs is not None: + if cfs.check_files and detail["uid"] != cfs.fs.ukey(path): + # Wrong file as determined by hash of file properties + continue + if cfs.expiry and time.time() - detail["time"] > cfs.expiry: + # Cached file has expired + continue + + fn = os.path.join(base, detail["fn"]) + if os.path.exists(fn): + return detail, fn + return False + + def clear_expired(self, expiry_time: int) -> tuple[list[str], bool]: + """Remove expired metadata from the cache. + + Returns names of files corresponding to expired metadata and a boolean + flag indicating whether the writable cache is empty. Caller is + responsible for deleting the expired files. + """ + expired_files = [] + for path, detail in self.cached_files[-1].copy().items(): + if time.time() - detail["time"] > expiry_time: + fn = detail.get("fn", "") + if not fn: + raise RuntimeError( + f"Cache metadata does not contain 'fn' for {path}" + ) + fn = os.path.join(self._storage[-1], fn) + expired_files.append(fn) + self.cached_files[-1].pop(path) + + if self.cached_files[-1]: + cache_path = os.path.join(self._storage[-1], "cache") + self._save(self.cached_files[-1], cache_path) + + writable_cache_empty = not self.cached_files[-1] + return expired_files, writable_cache_empty + + def load(self) -> None: + """Load all metadata from disk and store in ``self.cached_files``""" + cached_files = [] + for fn, _, _ in self._scan_locations(): + if os.path.exists(fn): + # TODO: consolidate blocks here + cached_files.append(self._load(fn)) + else: + cached_files.append({}) + self.cached_files = cached_files or [{}] + + def on_close_cached_file(self, f: Any, path: str) -> None: + """Perform side-effect actions on closing a cached file. + + The actual closing of the file is the responsibility of the caller. + """ + # File must be writeble, so in self.cached_files[-1] + c = self.cached_files[-1][path] + if c["blocks"] is not True and len(c["blocks"]) * f.blocksize >= f.size: + c["blocks"] = True + + def pop_file(self, path: str) -> str | None: + """Remove metadata of cached file. + + If path is in the cache, return the filename of the cached file, + otherwise return ``None``. Caller is responsible for deleting the + cached file. + """ + details = self.check_file(path, None) + if not details: + return None + _, fn = details + if fn.startswith(self._storage[-1]): + self.cached_files[-1].pop(path) + self.save() + else: + raise PermissionError( + "Can only delete cached file in last, writable cache location" + ) + return fn + + def save(self) -> None: + """Save metadata to disk""" + for (fn, _, writable), cache in zip(self._scan_locations(), self.cached_files): + if not writable: + continue + + if os.path.exists(fn): + cached_files = self._load(fn) + for k, c in cached_files.items(): + if k in cache: + if c["blocks"] is True or cache[k]["blocks"] is True: + c["blocks"] = True + else: + # self.cached_files[*][*]["blocks"] must continue to + # point to the same set object so that updates + # performed by MMapCache are propagated back to + # self.cached_files. + blocks = cache[k]["blocks"] + blocks.update(c["blocks"]) + c["blocks"] = blocks + c["time"] = max(c["time"], cache[k]["time"]) + c["uid"] = cache[k]["uid"] + + # Files can be added to cache after it was written once + for k, c in cache.items(): + if k not in cached_files: + cached_files[k] = c + else: + cached_files = cache + cache = {k: v.copy() for k, v in cached_files.items()} + for c in cache.values(): + if isinstance(c["blocks"], set): + c["blocks"] = list(c["blocks"]) + self._save(cache, fn) + self.cached_files[-1] = cached_files + + def update_file(self, path: str, detail: Detail) -> None: + """Update metadata for specific file in memory, do not save""" + self.cached_files[-1][path] = detail diff --git a/fsspec/implementations/cached.py b/fsspec/implementations/cached.py new file mode 100644 index 0000000000000000000000000000000000000000..4847c2ef262b688b25508b2861f394426ba2dd73 --- /dev/null +++ b/fsspec/implementations/cached.py @@ -0,0 +1,998 @@ +from __future__ import annotations + +import inspect +import logging +import os +import tempfile +import time +import weakref +from shutil import rmtree +from typing import TYPE_CHECKING, Any, Callable, ClassVar + +from fsspec import AbstractFileSystem, filesystem +from fsspec.callbacks import DEFAULT_CALLBACK +from fsspec.compression import compr +from fsspec.core import BaseCache, MMapCache +from fsspec.exceptions import BlocksizeMismatchError +from fsspec.implementations.cache_mapper import create_cache_mapper +from fsspec.implementations.cache_metadata import CacheMetadata +from fsspec.implementations.local import LocalFileSystem +from fsspec.spec import AbstractBufferedFile +from fsspec.transaction import Transaction +from fsspec.utils import infer_compression + +if TYPE_CHECKING: + from fsspec.implementations.cache_mapper import AbstractCacheMapper + +logger = logging.getLogger("fsspec.cached") + + +class WriteCachedTransaction(Transaction): + def complete(self, commit=True): + rpaths = [f.path for f in self.files] + lpaths = [f.fn for f in self.files] + if commit: + self.fs.put(lpaths, rpaths) + self.files.clear() + self.fs._intrans = False + self.fs._transaction = None + self.fs = None # break cycle + + +class CachingFileSystem(AbstractFileSystem): + """Locally caching filesystem, layer over any other FS + + This class implements chunk-wise local storage of remote files, for quick + access after the initial download. The files are stored in a given + directory with hashes of URLs for the filenames. If no directory is given, + a temporary one is used, which should be cleaned up by the OS after the + process ends. The files themselves are sparse (as implemented in + :class:`~fsspec.caching.MMapCache`), so only the data which is accessed + takes up space. + + Restrictions: + + - the block-size must be the same for each access of a given file, unless + all blocks of the file have already been read + - caching can only be applied to file-systems which produce files + derived from fsspec.spec.AbstractBufferedFile ; LocalFileSystem is also + allowed, for testing + """ + + protocol: ClassVar[str | tuple[str, ...]] = ("blockcache", "cached") + + def __init__( + self, + target_protocol=None, + cache_storage="TMP", + cache_check=10, + check_files=False, + expiry_time=604800, + target_options=None, + fs=None, + same_names: bool | None = None, + compression=None, + cache_mapper: AbstractCacheMapper | None = None, + **kwargs, + ): + """ + + Parameters + ---------- + target_protocol: str (optional) + Target filesystem protocol. Provide either this or ``fs``. + cache_storage: str or list(str) + Location to store files. If "TMP", this is a temporary directory, + and will be cleaned up by the OS when this process ends (or later). + If a list, each location will be tried in the order given, but + only the last will be considered writable. + cache_check: int + Number of seconds between reload of cache metadata + check_files: bool + Whether to explicitly see if the UID of the remote file matches + the stored one before using. Warning: some file systems such as + HTTP cannot reliably give a unique hash of the contents of some + path, so be sure to set this option to False. + expiry_time: int + The time in seconds after which a local copy is considered useless. + Set to falsy to prevent expiry. The default is equivalent to one + week. + target_options: dict or None + Passed to the instantiation of the FS, if fs is None. + fs: filesystem instance + The target filesystem to run against. Provide this or ``protocol``. + same_names: bool (optional) + By default, target URLs are hashed using a ``HashCacheMapper`` so + that files from different backends with the same basename do not + conflict. If this argument is ``true``, a ``BasenameCacheMapper`` + is used instead. Other cache mapper options are available by using + the ``cache_mapper`` keyword argument. Only one of this and + ``cache_mapper`` should be specified. + compression: str (optional) + To decompress on download. Can be 'infer' (guess from the URL name), + one of the entries in ``fsspec.compression.compr``, or None for no + decompression. + cache_mapper: AbstractCacheMapper (optional) + The object use to map from original filenames to cached filenames. + Only one of this and ``same_names`` should be specified. + """ + super().__init__(**kwargs) + if fs is None and target_protocol is None: + raise ValueError( + "Please provide filesystem instance(fs) or target_protocol" + ) + if not (fs is None) ^ (target_protocol is None): + raise ValueError( + "Both filesystems (fs) and target_protocol may not be both given." + ) + if cache_storage == "TMP": + tempdir = tempfile.mkdtemp() + storage = [tempdir] + weakref.finalize(self, self._remove_tempdir, tempdir) + else: + if isinstance(cache_storage, str): + storage = [cache_storage] + else: + storage = cache_storage + os.makedirs(storage[-1], exist_ok=True) + self.storage = storage + self.kwargs = target_options or {} + self.cache_check = cache_check + self.check_files = check_files + self.expiry = expiry_time + self.compression = compression + + # Size of cache in bytes. If None then the size is unknown and will be + # recalculated the next time cache_size() is called. On writes to the + # cache this is reset to None. + self._cache_size = None + + if same_names is not None and cache_mapper is not None: + raise ValueError( + "Cannot specify both same_names and cache_mapper in " + "CachingFileSystem.__init__" + ) + if cache_mapper is not None: + self._mapper = cache_mapper + else: + self._mapper = create_cache_mapper( + same_names if same_names is not None else False + ) + + self.target_protocol = ( + target_protocol + if isinstance(target_protocol, str) + else (fs.protocol if isinstance(fs.protocol, str) else fs.protocol[0]) + ) + self._metadata = CacheMetadata(self.storage) + self.load_cache() + self.fs = fs if fs is not None else filesystem(target_protocol, **self.kwargs) + + def _strip_protocol(path): + # acts as a method, since each instance has a difference target + return self.fs._strip_protocol(type(self)._strip_protocol(path)) + + self._strip_protocol: Callable = _strip_protocol + + @staticmethod + def _remove_tempdir(tempdir): + try: + rmtree(tempdir) + except Exception: + pass + + def _mkcache(self): + os.makedirs(self.storage[-1], exist_ok=True) + + def cache_size(self): + """Return size of cache in bytes. + + If more than one cache directory is in use, only the size of the last + one (the writable cache directory) is returned. + """ + if self._cache_size is None: + cache_dir = self.storage[-1] + self._cache_size = filesystem("file").du(cache_dir, withdirs=True) + return self._cache_size + + def load_cache(self): + """Read set of stored blocks from file""" + self._metadata.load() + self._mkcache() + self.last_cache = time.time() + + def save_cache(self): + """Save set of stored blocks from file""" + self._mkcache() + self._metadata.save() + self.last_cache = time.time() + self._cache_size = None + + def _check_cache(self): + """Reload caches if time elapsed or any disappeared""" + self._mkcache() + if not self.cache_check: + # explicitly told not to bother checking + return + timecond = time.time() - self.last_cache > self.cache_check + existcond = all(os.path.exists(storage) for storage in self.storage) + if timecond or not existcond: + self.load_cache() + + def _check_file(self, path): + """Is path in cache and still valid""" + path = self._strip_protocol(path) + self._check_cache() + return self._metadata.check_file(path, self) + + def clear_cache(self): + """Remove all files and metadata from the cache + + In the case of multiple cache locations, this clears only the last one, + which is assumed to be the read/write one. + """ + rmtree(self.storage[-1]) + self.load_cache() + self._cache_size = None + + def clear_expired_cache(self, expiry_time=None): + """Remove all expired files and metadata from the cache + + In the case of multiple cache locations, this clears only the last one, + which is assumed to be the read/write one. + + Parameters + ---------- + expiry_time: int + The time in seconds after which a local copy is considered useless. + If not defined the default is equivalent to the attribute from the + file caching instantiation. + """ + + if not expiry_time: + expiry_time = self.expiry + + self._check_cache() + + expired_files, writable_cache_empty = self._metadata.clear_expired(expiry_time) + for fn in expired_files: + if os.path.exists(fn): + os.remove(fn) + + if writable_cache_empty: + rmtree(self.storage[-1]) + self.load_cache() + + self._cache_size = None + + def pop_from_cache(self, path): + """Remove cached version of given file + + Deletes local copy of the given (remote) path. If it is found in a cache + location which is not the last, it is assumed to be read-only, and + raises PermissionError + """ + path = self._strip_protocol(path) + fn = self._metadata.pop_file(path) + if fn is not None: + os.remove(fn) + self._cache_size = None + + def _open( + self, + path, + mode="rb", + block_size=None, + autocommit=True, + cache_options=None, + **kwargs, + ): + """Wrap the target _open + + If the whole file exists in the cache, just open it locally and + return that. + + Otherwise, open the file on the target FS, and make it have a mmap + cache pointing to the location which we determine, in our cache. + The ``blocks`` instance is shared, so as the mmap cache instance + updates, so does the entry in our ``cached_files`` attribute. + We monkey-patch this file, so that when it closes, we call + ``close_and_update`` to save the state of the blocks. + """ + path = self._strip_protocol(path) + + path = self.fs._strip_protocol(path) + if "r" not in mode: + return self.fs._open( + path, + mode=mode, + block_size=block_size, + autocommit=autocommit, + cache_options=cache_options, + **kwargs, + ) + detail = self._check_file(path) + if detail: + # file is in cache + detail, fn = detail + hash, blocks = detail["fn"], detail["blocks"] + if blocks is True: + # stored file is complete + logger.debug("Opening local copy of %s", path) + return open(fn, mode) + # TODO: action where partial file exists in read-only cache + logger.debug("Opening partially cached copy of %s", path) + else: + hash = self._mapper(path) + fn = os.path.join(self.storage[-1], hash) + blocks = set() + detail = { + "original": path, + "fn": hash, + "blocks": blocks, + "time": time.time(), + "uid": self.fs.ukey(path), + } + self._metadata.update_file(path, detail) + logger.debug("Creating local sparse file for %s", path) + + # explicitly submitting the size to the open call will avoid extra + # operations when opening. This is particularly relevant + # for any file that is read over a network, e.g. S3. + size = detail.get("size") + + # call target filesystems open + self._mkcache() + f = self.fs._open( + path, + mode=mode, + block_size=block_size, + autocommit=autocommit, + cache_options=cache_options, + cache_type="none", + size=size, + **kwargs, + ) + + # set size if not already set + if size is None: + detail["size"] = f.size + self._metadata.update_file(path, detail) + + if self.compression: + comp = ( + infer_compression(path) + if self.compression == "infer" + else self.compression + ) + f = compr[comp](f, mode="rb") + if "blocksize" in detail: + if detail["blocksize"] != f.blocksize: + raise BlocksizeMismatchError( + f"Cached file must be reopened with same block" + f" size as original (old: {detail['blocksize']}," + f" new {f.blocksize})" + ) + else: + detail["blocksize"] = f.blocksize + + def _fetch_ranges(ranges): + return self.fs.cat_ranges( + [path] * len(ranges), + [r[0] for r in ranges], + [r[1] for r in ranges], + **kwargs, + ) + + multi_fetcher = None if self.compression else _fetch_ranges + f.cache = MMapCache( + f.blocksize, f._fetch_range, f.size, fn, blocks, multi_fetcher=multi_fetcher + ) + close = f.close + f.close = lambda: self.close_and_update(f, close) + self.save_cache() + return f + + def _parent(self, path): + return self.fs._parent(path) + + def hash_name(self, path: str, *args: Any) -> str: + # Kept for backward compatibility with downstream libraries. + # Ignores extra arguments, previously same_name boolean. + return self._mapper(path) + + def close_and_update(self, f, close): + """Called when a file is closing, so store the set of blocks""" + if f.closed: + return + path = self._strip_protocol(f.path) + self._metadata.on_close_cached_file(f, path) + try: + logger.debug("going to save") + self.save_cache() + logger.debug("saved") + except OSError: + logger.debug("Cache saving failed while closing file") + except NameError: + logger.debug("Cache save failed due to interpreter shutdown") + close() + f.closed = True + + def ls(self, path, detail=True): + return self.fs.ls(path, detail) + + def __getattribute__(self, item): + if item in { + "load_cache", + "_open", + "save_cache", + "close_and_update", + "__init__", + "__getattribute__", + "__reduce__", + "_make_local_details", + "open", + "cat", + "cat_file", + "_cat_file", + "cat_ranges", + "_cat_ranges", + "get", + "read_block", + "tail", + "head", + "info", + "ls", + "exists", + "isfile", + "isdir", + "_check_file", + "_check_cache", + "_mkcache", + "clear_cache", + "clear_expired_cache", + "pop_from_cache", + "local_file", + "_paths_from_path", + "get_mapper", + "open_many", + "commit_many", + "hash_name", + "__hash__", + "__eq__", + "to_json", + "to_dict", + "cache_size", + "pipe_file", + "pipe", + "start_transaction", + "end_transaction", + }: + # all the methods defined in this class. Note `open` here, since + # it calls `_open`, but is actually in superclass + return lambda *args, **kw: getattr(type(self), item).__get__(self)( + *args, **kw + ) + if item in ["__reduce_ex__"]: + raise AttributeError + if item in ["transaction"]: + # property + return type(self).transaction.__get__(self) + if item in {"_cache", "transaction_type", "protocol"}: + # class attributes + return getattr(type(self), item) + if item == "__class__": + return type(self) + d = object.__getattribute__(self, "__dict__") + fs = d.get("fs", None) # fs is not immediately defined + if item in d: + return d[item] + elif fs is not None: + if item in fs.__dict__: + # attribute of instance + return fs.__dict__[item] + # attributed belonging to the target filesystem + cls = type(fs) + m = getattr(cls, item) + if (inspect.isfunction(m) or inspect.isdatadescriptor(m)) and ( + not hasattr(m, "__self__") or m.__self__ is None + ): + # instance method + return m.__get__(fs, cls) + return m # class method or attribute + else: + # attributes of the superclass, while target is being set up + return super().__getattribute__(item) + + def __eq__(self, other): + """Test for equality.""" + if self is other: + return True + if not isinstance(other, type(self)): + return False + return ( + self.storage == other.storage + and self.kwargs == other.kwargs + and self.cache_check == other.cache_check + and self.check_files == other.check_files + and self.expiry == other.expiry + and self.compression == other.compression + and self._mapper == other._mapper + and self.target_protocol == other.target_protocol + ) + + def __hash__(self): + """Calculate hash.""" + return ( + hash(tuple(self.storage)) + ^ hash(str(self.kwargs)) + ^ hash(self.cache_check) + ^ hash(self.check_files) + ^ hash(self.expiry) + ^ hash(self.compression) + ^ hash(self._mapper) + ^ hash(self.target_protocol) + ) + + +class WholeFileCacheFileSystem(CachingFileSystem): + """Caches whole remote files on first access + + This class is intended as a layer over any other file system, and + will make a local copy of each file accessed, so that all subsequent + reads are local. This is similar to ``CachingFileSystem``, but without + the block-wise functionality and so can work even when sparse files + are not allowed. See its docstring for definition of the init + arguments. + + The class still needs access to the remote store for listing files, + and may refresh cached files. + """ + + protocol = "filecache" + local_file = True + + def open_many(self, open_files, **kwargs): + paths = [of.path for of in open_files] + if "r" in open_files.mode: + self._mkcache() + else: + return [ + LocalTempFile( + self.fs, + path, + mode=open_files.mode, + fn=os.path.join(self.storage[-1], self._mapper(path)), + **kwargs, + ) + for path in paths + ] + + if self.compression: + raise NotImplementedError + details = [self._check_file(sp) for sp in paths] + downpath = [p for p, d in zip(paths, details) if not d] + downfn0 = [ + os.path.join(self.storage[-1], self._mapper(p)) + for p, d in zip(paths, details) + ] # keep these path names for opening later + downfn = [fn for fn, d in zip(downfn0, details) if not d] + if downpath: + # skip if all files are already cached and up to date + self.fs.get(downpath, downfn) + + # update metadata - only happens when downloads are successful + newdetail = [ + { + "original": path, + "fn": self._mapper(path), + "blocks": True, + "time": time.time(), + "uid": self.fs.ukey(path), + } + for path in downpath + ] + for path, detail in zip(downpath, newdetail): + self._metadata.update_file(path, detail) + self.save_cache() + + def firstpart(fn): + # helper to adapt both whole-file and simple-cache + return fn[1] if isinstance(fn, tuple) else fn + + return [ + open(firstpart(fn0) if fn0 else fn1, mode=open_files.mode) + for fn0, fn1 in zip(details, downfn0) + ] + + def commit_many(self, open_files): + self.fs.put([f.fn for f in open_files], [f.path for f in open_files]) + [f.close() for f in open_files] + for f in open_files: + # in case autocommit is off, and so close did not already delete + try: + os.remove(f.name) + except FileNotFoundError: + pass + self._cache_size = None + + def _make_local_details(self, path): + hash = self._mapper(path) + fn = os.path.join(self.storage[-1], hash) + detail = { + "original": path, + "fn": hash, + "blocks": True, + "time": time.time(), + "uid": self.fs.ukey(path), + } + self._metadata.update_file(path, detail) + logger.debug("Copying %s to local cache", path) + return fn + + def cat( + self, + path, + recursive=False, + on_error="raise", + callback=DEFAULT_CALLBACK, + **kwargs, + ): + paths = self.expand_path( + path, recursive=recursive, maxdepth=kwargs.get("maxdepth") + ) + getpaths = [] + storepaths = [] + fns = [] + out = {} + for p in paths.copy(): + try: + detail = self._check_file(p) + if not detail: + fn = self._make_local_details(p) + getpaths.append(p) + storepaths.append(fn) + else: + detail, fn = detail if isinstance(detail, tuple) else (None, detail) + fns.append(fn) + except Exception as e: + if on_error == "raise": + raise + if on_error == "return": + out[p] = e + paths.remove(p) + + if getpaths: + self.fs.get(getpaths, storepaths) + self.save_cache() + + callback.set_size(len(paths)) + for p, fn in zip(paths, fns): + with open(fn, "rb") as f: + out[p] = f.read() + callback.relative_update(1) + if isinstance(path, str) and len(paths) == 1 and recursive is False: + out = out[paths[0]] + return out + + def _open(self, path, mode="rb", **kwargs): + path = self._strip_protocol(path) + if "r" not in mode: + hash = self._mapper(path) + fn = os.path.join(self.storage[-1], hash) + user_specified_kwargs = { + k: v + for k, v in kwargs.items() + # those kwargs were added by open(), we don't want them + if k not in ["autocommit", "block_size", "cache_options"] + } + return LocalTempFile(self, path, mode=mode, fn=fn, **user_specified_kwargs) + detail = self._check_file(path) + if detail: + detail, fn = detail + _, blocks = detail["fn"], detail["blocks"] + if blocks is True: + logger.debug("Opening local copy of %s", path) + + # In order to support downstream filesystems to be able to + # infer the compression from the original filename, like + # the `TarFileSystem`, let's extend the `io.BufferedReader` + # fileobject protocol by adding a dedicated attribute + # `original`. + f = open(fn, mode) + f.original = detail.get("original") + return f + else: + raise ValueError( + f"Attempt to open partially cached file {path}" + f" as a wholly cached file" + ) + else: + fn = self._make_local_details(path) + kwargs["mode"] = mode + + # call target filesystems open + self._mkcache() + if self.compression: + with self.fs._open(path, **kwargs) as f, open(fn, "wb") as f2: + if isinstance(f, AbstractBufferedFile): + # want no type of caching if just downloading whole thing + f.cache = BaseCache(0, f.cache.fetcher, f.size) + comp = ( + infer_compression(path) + if self.compression == "infer" + else self.compression + ) + f = compr[comp](f, mode="rb") + data = True + while data: + block = getattr(f, "blocksize", 5 * 2**20) + data = f.read(block) + f2.write(data) + else: + self.fs.get_file(path, fn) + self.save_cache() + return self._open(path, mode) + + +class SimpleCacheFileSystem(WholeFileCacheFileSystem): + """Caches whole remote files on first access + + This class is intended as a layer over any other file system, and + will make a local copy of each file accessed, so that all subsequent + reads are local. This implementation only copies whole files, and + does not keep any metadata about the download time or file details. + It is therefore safer to use in multi-threaded/concurrent situations. + + This is the only of the caching filesystems that supports write: you will + be given a real local open file, and upon close and commit, it will be + uploaded to the target filesystem; the writability or the target URL is + not checked until that time. + + """ + + protocol = "simplecache" + local_file = True + transaction_type = WriteCachedTransaction + + def __init__(self, **kwargs): + kw = kwargs.copy() + for key in ["cache_check", "expiry_time", "check_files"]: + kw[key] = False + super().__init__(**kw) + for storage in self.storage: + if not os.path.exists(storage): + os.makedirs(storage, exist_ok=True) + + def _check_file(self, path): + self._check_cache() + sha = self._mapper(path) + for storage in self.storage: + fn = os.path.join(storage, sha) + if os.path.exists(fn): + return fn + + def save_cache(self): + pass + + def load_cache(self): + pass + + def pipe_file(self, path, value=None, **kwargs): + if self._intrans: + with self.open(path, "wb") as f: + f.write(value) + else: + super().pipe_file(path, value) + + def ls(self, path, detail=True, **kwargs): + path = self._strip_protocol(path) + details = [] + try: + details = self.fs.ls( + path, detail=True, **kwargs + ).copy() # don't edit original! + except FileNotFoundError as e: + ex = e + else: + ex = None + if self._intrans: + path1 = path.rstrip("/") + "/" + for f in self.transaction.files: + if f.path == path: + details.append( + {"name": path, "size": f.size or f.tell(), "type": "file"} + ) + elif f.path.startswith(path1): + if f.path.count("/") == path1.count("/"): + details.append( + {"name": f.path, "size": f.size or f.tell(), "type": "file"} + ) + else: + dname = "/".join(f.path.split("/")[: path1.count("/") + 1]) + details.append({"name": dname, "size": 0, "type": "directory"}) + if ex is not None and not details: + raise ex + if detail: + return details + return sorted(_["name"] for _ in details) + + def info(self, path, **kwargs): + path = self._strip_protocol(path) + if self._intrans: + f = [_ for _ in self.transaction.files if _.path == path] + if f: + size = os.path.getsize(f[0].fn) if f[0].closed else f[0].tell() + return {"name": path, "size": size, "type": "file"} + f = any(_.path.startswith(path + "/") for _ in self.transaction.files) + if f: + return {"name": path, "size": 0, "type": "directory"} + return self.fs.info(path, **kwargs) + + def pipe(self, path, value=None, **kwargs): + if isinstance(path, str): + self.pipe_file(self._strip_protocol(path), value, **kwargs) + elif isinstance(path, dict): + for k, v in path.items(): + self.pipe_file(self._strip_protocol(k), v, **kwargs) + else: + raise ValueError("path must be str or dict") + + async def _cat_file(self, path, start=None, end=None, **kwargs): + logger.debug("async cat_file %s", path) + path = self._strip_protocol(path) + sha = self._mapper(path) + fn = self._check_file(path) + + if not fn: + fn = os.path.join(self.storage[-1], sha) + await self.fs._get_file(path, fn, **kwargs) + + with open(fn, "rb") as f: # noqa ASYNC230 + if start: + f.seek(start) + size = -1 if end is None else end - f.tell() + return f.read(size) + + async def _cat_ranges( + self, paths, starts, ends, max_gap=None, on_error="return", **kwargs + ): + logger.debug("async cat ranges %s", paths) + lpaths = [] + rset = set() + download = [] + rpaths = [] + for p in paths: + fn = self._check_file(p) + if fn is None and p not in rset: + sha = self._mapper(p) + fn = os.path.join(self.storage[-1], sha) + download.append(fn) + rset.add(p) + rpaths.append(p) + lpaths.append(fn) + if download: + await self.fs._get(rpaths, download, on_error=on_error) + + return LocalFileSystem().cat_ranges( + lpaths, starts, ends, max_gap=max_gap, on_error=on_error, **kwargs + ) + + def cat_ranges( + self, paths, starts, ends, max_gap=None, on_error="return", **kwargs + ): + logger.debug("cat ranges %s", paths) + lpaths = [self._check_file(p) for p in paths] + rpaths = [p for l, p in zip(lpaths, paths) if l is False] + lpaths = [l for l, p in zip(lpaths, paths) if l is False] + self.fs.get(rpaths, lpaths) + paths = [self._check_file(p) for p in paths] + return LocalFileSystem().cat_ranges( + paths, starts, ends, max_gap=max_gap, on_error=on_error, **kwargs + ) + + def _open(self, path, mode="rb", **kwargs): + path = self._strip_protocol(path) + sha = self._mapper(path) + + if "r" not in mode: + fn = os.path.join(self.storage[-1], sha) + user_specified_kwargs = { + k: v + for k, v in kwargs.items() + if k not in ["autocommit", "block_size", "cache_options"] + } # those were added by open() + return LocalTempFile( + self, + path, + mode=mode, + autocommit=not self._intrans, + fn=fn, + **user_specified_kwargs, + ) + fn = self._check_file(path) + if fn: + return open(fn, mode) + + fn = os.path.join(self.storage[-1], sha) + logger.debug("Copying %s to local cache", path) + kwargs["mode"] = mode + + self._mkcache() + self._cache_size = None + if self.compression: + with self.fs._open(path, **kwargs) as f, open(fn, "wb") as f2: + if isinstance(f, AbstractBufferedFile): + # want no type of caching if just downloading whole thing + f.cache = BaseCache(0, f.cache.fetcher, f.size) + comp = ( + infer_compression(path) + if self.compression == "infer" + else self.compression + ) + f = compr[comp](f, mode="rb") + data = True + while data: + block = getattr(f, "blocksize", 5 * 2**20) + data = f.read(block) + f2.write(data) + else: + self.fs.get_file(path, fn) + return self._open(path, mode) + + +class LocalTempFile: + """A temporary local file, which will be uploaded on commit""" + + def __init__(self, fs, path, fn, mode="wb", autocommit=True, seek=0, **kwargs): + self.fn = fn + self.fh = open(fn, mode) + self.mode = mode + if seek: + self.fh.seek(seek) + self.path = path + self.size = None + self.fs = fs + self.closed = False + self.autocommit = autocommit + self.kwargs = kwargs + + def __reduce__(self): + # always open in r+b to allow continuing writing at a location + return ( + LocalTempFile, + (self.fs, self.path, self.fn, "r+b", self.autocommit, self.tell()), + ) + + def __enter__(self): + return self.fh + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + def close(self): + # self.size = self.fh.tell() + if self.closed: + return + self.fh.close() + self.closed = True + if self.autocommit: + self.commit() + + def discard(self): + self.fh.close() + os.remove(self.fn) + + def commit(self): + self.fs.put(self.fn, self.path, **self.kwargs) + # we do not delete the local copy, it's still in the cache. + + @property + def name(self): + return self.fn + + def __repr__(self) -> str: + return f"LocalTempFile: {self.path}" + + def __getattr__(self, item): + return getattr(self.fh, item) diff --git a/fsspec/implementations/dask.py b/fsspec/implementations/dask.py new file mode 100644 index 0000000000000000000000000000000000000000..3e1276463db6866665e6a0fe114efc247971b57e --- /dev/null +++ b/fsspec/implementations/dask.py @@ -0,0 +1,152 @@ +import dask +from distributed.client import Client, _get_global_client +from distributed.worker import Worker + +from fsspec import filesystem +from fsspec.spec import AbstractBufferedFile, AbstractFileSystem +from fsspec.utils import infer_storage_options + + +def _get_client(client): + if client is None: + return _get_global_client() + elif isinstance(client, Client): + return client + else: + # e.g., connection string + return Client(client) + + +def _in_worker(): + return bool(Worker._instances) + + +class DaskWorkerFileSystem(AbstractFileSystem): + """View files accessible to a worker as any other remote file-system + + When instances are run on the worker, uses the real filesystem. When + run on the client, they call the worker to provide information or data. + + **Warning** this implementation is experimental, and read-only for now. + """ + + def __init__( + self, target_protocol=None, target_options=None, fs=None, client=None, **kwargs + ): + super().__init__(**kwargs) + if not (fs is None) ^ (target_protocol is None): + raise ValueError( + "Please provide one of filesystem instance (fs) or" + " target_protocol, not both" + ) + self.target_protocol = target_protocol + self.target_options = target_options + self.worker = None + self.client = client + self.fs = fs + self._determine_worker() + + @staticmethod + def _get_kwargs_from_urls(path): + so = infer_storage_options(path) + if "host" in so and "port" in so: + return {"client": f"{so['host']}:{so['port']}"} + else: + return {} + + def _determine_worker(self): + if _in_worker(): + self.worker = True + if self.fs is None: + self.fs = filesystem( + self.target_protocol, **(self.target_options or {}) + ) + else: + self.worker = False + self.client = _get_client(self.client) + self.rfs = dask.delayed(self) + + def mkdir(self, *args, **kwargs): + if self.worker: + self.fs.mkdir(*args, **kwargs) + else: + self.rfs.mkdir(*args, **kwargs).compute() + + def rm(self, *args, **kwargs): + if self.worker: + self.fs.rm(*args, **kwargs) + else: + self.rfs.rm(*args, **kwargs).compute() + + def copy(self, *args, **kwargs): + if self.worker: + self.fs.copy(*args, **kwargs) + else: + self.rfs.copy(*args, **kwargs).compute() + + def mv(self, *args, **kwargs): + if self.worker: + self.fs.mv(*args, **kwargs) + else: + self.rfs.mv(*args, **kwargs).compute() + + def ls(self, *args, **kwargs): + if self.worker: + return self.fs.ls(*args, **kwargs) + else: + return self.rfs.ls(*args, **kwargs).compute() + + def _open( + self, + path, + mode="rb", + block_size=None, + autocommit=True, + cache_options=None, + **kwargs, + ): + if self.worker: + return self.fs._open( + path, + mode=mode, + block_size=block_size, + autocommit=autocommit, + cache_options=cache_options, + **kwargs, + ) + else: + return DaskFile( + fs=self, + path=path, + mode=mode, + block_size=block_size, + autocommit=autocommit, + cache_options=cache_options, + **kwargs, + ) + + def fetch_range(self, path, mode, start, end): + if self.worker: + with self._open(path, mode) as f: + f.seek(start) + return f.read(end - start) + else: + return self.rfs.fetch_range(path, mode, start, end).compute() + + +class DaskFile(AbstractBufferedFile): + def __init__(self, mode="rb", **kwargs): + if mode != "rb": + raise ValueError('Remote dask files can only be opened in "rb" mode') + super().__init__(**kwargs) + + def _upload_chunk(self, final=False): + pass + + def _initiate_upload(self): + """Create remote file/upload""" + pass + + def _fetch_range(self, start, end): + """Get the specified set of bytes from remote""" + return self.fs.fetch_range(self.path, self.mode, start, end) diff --git a/fsspec/implementations/data.py b/fsspec/implementations/data.py new file mode 100644 index 0000000000000000000000000000000000000000..519032305bed633f2ba8a6148076433caf81710b --- /dev/null +++ b/fsspec/implementations/data.py @@ -0,0 +1,58 @@ +import base64 +import io +from typing import Optional +from urllib.parse import unquote + +from fsspec import AbstractFileSystem + + +class DataFileSystem(AbstractFileSystem): + """A handy decoder for data-URLs + + Example + ------- + >>> with fsspec.open("data:,Hello%2C%20World%21") as f: + ... print(f.read()) + b"Hello, World!" + + See https://developer.mozilla.org/en-US/docs/Web/HTTP/Basics_of_HTTP/Data_URLs + """ + + protocol = "data" + + def __init__(self, **kwargs): + """No parameters for this filesystem""" + super().__init__(**kwargs) + + def cat_file(self, path, start=None, end=None, **kwargs): + pref, data = path.split(",", 1) + if pref.endswith("base64"): + return base64.b64decode(data)[start:end] + return unquote(data).encode()[start:end] + + def info(self, path, **kwargs): + pref, name = path.split(",", 1) + data = self.cat_file(path) + mime = pref.split(":", 1)[1].split(";", 1)[0] + return {"name": name, "size": len(data), "type": "file", "mimetype": mime} + + def _open( + self, + path, + mode="rb", + block_size=None, + autocommit=True, + cache_options=None, + **kwargs, + ): + if "r" not in mode: + raise ValueError("Read only filesystem") + return io.BytesIO(self.cat_file(path)) + + @staticmethod + def encode(data: bytes, mime: Optional[str] = None): + """Format the given data into data-URL syntax + + This version always base64 encodes, even when the data is ascii/url-safe. + """ + return f"data:{mime or ''};base64,{base64.b64encode(data).decode()}" diff --git a/fsspec/implementations/dbfs.py b/fsspec/implementations/dbfs.py new file mode 100644 index 0000000000000000000000000000000000000000..1a7fc93d7389c894ecb5fc6267ce20abe4087068 --- /dev/null +++ b/fsspec/implementations/dbfs.py @@ -0,0 +1,496 @@ +from __future__ import annotations + +import base64 +import urllib + +import requests +from requests.adapters import HTTPAdapter, Retry +from typing_extensions import override + +from fsspec import AbstractFileSystem +from fsspec.spec import AbstractBufferedFile + + +class DatabricksException(Exception): + """ + Helper class for exceptions raised in this module. + """ + + def __init__(self, error_code, message, details=None): + """Create a new DatabricksException""" + super().__init__(message) + + self.error_code = error_code + self.message = message + self.details = details + + +class DatabricksFileSystem(AbstractFileSystem): + """ + Get access to the Databricks filesystem implementation over HTTP. + Can be used inside and outside of a databricks cluster. + """ + + def __init__(self, instance, token, **kwargs): + """ + Create a new DatabricksFileSystem. + + Parameters + ---------- + instance: str + The instance URL of the databricks cluster. + For example for an Azure databricks cluster, this + has the form adb-..azuredatabricks.net. + token: str + Your personal token. Find out more + here: https://docs.databricks.com/dev-tools/api/latest/authentication.html + """ + self.instance = instance + self.token = token + self.session = requests.Session() + self.retries = Retry( + total=10, + backoff_factor=0.05, + status_forcelist=[408, 429, 500, 502, 503, 504], + ) + + self.session.mount("https://", HTTPAdapter(max_retries=self.retries)) + self.session.headers.update({"Authorization": f"Bearer {self.token}"}) + + super().__init__(**kwargs) + + @override + def _ls_from_cache(self, path) -> list[dict[str, str | int]] | None: + """Check cache for listing + + Returns listing, if found (may be empty list for a directory that + exists but contains nothing), None if not in cache. + """ + self.dircache.pop(path.rstrip("/"), None) + + parent = self._parent(path) + if parent in self.dircache: + for entry in self.dircache[parent]: + if entry["name"] == path.rstrip("/"): + if entry["type"] != "directory": + return [entry] + return [] + raise FileNotFoundError(path) + + def ls(self, path, detail=True, **kwargs): + """ + List the contents of the given path. + + Parameters + ---------- + path: str + Absolute path + detail: bool + Return not only the list of filenames, + but also additional information on file sizes + and types. + """ + try: + out = self._ls_from_cache(path) + except FileNotFoundError: + # This happens if the `path`'s parent was cached, but `path` is not + # there. This suggests that `path` is new since the parent was + # cached. Attempt to invalidate parent's cache before continuing. + self.dircache.pop(self._parent(path), None) + out = None + + if not out: + try: + r = self._send_to_api( + method="get", endpoint="list", json={"path": path} + ) + except DatabricksException as e: + if e.error_code == "RESOURCE_DOES_NOT_EXIST": + raise FileNotFoundError(e.message) from e + + raise + files = r.get("files", []) + out = [ + { + "name": o["path"], + "type": "directory" if o["is_dir"] else "file", + "size": o["file_size"], + } + for o in files + ] + self.dircache[path] = out + + if detail: + return out + return [o["name"] for o in out] + + def makedirs(self, path, exist_ok=True): + """ + Create a given absolute path and all of its parents. + + Parameters + ---------- + path: str + Absolute path to create + exist_ok: bool + If false, checks if the folder + exists before creating it (and raises an + Exception if this is the case) + """ + if not exist_ok: + try: + # If the following succeeds, the path is already present + self._send_to_api( + method="get", endpoint="get-status", json={"path": path} + ) + raise FileExistsError(f"Path {path} already exists") + except DatabricksException as e: + if e.error_code == "RESOURCE_DOES_NOT_EXIST": + pass + + try: + self._send_to_api(method="post", endpoint="mkdirs", json={"path": path}) + except DatabricksException as e: + if e.error_code == "RESOURCE_ALREADY_EXISTS": + raise FileExistsError(e.message) from e + + raise + self.invalidate_cache(self._parent(path)) + + def mkdir(self, path, create_parents=True, **kwargs): + """ + Create a given absolute path and all of its parents. + + Parameters + ---------- + path: str + Absolute path to create + create_parents: bool + Whether to create all parents or not. + "False" is not implemented so far. + """ + if not create_parents: + raise NotImplementedError + + self.mkdirs(path, **kwargs) + + def rm(self, path, recursive=False, **kwargs): + """ + Remove the file or folder at the given absolute path. + + Parameters + ---------- + path: str + Absolute path what to remove + recursive: bool + Recursively delete all files in a folder. + """ + try: + self._send_to_api( + method="post", + endpoint="delete", + json={"path": path, "recursive": recursive}, + ) + except DatabricksException as e: + # This is not really an exception, it just means + # not everything was deleted so far + if e.error_code == "PARTIAL_DELETE": + self.rm(path=path, recursive=recursive) + elif e.error_code == "IO_ERROR": + # Using the same exception as the os module would use here + raise OSError(e.message) from e + + raise + self.invalidate_cache(self._parent(path)) + + def mv( + self, source_path, destination_path, recursive=False, maxdepth=None, **kwargs + ): + """ + Move a source to a destination path. + + A note from the original [databricks API manual] + (https://docs.databricks.com/dev-tools/api/latest/dbfs.html#move). + + When moving a large number of files the API call will time out after + approximately 60s, potentially resulting in partially moved data. + Therefore, for operations that move more than 10k files, we strongly + discourage using the DBFS REST API. + + Parameters + ---------- + source_path: str + From where to move (absolute path) + destination_path: str + To where to move (absolute path) + recursive: bool + Not implemented to far. + maxdepth: + Not implemented to far. + """ + if recursive: + raise NotImplementedError + if maxdepth: + raise NotImplementedError + + try: + self._send_to_api( + method="post", + endpoint="move", + json={"source_path": source_path, "destination_path": destination_path}, + ) + except DatabricksException as e: + if e.error_code == "RESOURCE_DOES_NOT_EXIST": + raise FileNotFoundError(e.message) from e + elif e.error_code == "RESOURCE_ALREADY_EXISTS": + raise FileExistsError(e.message) from e + + raise + self.invalidate_cache(self._parent(source_path)) + self.invalidate_cache(self._parent(destination_path)) + + def _open(self, path, mode="rb", block_size="default", **kwargs): + """ + Overwrite the base class method to make sure to create a DBFile. + All arguments are copied from the base method. + + Only the default blocksize is allowed. + """ + return DatabricksFile(self, path, mode=mode, block_size=block_size, **kwargs) + + def _send_to_api(self, method, endpoint, json): + """ + Send the given json to the DBFS API + using a get or post request (specified by the argument `method`). + + Parameters + ---------- + method: str + Which http method to use for communication; "get" or "post". + endpoint: str + Where to send the request to (last part of the API URL) + json: dict + Dictionary of information to send + """ + if method == "post": + session_call = self.session.post + elif method == "get": + session_call = self.session.get + else: + raise ValueError(f"Do not understand method {method}") + + url = urllib.parse.urljoin(f"https://{self.instance}/api/2.0/dbfs/", endpoint) + + r = session_call(url, json=json) + + # The DBFS API will return a json, also in case of an exception. + # We want to preserve this information as good as possible. + try: + r.raise_for_status() + except requests.HTTPError as e: + # try to extract json error message + # if that fails, fall back to the original exception + try: + exception_json = e.response.json() + except Exception: + raise e from None + + raise DatabricksException(**exception_json) from e + + return r.json() + + def _create_handle(self, path, overwrite=True): + """ + Internal function to create a handle, which can be used to + write blocks of a file to DBFS. + A handle has a unique identifier which needs to be passed + whenever written during this transaction. + The handle is active for 10 minutes - after that a new + write transaction needs to be created. + Make sure to close the handle after you are finished. + + Parameters + ---------- + path: str + Absolute path for this file. + overwrite: bool + If a file already exist at this location, either overwrite + it or raise an exception. + """ + try: + r = self._send_to_api( + method="post", + endpoint="create", + json={"path": path, "overwrite": overwrite}, + ) + return r["handle"] + except DatabricksException as e: + if e.error_code == "RESOURCE_ALREADY_EXISTS": + raise FileExistsError(e.message) from e + + raise + + def _close_handle(self, handle): + """ + Close a handle, which was opened by :func:`_create_handle`. + + Parameters + ---------- + handle: str + Which handle to close. + """ + try: + self._send_to_api(method="post", endpoint="close", json={"handle": handle}) + except DatabricksException as e: + if e.error_code == "RESOURCE_DOES_NOT_EXIST": + raise FileNotFoundError(e.message) from e + + raise + + def _add_data(self, handle, data): + """ + Upload data to an already opened file handle + (opened by :func:`_create_handle`). + The maximal allowed data size is 1MB after + conversion to base64. + Remember to close the handle when you are finished. + + Parameters + ---------- + handle: str + Which handle to upload data to. + data: bytes + Block of data to add to the handle. + """ + data = base64.b64encode(data).decode() + try: + self._send_to_api( + method="post", + endpoint="add-block", + json={"handle": handle, "data": data}, + ) + except DatabricksException as e: + if e.error_code == "RESOURCE_DOES_NOT_EXIST": + raise FileNotFoundError(e.message) from e + elif e.error_code == "MAX_BLOCK_SIZE_EXCEEDED": + raise ValueError(e.message) from e + + raise + + def _get_data(self, path, start, end): + """ + Download data in bytes from a given absolute path in a block + from [start, start+length]. + The maximum number of allowed bytes to read is 1MB. + + Parameters + ---------- + path: str + Absolute path to download data from + start: int + Start position of the block + end: int + End position of the block + """ + try: + r = self._send_to_api( + method="get", + endpoint="read", + json={"path": path, "offset": start, "length": end - start}, + ) + return base64.b64decode(r["data"]) + except DatabricksException as e: + if e.error_code == "RESOURCE_DOES_NOT_EXIST": + raise FileNotFoundError(e.message) from e + elif e.error_code in ["INVALID_PARAMETER_VALUE", "MAX_READ_SIZE_EXCEEDED"]: + raise ValueError(e.message) from e + + raise + + def invalidate_cache(self, path=None): + if path is None: + self.dircache.clear() + else: + self.dircache.pop(path, None) + super().invalidate_cache(path) + + +class DatabricksFile(AbstractBufferedFile): + """ + Helper class for files referenced in the DatabricksFileSystem. + """ + + DEFAULT_BLOCK_SIZE = 1 * 2**20 # only allowed block size + + def __init__( + self, + fs, + path, + mode="rb", + block_size="default", + autocommit=True, + cache_type="readahead", + cache_options=None, + **kwargs, + ): + """ + Create a new instance of the DatabricksFile. + + The blocksize needs to be the default one. + """ + if block_size is None or block_size == "default": + block_size = self.DEFAULT_BLOCK_SIZE + + assert block_size == self.DEFAULT_BLOCK_SIZE, ( + f"Only the default block size is allowed, not {block_size}" + ) + + super().__init__( + fs, + path, + mode=mode, + block_size=block_size, + autocommit=autocommit, + cache_type=cache_type, + cache_options=cache_options or {}, + **kwargs, + ) + + def _initiate_upload(self): + """Internal function to start a file upload""" + self.handle = self.fs._create_handle(self.path) + + def _upload_chunk(self, final=False): + """Internal function to add a chunk of data to a started upload""" + self.buffer.seek(0) + data = self.buffer.getvalue() + + data_chunks = [ + data[start:end] for start, end in self._to_sized_blocks(len(data)) + ] + + for data_chunk in data_chunks: + self.fs._add_data(handle=self.handle, data=data_chunk) + + if final: + self.fs._close_handle(handle=self.handle) + return True + + def _fetch_range(self, start, end): + """Internal function to download a block of data""" + return_buffer = b"" + length = end - start + for chunk_start, chunk_end in self._to_sized_blocks(length, start): + return_buffer += self.fs._get_data( + path=self.path, start=chunk_start, end=chunk_end + ) + + return return_buffer + + def _to_sized_blocks(self, length, start=0): + """Helper function to split a range from 0 to total_length into blocksizes""" + end = start + length + for data_chunk in range(start, end, self.blocksize): + data_start = data_chunk + data_end = min(end, data_chunk + self.blocksize) + yield data_start, data_end diff --git a/fsspec/implementations/dirfs.py b/fsspec/implementations/dirfs.py new file mode 100644 index 0000000000000000000000000000000000000000..c0623b82fc61ed4780fdbc1d69680f2b96f2b9f3 --- /dev/null +++ b/fsspec/implementations/dirfs.py @@ -0,0 +1,388 @@ +from .. import filesystem +from ..asyn import AsyncFileSystem + + +class DirFileSystem(AsyncFileSystem): + """Directory prefix filesystem + + The DirFileSystem is a filesystem-wrapper. It assumes every path it is dealing with + is relative to the `path`. After performing the necessary paths operation it + delegates everything to the wrapped filesystem. + """ + + protocol = "dir" + + def __init__( + self, + path=None, + fs=None, + fo=None, + target_protocol=None, + target_options=None, + **storage_options, + ): + """ + Parameters + ---------- + path: str + Path to the directory. + fs: AbstractFileSystem + An instantiated filesystem to wrap. + target_protocol, target_options: + if fs is none, construct it from these + fo: str + Alternate for path; do not provide both + """ + super().__init__(**storage_options) + if fs is None: + fs = filesystem(protocol=target_protocol, **(target_options or {})) + path = path or fo + + if self.asynchronous and not fs.async_impl: + raise ValueError("can't use asynchronous with non-async fs") + + if fs.async_impl and self.asynchronous != fs.asynchronous: + raise ValueError("both dirfs and fs should be in the same sync/async mode") + + self.path = fs._strip_protocol(path) + self.fs = fs + + def _join(self, path): + if isinstance(path, str): + if not self.path: + return path + if not path: + return self.path + return self.fs.sep.join((self.path, self._strip_protocol(path))) + if isinstance(path, dict): + return {self._join(_path): value for _path, value in path.items()} + return [self._join(_path) for _path in path] + + def _relpath(self, path): + if isinstance(path, str): + if not self.path: + return path + # We need to account for S3FileSystem returning paths that do not + # start with a '/' + if path == self.path or ( + self.path.startswith(self.fs.sep) and path == self.path[1:] + ): + return "" + prefix = self.path + self.fs.sep + if self.path.startswith(self.fs.sep) and not path.startswith(self.fs.sep): + prefix = prefix[1:] + assert path.startswith(prefix) + return path[len(prefix) :] + return [self._relpath(_path) for _path in path] + + # Wrappers below + + @property + def sep(self): + return self.fs.sep + + async def set_session(self, *args, **kwargs): + return await self.fs.set_session(*args, **kwargs) + + async def _rm_file(self, path, **kwargs): + return await self.fs._rm_file(self._join(path), **kwargs) + + def rm_file(self, path, **kwargs): + return self.fs.rm_file(self._join(path), **kwargs) + + async def _rm(self, path, *args, **kwargs): + return await self.fs._rm(self._join(path), *args, **kwargs) + + def rm(self, path, *args, **kwargs): + return self.fs.rm(self._join(path), *args, **kwargs) + + async def _cp_file(self, path1, path2, **kwargs): + return await self.fs._cp_file(self._join(path1), self._join(path2), **kwargs) + + def cp_file(self, path1, path2, **kwargs): + return self.fs.cp_file(self._join(path1), self._join(path2), **kwargs) + + async def _copy( + self, + path1, + path2, + *args, + **kwargs, + ): + return await self.fs._copy( + self._join(path1), + self._join(path2), + *args, + **kwargs, + ) + + def copy(self, path1, path2, *args, **kwargs): + return self.fs.copy( + self._join(path1), + self._join(path2), + *args, + **kwargs, + ) + + async def _pipe(self, path, *args, **kwargs): + return await self.fs._pipe(self._join(path), *args, **kwargs) + + def pipe(self, path, *args, **kwargs): + return self.fs.pipe(self._join(path), *args, **kwargs) + + async def _pipe_file(self, path, *args, **kwargs): + return await self.fs._pipe_file(self._join(path), *args, **kwargs) + + def pipe_file(self, path, *args, **kwargs): + return self.fs.pipe_file(self._join(path), *args, **kwargs) + + async def _cat_file(self, path, *args, **kwargs): + return await self.fs._cat_file(self._join(path), *args, **kwargs) + + def cat_file(self, path, *args, **kwargs): + return self.fs.cat_file(self._join(path), *args, **kwargs) + + async def _cat(self, path, *args, **kwargs): + ret = await self.fs._cat( + self._join(path), + *args, + **kwargs, + ) + + if isinstance(ret, dict): + return {self._relpath(key): value for key, value in ret.items()} + + return ret + + def cat(self, path, *args, **kwargs): + ret = self.fs.cat( + self._join(path), + *args, + **kwargs, + ) + + if isinstance(ret, dict): + return {self._relpath(key): value for key, value in ret.items()} + + return ret + + async def _put_file(self, lpath, rpath, **kwargs): + return await self.fs._put_file(lpath, self._join(rpath), **kwargs) + + def put_file(self, lpath, rpath, **kwargs): + return self.fs.put_file(lpath, self._join(rpath), **kwargs) + + async def _put( + self, + lpath, + rpath, + *args, + **kwargs, + ): + return await self.fs._put( + lpath, + self._join(rpath), + *args, + **kwargs, + ) + + def put(self, lpath, rpath, *args, **kwargs): + return self.fs.put( + lpath, + self._join(rpath), + *args, + **kwargs, + ) + + async def _get_file(self, rpath, lpath, **kwargs): + return await self.fs._get_file(self._join(rpath), lpath, **kwargs) + + def get_file(self, rpath, lpath, **kwargs): + return self.fs.get_file(self._join(rpath), lpath, **kwargs) + + async def _get(self, rpath, *args, **kwargs): + return await self.fs._get(self._join(rpath), *args, **kwargs) + + def get(self, rpath, *args, **kwargs): + return self.fs.get(self._join(rpath), *args, **kwargs) + + async def _isfile(self, path): + return await self.fs._isfile(self._join(path)) + + def isfile(self, path): + return self.fs.isfile(self._join(path)) + + async def _isdir(self, path): + return await self.fs._isdir(self._join(path)) + + def isdir(self, path): + return self.fs.isdir(self._join(path)) + + async def _size(self, path): + return await self.fs._size(self._join(path)) + + def size(self, path): + return self.fs.size(self._join(path)) + + async def _exists(self, path): + return await self.fs._exists(self._join(path)) + + def exists(self, path): + return self.fs.exists(self._join(path)) + + async def _info(self, path, **kwargs): + info = await self.fs._info(self._join(path), **kwargs) + info = info.copy() + info["name"] = self._relpath(info["name"]) + return info + + def info(self, path, **kwargs): + info = self.fs.info(self._join(path), **kwargs) + info = info.copy() + info["name"] = self._relpath(info["name"]) + return info + + async def _ls(self, path, detail=True, **kwargs): + ret = (await self.fs._ls(self._join(path), detail=detail, **kwargs)).copy() + if detail: + out = [] + for entry in ret: + entry = entry.copy() + entry["name"] = self._relpath(entry["name"]) + out.append(entry) + return out + + return self._relpath(ret) + + def ls(self, path, detail=True, **kwargs): + ret = self.fs.ls(self._join(path), detail=detail, **kwargs).copy() + if detail: + out = [] + for entry in ret: + entry = entry.copy() + entry["name"] = self._relpath(entry["name"]) + out.append(entry) + return out + + return self._relpath(ret) + + async def _walk(self, path, *args, **kwargs): + async for root, dirs, files in self.fs._walk(self._join(path), *args, **kwargs): + yield self._relpath(root), dirs, files + + def walk(self, path, *args, **kwargs): + for root, dirs, files in self.fs.walk(self._join(path), *args, **kwargs): + yield self._relpath(root), dirs, files + + async def _glob(self, path, **kwargs): + detail = kwargs.get("detail", False) + ret = await self.fs._glob(self._join(path), **kwargs) + if detail: + return {self._relpath(path): info for path, info in ret.items()} + return self._relpath(ret) + + def glob(self, path, **kwargs): + detail = kwargs.get("detail", False) + ret = self.fs.glob(self._join(path), **kwargs) + if detail: + return {self._relpath(path): info for path, info in ret.items()} + return self._relpath(ret) + + async def _du(self, path, *args, **kwargs): + total = kwargs.get("total", True) + ret = await self.fs._du(self._join(path), *args, **kwargs) + if total: + return ret + + return {self._relpath(path): size for path, size in ret.items()} + + def du(self, path, *args, **kwargs): + total = kwargs.get("total", True) + ret = self.fs.du(self._join(path), *args, **kwargs) + if total: + return ret + + return {self._relpath(path): size for path, size in ret.items()} + + async def _find(self, path, *args, **kwargs): + detail = kwargs.get("detail", False) + ret = await self.fs._find(self._join(path), *args, **kwargs) + if detail: + return {self._relpath(path): info for path, info in ret.items()} + return self._relpath(ret) + + def find(self, path, *args, **kwargs): + detail = kwargs.get("detail", False) + ret = self.fs.find(self._join(path), *args, **kwargs) + if detail: + return {self._relpath(path): info for path, info in ret.items()} + return self._relpath(ret) + + async def _expand_path(self, path, *args, **kwargs): + return self._relpath( + await self.fs._expand_path(self._join(path), *args, **kwargs) + ) + + def expand_path(self, path, *args, **kwargs): + return self._relpath(self.fs.expand_path(self._join(path), *args, **kwargs)) + + async def _mkdir(self, path, *args, **kwargs): + return await self.fs._mkdir(self._join(path), *args, **kwargs) + + def mkdir(self, path, *args, **kwargs): + return self.fs.mkdir(self._join(path), *args, **kwargs) + + async def _makedirs(self, path, *args, **kwargs): + return await self.fs._makedirs(self._join(path), *args, **kwargs) + + def makedirs(self, path, *args, **kwargs): + return self.fs.makedirs(self._join(path), *args, **kwargs) + + def rmdir(self, path): + return self.fs.rmdir(self._join(path)) + + def mv(self, path1, path2, **kwargs): + return self.fs.mv( + self._join(path1), + self._join(path2), + **kwargs, + ) + + def touch(self, path, **kwargs): + return self.fs.touch(self._join(path), **kwargs) + + def created(self, path): + return self.fs.created(self._join(path)) + + def modified(self, path): + return self.fs.modified(self._join(path)) + + def sign(self, path, *args, **kwargs): + return self.fs.sign(self._join(path), *args, **kwargs) + + def __repr__(self): + return f"{self.__class__.__qualname__}(path='{self.path}', fs={self.fs})" + + def open( + self, + path, + *args, + **kwargs, + ): + return self.fs.open( + self._join(path), + *args, + **kwargs, + ) + + async def open_async( + self, + path, + *args, + **kwargs, + ): + return await self.fs.open_async( + self._join(path), + *args, + **kwargs, + ) diff --git a/fsspec/implementations/ftp.py b/fsspec/implementations/ftp.py new file mode 100644 index 0000000000000000000000000000000000000000..a3db22b04a00ddf12582253ec19b2938c794c1da --- /dev/null +++ b/fsspec/implementations/ftp.py @@ -0,0 +1,387 @@ +import os +import uuid +from ftplib import FTP, FTP_TLS, Error, error_perm +from typing import Any + +from ..spec import AbstractBufferedFile, AbstractFileSystem +from ..utils import infer_storage_options, isfilelike + + +class FTPFileSystem(AbstractFileSystem): + """A filesystem over classic FTP""" + + root_marker = "/" + cachable = False + protocol = "ftp" + + def __init__( + self, + host, + port=21, + username=None, + password=None, + acct=None, + block_size=None, + tempdir=None, + timeout=30, + encoding="utf-8", + tls=False, + **kwargs, + ): + """ + You can use _get_kwargs_from_urls to get some kwargs from + a reasonable FTP url. + + Authentication will be anonymous if username/password are not + given. + + Parameters + ---------- + host: str + The remote server name/ip to connect to + port: int + Port to connect with + username: str or None + If authenticating, the user's identifier + password: str of None + User's password on the server, if using + acct: str or None + Some servers also need an "account" string for auth + block_size: int or None + If given, the read-ahead or write buffer size. + tempdir: str + Directory on remote to put temporary files when in a transaction + timeout: int + Timeout of the ftp connection in seconds + encoding: str + Encoding to use for directories and filenames in FTP connection + tls: bool + Use FTP-TLS, by default False + """ + super().__init__(**kwargs) + self.host = host + self.port = port + self.tempdir = tempdir or "/tmp" + self.cred = username or "", password or "", acct or "" + self.timeout = timeout + self.encoding = encoding + if block_size is not None: + self.blocksize = block_size + else: + self.blocksize = 2**16 + self.tls = tls + self._connect() + if self.tls: + self.ftp.prot_p() + + def _connect(self): + if self.tls: + ftp_cls = FTP_TLS + else: + ftp_cls = FTP + self.ftp = ftp_cls(timeout=self.timeout, encoding=self.encoding) + self.ftp.connect(self.host, self.port) + self.ftp.login(*self.cred) + + @classmethod + def _strip_protocol(cls, path): + return "/" + infer_storage_options(path)["path"].lstrip("/").rstrip("/") + + @staticmethod + def _get_kwargs_from_urls(urlpath): + out = infer_storage_options(urlpath) + out.pop("path", None) + out.pop("protocol", None) + return out + + def ls(self, path, detail=True, **kwargs): + path = self._strip_protocol(path) + out = [] + if path not in self.dircache: + try: + try: + out = [ + (fn, details) + for (fn, details) in self.ftp.mlsd(path) + if fn not in [".", ".."] + and details["type"] not in ["pdir", "cdir"] + ] + except error_perm: + out = _mlsd2(self.ftp, path) # Not platform independent + for fn, details in out: + details["name"] = "/".join( + ["" if path == "/" else path, fn.lstrip("/")] + ) + if details["type"] == "file": + details["size"] = int(details["size"]) + else: + details["size"] = 0 + if details["type"] == "dir": + details["type"] = "directory" + self.dircache[path] = out + except Error: + try: + info = self.info(path) + if info["type"] == "file": + out = [(path, info)] + except (Error, IndexError) as exc: + raise FileNotFoundError(path) from exc + files = self.dircache.get(path, out) + if not detail: + return sorted([fn for fn, details in files]) + return [details for fn, details in files] + + def info(self, path, **kwargs): + # implement with direct method + path = self._strip_protocol(path) + if path == "/": + # special case, since this dir has no real entry + return {"name": "/", "size": 0, "type": "directory"} + files = self.ls(self._parent(path).lstrip("/"), True) + try: + out = next(f for f in files if f["name"] == path) + except StopIteration as exc: + raise FileNotFoundError(path) from exc + return out + + def get_file(self, rpath, lpath, **kwargs): + if self.isdir(rpath): + if not os.path.exists(lpath): + os.mkdir(lpath) + return + if isfilelike(lpath): + outfile = lpath + else: + outfile = open(lpath, "wb") + + def cb(x): + outfile.write(x) + + self.ftp.retrbinary( + f"RETR {rpath}", + blocksize=self.blocksize, + callback=cb, + ) + if not isfilelike(lpath): + outfile.close() + + def cat_file(self, path, start=None, end=None, **kwargs): + if end is not None: + return super().cat_file(path, start, end, **kwargs) + out = [] + + def cb(x): + out.append(x) + + try: + self.ftp.retrbinary( + f"RETR {path}", + blocksize=self.blocksize, + rest=start, + callback=cb, + ) + except (Error, error_perm) as orig_exc: + raise FileNotFoundError(path) from orig_exc + return b"".join(out) + + def _open( + self, + path, + mode="rb", + block_size=None, + cache_options=None, + autocommit=True, + **kwargs, + ): + path = self._strip_protocol(path) + block_size = block_size or self.blocksize + return FTPFile( + self, + path, + mode=mode, + block_size=block_size, + tempdir=self.tempdir, + autocommit=autocommit, + cache_options=cache_options, + ) + + def _rm(self, path): + path = self._strip_protocol(path) + self.ftp.delete(path) + self.invalidate_cache(self._parent(path)) + + def rm(self, path, recursive=False, maxdepth=None): + paths = self.expand_path(path, recursive=recursive, maxdepth=maxdepth) + for p in reversed(paths): + if self.isfile(p): + self.rm_file(p) + else: + self.rmdir(p) + + def mkdir(self, path: str, create_parents: bool = True, **kwargs: Any) -> None: + path = self._strip_protocol(path) + parent = self._parent(path) + if parent != self.root_marker and not self.exists(parent) and create_parents: + self.mkdir(parent, create_parents=create_parents) + + self.ftp.mkd(path) + self.invalidate_cache(self._parent(path)) + + def makedirs(self, path: str, exist_ok: bool = False) -> None: + path = self._strip_protocol(path) + if self.exists(path): + # NB: "/" does not "exist" as it has no directory entry + if not exist_ok: + raise FileExistsError(f"{path} exists without `exist_ok`") + # exists_ok=True -> no-op + else: + self.mkdir(path, create_parents=True) + + def rmdir(self, path): + path = self._strip_protocol(path) + self.ftp.rmd(path) + self.invalidate_cache(self._parent(path)) + + def mv(self, path1, path2, **kwargs): + path1 = self._strip_protocol(path1) + path2 = self._strip_protocol(path2) + self.ftp.rename(path1, path2) + self.invalidate_cache(self._parent(path1)) + self.invalidate_cache(self._parent(path2)) + + def __del__(self): + self.ftp.close() + + def invalidate_cache(self, path=None): + if path is None: + self.dircache.clear() + else: + self.dircache.pop(path, None) + super().invalidate_cache(path) + + +class TransferDone(Exception): + """Internal exception to break out of transfer""" + + pass + + +class FTPFile(AbstractBufferedFile): + """Interact with a remote FTP file with read/write buffering""" + + def __init__( + self, + fs, + path, + mode="rb", + block_size="default", + autocommit=True, + cache_type="readahead", + cache_options=None, + **kwargs, + ): + super().__init__( + fs, + path, + mode=mode, + block_size=block_size, + autocommit=autocommit, + cache_type=cache_type, + cache_options=cache_options, + **kwargs, + ) + if not autocommit: + self.target = self.path + self.path = "/".join([kwargs["tempdir"], str(uuid.uuid4())]) + + def commit(self): + self.fs.mv(self.path, self.target) + + def discard(self): + self.fs.rm(self.path) + + def _fetch_range(self, start, end): + """Get bytes between given byte limits + + Implemented by raising an exception in the fetch callback when the + number of bytes received reaches the requested amount. + + Will fail if the server does not respect the REST command on + retrieve requests. + """ + out = [] + total = [0] + + def callback(x): + total[0] += len(x) + if total[0] > end - start: + out.append(x[: (end - start) - total[0]]) + if end < self.size: + raise TransferDone + else: + out.append(x) + + if total[0] == end - start and end < self.size: + raise TransferDone + + try: + self.fs.ftp.retrbinary( + f"RETR {self.path}", + blocksize=self.blocksize, + rest=start, + callback=callback, + ) + except TransferDone: + try: + # stop transfer, we got enough bytes for this block + self.fs.ftp.abort() + self.fs.ftp.getmultiline() + except Error: + self.fs._connect() + + return b"".join(out) + + def _upload_chunk(self, final=False): + self.buffer.seek(0) + self.fs.ftp.storbinary( + f"STOR {self.path}", self.buffer, blocksize=self.blocksize, rest=self.offset + ) + return True + + +def _mlsd2(ftp, path="."): + """ + Fall back to using `dir` instead of `mlsd` if not supported. + + This parses a Linux style `ls -l` response to `dir`, but the response may + be platform dependent. + + Parameters + ---------- + ftp: ftplib.FTP + path: str + Expects to be given path, but defaults to ".". + """ + lines = [] + minfo = [] + ftp.dir(path, lines.append) + for line in lines: + split_line = line.split() + if len(split_line) < 9: + continue + this = ( + split_line[-1], + { + "modify": " ".join(split_line[5:8]), + "unix.owner": split_line[2], + "unix.group": split_line[3], + "unix.mode": split_line[0], + "size": split_line[4], + }, + ) + if this[1]["unix.mode"][0] == "d": + this[1]["type"] = "dir" + else: + this[1]["type"] = "file" + minfo.append(this) + return minfo diff --git a/fsspec/implementations/gist.py b/fsspec/implementations/gist.py new file mode 100644 index 0000000000000000000000000000000000000000..74117b544b02108fd2a7be06c716df477481ed47 --- /dev/null +++ b/fsspec/implementations/gist.py @@ -0,0 +1,232 @@ +import requests + +from ..spec import AbstractFileSystem +from ..utils import infer_storage_options +from .memory import MemoryFile + + +class GistFileSystem(AbstractFileSystem): + """ + Interface to files in a single GitHub Gist. + + Provides read-only access to a gist's files. Gists do not contain + subdirectories, so file listing is straightforward. + + Parameters + ---------- + gist_id : str + The ID of the gist you want to access (the long hex value from the URL). + filenames : list[str] (optional) + If provided, only make a file system representing these files, and do not fetch + the list of all files for this gist. + sha : str (optional) + If provided, fetch a particular revision of the gist. If omitted, + the latest revision is used. + username : str (optional) + GitHub username for authentication (required if token is given). + token : str (optional) + GitHub personal access token (required if username is given). + timeout : (float, float) or float, optional + Connect and read timeouts for requests (default 60s each). + kwargs : dict + Stored on `self.request_kw` and passed to `requests.get` when fetching Gist + metadata or reading ("opening") a file. + """ + + protocol = "gist" + gist_url = "https://api.github.com/gists/{gist_id}" + gist_rev_url = "https://api.github.com/gists/{gist_id}/{sha}" + + def __init__( + self, + gist_id, + filenames=None, + sha=None, + username=None, + token=None, + timeout=None, + **kwargs, + ): + super().__init__() + self.gist_id = gist_id + self.filenames = filenames + self.sha = sha # revision of the gist (optional) + if (username is None) ^ (token is None): + # Both or neither must be set + if username or token: + raise ValueError("Auth requires both username and token, or neither.") + self.username = username + self.token = token + self.request_kw = kwargs + # Default timeouts to 60s connect/read if none provided + self.timeout = timeout if timeout is not None else (60, 60) + + # We use a single-level "directory" cache, because a gist is essentially flat + self.dircache[""] = self._fetch_file_list() + + @property + def kw(self): + """Auth parameters passed to 'requests' if we have username/token.""" + if self.username is not None and self.token is not None: + return {"auth": (self.username, self.token), **self.request_kw} + return self.request_kw + + def _fetch_gist_metadata(self): + """ + Fetch the JSON metadata for this gist (possibly for a specific revision). + """ + if self.sha: + url = self.gist_rev_url.format(gist_id=self.gist_id, sha=self.sha) + else: + url = self.gist_url.format(gist_id=self.gist_id) + + r = requests.get(url, timeout=self.timeout, **self.kw) + if r.status_code == 404: + raise FileNotFoundError( + f"Gist not found: {self.gist_id}@{self.sha or 'latest'}" + ) + r.raise_for_status() + return r.json() + + def _fetch_file_list(self): + """ + Returns a list of dicts describing each file in the gist. These get stored + in self.dircache[""]. + """ + meta = self._fetch_gist_metadata() + if self.filenames: + available_files = meta.get("files", {}) + files = {} + for fn in self.filenames: + if fn not in available_files: + raise FileNotFoundError(fn) + files[fn] = available_files[fn] + else: + files = meta.get("files", {}) + + out = [] + for fname, finfo in files.items(): + if finfo is None: + # Occasionally GitHub returns a file entry with null if it was deleted + continue + # Build a directory entry + out.append( + { + "name": fname, # file's name + "type": "file", # gists have no subdirectories + "size": finfo.get("size", 0), # file size in bytes + "raw_url": finfo.get("raw_url"), + } + ) + return out + + @classmethod + def _strip_protocol(cls, path): + """ + Remove 'gist://' from the path, if present. + """ + # The default infer_storage_options can handle gist://username:token@id/file + # or gist://id/file, but let's ensure we handle a normal usage too. + # We'll just strip the protocol prefix if it exists. + path = infer_storage_options(path).get("path", path) + return path.lstrip("/") + + @staticmethod + def _get_kwargs_from_urls(path): + """ + Parse 'gist://' style URLs into GistFileSystem constructor kwargs. + For example: + gist://:TOKEN@/file.txt + gist://username:TOKEN@/file.txt + """ + so = infer_storage_options(path) + out = {} + if "username" in so and so["username"]: + out["username"] = so["username"] + if "password" in so and so["password"]: + out["token"] = so["password"] + if "host" in so and so["host"]: + # We interpret 'host' as the gist ID + out["gist_id"] = so["host"] + + # Extract SHA and filename from path + if "path" in so and so["path"]: + path_parts = so["path"].rsplit("/", 2)[-2:] + if len(path_parts) == 2: + if path_parts[0]: # SHA present + out["sha"] = path_parts[0] + if path_parts[1]: # filename also present + out["filenames"] = [path_parts[1]] + + return out + + def ls(self, path="", detail=False, **kwargs): + """ + List files in the gist. Gists are single-level, so any 'path' is basically + the filename, or empty for all files. + + Parameters + ---------- + path : str, optional + The filename to list. If empty, returns all files in the gist. + detail : bool, default False + If True, return a list of dicts; if False, return a list of filenames. + """ + path = self._strip_protocol(path or "") + # If path is empty, return all + if path == "": + results = self.dircache[""] + else: + # We want just the single file with this name + all_files = self.dircache[""] + results = [f for f in all_files if f["name"] == path] + if not results: + raise FileNotFoundError(path) + if detail: + return results + else: + return sorted(f["name"] for f in results) + + def _open(self, path, mode="rb", block_size=None, **kwargs): + """ + Read a single file from the gist. + """ + if mode != "rb": + raise NotImplementedError("GitHub Gist FS is read-only (no write).") + + path = self._strip_protocol(path) + # Find the file entry in our dircache + matches = [f for f in self.dircache[""] if f["name"] == path] + if not matches: + raise FileNotFoundError(path) + finfo = matches[0] + + raw_url = finfo.get("raw_url") + if not raw_url: + raise FileNotFoundError(f"No raw_url for file: {path}") + + r = requests.get(raw_url, timeout=self.timeout, **self.kw) + if r.status_code == 404: + raise FileNotFoundError(path) + r.raise_for_status() + return MemoryFile(path, None, r.content) + + def cat(self, path, recursive=False, on_error="raise", **kwargs): + """ + Return {path: contents} for the given file or files. If 'recursive' is True, + and path is empty, returns all files in the gist. + """ + paths = self.expand_path(path, recursive=recursive) + out = {} + for p in paths: + try: + with self.open(p, "rb") as f: + out[p] = f.read() + except FileNotFoundError as e: + if on_error == "raise": + raise e + elif on_error == "omit": + pass # skip + else: + out[p] = e + return out diff --git a/fsspec/implementations/git.py b/fsspec/implementations/git.py new file mode 100644 index 0000000000000000000000000000000000000000..808d293a1c991ea87d19a2129f3e56d9b813daaa --- /dev/null +++ b/fsspec/implementations/git.py @@ -0,0 +1,114 @@ +import os + +import pygit2 + +from fsspec.spec import AbstractFileSystem + +from .memory import MemoryFile + + +class GitFileSystem(AbstractFileSystem): + """Browse the files of a local git repo at any hash/tag/branch + + (experimental backend) + """ + + root_marker = "" + cachable = True + + def __init__(self, path=None, fo=None, ref=None, **kwargs): + """ + + Parameters + ---------- + path: str (optional) + Local location of the repo (uses current directory if not given). + May be deprecated in favour of ``fo``. When used with a higher + level function such as fsspec.open(), may be of the form + "git://[path-to-repo[:]][ref@]path/to/file" (but the actual + file path should not contain "@" or ":"). + fo: str (optional) + Same as ``path``, but passed as part of a chained URL. This one + takes precedence if both are given. + ref: str (optional) + Reference to work with, could be a hash, tag or branch name. Defaults + to current working tree. Note that ``ls`` and ``open`` also take hash, + so this becomes the default for those operations + kwargs + """ + super().__init__(**kwargs) + self.repo = pygit2.Repository(fo or path or os.getcwd()) + self.ref = ref or "master" + + @classmethod + def _strip_protocol(cls, path): + path = super()._strip_protocol(path).lstrip("/") + if ":" in path: + path = path.split(":", 1)[1] + if "@" in path: + path = path.split("@", 1)[1] + return path.lstrip("/") + + def _path_to_object(self, path, ref): + comm, ref = self.repo.resolve_refish(ref or self.ref) + parts = path.split("/") + tree = comm.tree + for part in parts: + if part and isinstance(tree, pygit2.Tree): + if part not in tree: + raise FileNotFoundError(path) + tree = tree[part] + return tree + + @staticmethod + def _get_kwargs_from_urls(path): + path = path.removeprefix("git://") + out = {} + if ":" in path: + out["path"], path = path.split(":", 1) + if "@" in path: + out["ref"], path = path.split("@", 1) + return out + + @staticmethod + def _object_to_info(obj, path=None): + # obj.name and obj.filemode are None for the root tree! + is_dir = isinstance(obj, pygit2.Tree) + return { + "type": "directory" if is_dir else "file", + "name": ( + "/".join([path, obj.name or ""]).lstrip("/") if path else obj.name + ), + "hex": str(obj.id), + "mode": "100644" if obj.filemode is None else f"{obj.filemode:o}", + "size": 0 if is_dir else obj.size, + } + + def ls(self, path, detail=True, ref=None, **kwargs): + tree = self._path_to_object(self._strip_protocol(path), ref) + return [ + GitFileSystem._object_to_info(obj, path) + if detail + else GitFileSystem._object_to_info(obj, path)["name"] + for obj in (tree if isinstance(tree, pygit2.Tree) else [tree]) + ] + + def info(self, path, ref=None, **kwargs): + tree = self._path_to_object(self._strip_protocol(path), ref) + return GitFileSystem._object_to_info(tree, path) + + def ukey(self, path, ref=None): + return self.info(path, ref=ref)["hex"] + + def _open( + self, + path, + mode="rb", + block_size=None, + autocommit=True, + cache_options=None, + ref=None, + **kwargs, + ): + obj = self._path_to_object(path, ref or self.ref) + return MemoryFile(data=obj.data) diff --git a/fsspec/implementations/github.py b/fsspec/implementations/github.py new file mode 100644 index 0000000000000000000000000000000000000000..3630f6db54413e2c396f6cc1b6b10cd379200043 --- /dev/null +++ b/fsspec/implementations/github.py @@ -0,0 +1,333 @@ +import base64 +import re + +import requests + +from ..spec import AbstractFileSystem +from ..utils import infer_storage_options +from .memory import MemoryFile + + +class GithubFileSystem(AbstractFileSystem): + """Interface to files in github + + An instance of this class provides the files residing within a remote github + repository. You may specify a point in the repos history, by SHA, branch + or tag (default is current master). + + For files less than 1 MB in size, file content is returned directly in a + MemoryFile. For larger files, or for files tracked by git-lfs, file content + is returned as an HTTPFile wrapping the ``download_url`` provided by the + GitHub API. + + When using fsspec.open, allows URIs of the form: + + - "github://path/file", in which case you must specify org, repo and + may specify sha in the extra args + - 'github://org:repo@/precip/catalog.yml', where the org and repo are + part of the URI + - 'github://org:repo@sha/precip/catalog.yml', where the sha is also included + + ``sha`` can be the full or abbreviated hex of the commit you want to fetch + from, or a branch or tag name (so long as it doesn't contain special characters + like "/", "?", which would have to be HTTP-encoded). + + For authorised access, you must provide username and token, which can be made + at https://github.com/settings/tokens + """ + + url = "https://api.github.com/repos/{org}/{repo}/git/trees/{sha}" + content_url = "https://api.github.com/repos/{org}/{repo}/contents/{path}?ref={sha}" + protocol = "github" + timeout = (60, 60) # connect, read timeouts + + def __init__( + self, org, repo, sha=None, username=None, token=None, timeout=None, **kwargs + ): + super().__init__(**kwargs) + self.org = org + self.repo = repo + if (username is None) ^ (token is None): + raise ValueError("Auth required both username and token") + self.username = username + self.token = token + if timeout is not None: + self.timeout = timeout + if sha is None: + # look up default branch (not necessarily "master") + u = "https://api.github.com/repos/{org}/{repo}" + r = requests.get( + u.format(org=org, repo=repo), timeout=self.timeout, **self.kw + ) + r.raise_for_status() + sha = r.json()["default_branch"] + + self.root = sha + self.ls("") + try: + from .http import HTTPFileSystem + + self.http_fs = HTTPFileSystem(**kwargs) + except ImportError: + self.http_fs = None + + @property + def kw(self): + if self.username: + return {"auth": (self.username, self.token)} + return {} + + @classmethod + def repos(cls, org_or_user, is_org=True): + """List repo names for given org or user + + This may become the top level of the FS + + Parameters + ---------- + org_or_user: str + Name of the github org or user to query + is_org: bool (default True) + Whether the name is an organisation (True) or user (False) + + Returns + ------- + List of string + """ + r = requests.get( + f"https://api.github.com/{['users', 'orgs'][is_org]}/{org_or_user}/repos", + timeout=cls.timeout, + ) + r.raise_for_status() + return [repo["name"] for repo in r.json()] + + @property + def tags(self): + """Names of tags in the repo""" + r = requests.get( + f"https://api.github.com/repos/{self.org}/{self.repo}/tags", + timeout=self.timeout, + **self.kw, + ) + r.raise_for_status() + return [t["name"] for t in r.json()] + + @property + def branches(self): + """Names of branches in the repo""" + r = requests.get( + f"https://api.github.com/repos/{self.org}/{self.repo}/branches", + timeout=self.timeout, + **self.kw, + ) + r.raise_for_status() + return [t["name"] for t in r.json()] + + @property + def refs(self): + """Named references, tags and branches""" + return {"tags": self.tags, "branches": self.branches} + + def ls(self, path, detail=False, sha=None, _sha=None, **kwargs): + """List files at given path + + Parameters + ---------- + path: str + Location to list, relative to repo root + detail: bool + If True, returns list of dicts, one per file; if False, returns + list of full filenames only + sha: str (optional) + List at the given point in the repo history, branch or tag name or commit + SHA + _sha: str (optional) + List this specific tree object (used internally to descend into trees) + """ + path = self._strip_protocol(path) + if path == "": + _sha = sha or self.root + if _sha is None: + parts = path.rstrip("/").split("/") + so_far = "" + _sha = sha or self.root + for part in parts: + out = self.ls(so_far, True, sha=sha, _sha=_sha) + so_far += "/" + part if so_far else part + out = [o for o in out if o["name"] == so_far] + if not out: + raise FileNotFoundError(path) + out = out[0] + if out["type"] == "file": + if detail: + return [out] + else: + return path + _sha = out["sha"] + if path not in self.dircache or sha not in [self.root, None]: + r = requests.get( + self.url.format(org=self.org, repo=self.repo, sha=_sha), + timeout=self.timeout, + **self.kw, + ) + if r.status_code == 404: + raise FileNotFoundError(path) + r.raise_for_status() + types = {"blob": "file", "tree": "directory"} + out = [ + { + "name": path + "/" + f["path"] if path else f["path"], + "mode": f["mode"], + "type": types[f["type"]], + "size": f.get("size", 0), + "sha": f["sha"], + } + for f in r.json()["tree"] + if f["type"] in types + ] + if sha in [self.root, None]: + self.dircache[path] = out + else: + out = self.dircache[path] + if detail: + return out + else: + return sorted([f["name"] for f in out]) + + def invalidate_cache(self, path=None): + self.dircache.clear() + + @classmethod + def _strip_protocol(cls, path): + opts = infer_storage_options(path) + if "username" not in opts: + return super()._strip_protocol(path) + return opts["path"].lstrip("/") + + @staticmethod + def _get_kwargs_from_urls(path): + opts = infer_storage_options(path) + if "username" not in opts: + return {} + out = {"org": opts["username"], "repo": opts["password"]} + if opts["host"]: + out["sha"] = opts["host"] + return out + + def _open( + self, + path, + mode="rb", + block_size=None, + cache_options=None, + sha=None, + **kwargs, + ): + if mode != "rb": + raise NotImplementedError + + # construct a url to hit the GitHub API's repo contents API + url = self.content_url.format( + org=self.org, repo=self.repo, path=path, sha=sha or self.root + ) + + # make a request to this API, and parse the response as JSON + r = requests.get(url, timeout=self.timeout, **self.kw) + if r.status_code == 404: + raise FileNotFoundError(path) + r.raise_for_status() + content_json = r.json() + + # if the response's content key is not empty, try to parse it as base64 + if content_json["content"]: + content = base64.b64decode(content_json["content"]) + + # as long as the content does not start with the string + # "version https://git-lfs.github.com/" + # then it is probably not a git-lfs pointer and we can just return + # the content directly + if not content.startswith(b"version https://git-lfs.github.com/"): + return MemoryFile(None, None, content) + + # we land here if the content was not present in the first response + # (regular file over 1MB or git-lfs tracked file) + # in this case, we get let the HTTPFileSystem handle the download + if self.http_fs is None: + raise ImportError( + "Please install fsspec[http] to access github files >1 MB " + "or git-lfs tracked files." + ) + return self.http_fs.open( + content_json["download_url"], + mode=mode, + block_size=block_size, + cache_options=cache_options, + **kwargs, + ) + + def rm(self, path, recursive=False, maxdepth=None, message=None): + path = self.expand_path(path, recursive=recursive, maxdepth=maxdepth) + for p in reversed(path): + self.rm_file(p, message=message) + + def rm_file(self, path, message=None, **kwargs): + """ + Remove a file from a specified branch using a given commit message. + + Since Github DELETE operation requires a branch name, and we can't reliably + determine whether the provided SHA refers to a branch, tag, or commit, we + assume it's a branch. If it's not, the user will encounter an error when + attempting to retrieve the file SHA or delete the file. + + Parameters + ---------- + path: str + The file's location relative to the repository root. + message: str, optional + The commit message for the deletion. + """ + + if not self.username: + raise ValueError("Authentication required") + + path = self._strip_protocol(path) + + # Attempt to get SHA from cache or Github API + sha = self._get_sha_from_cache(path) + if not sha: + url = self.content_url.format( + org=self.org, repo=self.repo, path=path.lstrip("/"), sha=self.root + ) + r = requests.get(url, timeout=self.timeout, **self.kw) + if r.status_code == 404: + raise FileNotFoundError(path) + r.raise_for_status() + sha = r.json()["sha"] + + # Delete the file + delete_url = self.content_url.format( + org=self.org, repo=self.repo, path=path, sha=self.root + ) + branch = self.root + data = { + "message": message or f"Delete {path}", + "sha": sha, + **({"branch": branch} if branch else {}), + } + + r = requests.delete(delete_url, json=data, timeout=self.timeout, **self.kw) + error_message = r.json().get("message", "") + if re.search(r"Branch .+ not found", error_message): + error = "Remove only works when the filesystem is initialised from a branch or default (None)" + raise ValueError(error) + r.raise_for_status() + + self.invalidate_cache(path) + + def _get_sha_from_cache(self, path): + for entries in self.dircache.values(): + for entry in entries: + entry_path = entry.get("name") + if entry_path and entry_path == path and "sha" in entry: + return entry["sha"] + return None diff --git a/fsspec/implementations/http.py b/fsspec/implementations/http.py new file mode 100644 index 0000000000000000000000000000000000000000..093fa29bea8f0296bf003a4fa1f25b99fc219831 --- /dev/null +++ b/fsspec/implementations/http.py @@ -0,0 +1,890 @@ +import asyncio +import io +import logging +import re +import weakref +from copy import copy +from urllib.parse import urlparse + +import aiohttp +import yarl + +from fsspec.asyn import AbstractAsyncStreamedFile, AsyncFileSystem, sync, sync_wrapper +from fsspec.callbacks import DEFAULT_CALLBACK +from fsspec.exceptions import FSTimeoutError +from fsspec.spec import AbstractBufferedFile +from fsspec.utils import ( + DEFAULT_BLOCK_SIZE, + glob_translate, + isfilelike, + nullcontext, + tokenize, +) + +from ..caching import AllBytes + +# https://stackoverflow.com/a/15926317/3821154 +ex = re.compile(r"""<(a|A)\s+(?:[^>]*?\s+)?(href|HREF)=["'](?P[^"']+)""") +ex2 = re.compile(r"""(?Phttp[s]?://[-a-zA-Z0-9@:%_+.~#?&/=]+)""") +logger = logging.getLogger("fsspec.http") + + +async def get_client(**kwargs): + return aiohttp.ClientSession(**kwargs) + + +class HTTPFileSystem(AsyncFileSystem): + """ + Simple File-System for fetching data via HTTP(S) + + ``ls()`` is implemented by loading the parent page and doing a regex + match on the result. If simple_link=True, anything of the form + "http(s)://server.com/stuff?thing=other"; otherwise only links within + HTML href tags will be used. + """ + + sep = "/" + + def __init__( + self, + simple_links=True, + block_size=None, + same_scheme=True, + size_policy=None, + cache_type="bytes", + cache_options=None, + asynchronous=False, + loop=None, + client_kwargs=None, + get_client=get_client, + encoded=False, + **storage_options, + ): + """ + NB: if this is called async, you must await set_client + + Parameters + ---------- + block_size: int + Blocks to read bytes; if 0, will default to raw requests file-like + objects instead of HTTPFile instances + simple_links: bool + If True, will consider both HTML tags and anything that looks + like a URL; if False, will consider only the former. + same_scheme: True + When doing ls/glob, if this is True, only consider paths that have + http/https matching the input URLs. + size_policy: this argument is deprecated + client_kwargs: dict + Passed to aiohttp.ClientSession, see + https://docs.aiohttp.org/en/stable/client_reference.html + For example, ``{'auth': aiohttp.BasicAuth('user', 'pass')}`` + get_client: Callable[..., aiohttp.ClientSession] + A callable, which takes keyword arguments and constructs + an aiohttp.ClientSession. Its state will be managed by + the HTTPFileSystem class. + storage_options: key-value + Any other parameters passed on to requests + cache_type, cache_options: defaults used in open() + """ + super().__init__(self, asynchronous=asynchronous, loop=loop, **storage_options) + self.block_size = block_size if block_size is not None else DEFAULT_BLOCK_SIZE + self.simple_links = simple_links + self.same_schema = same_scheme + self.cache_type = cache_type + self.cache_options = cache_options + self.client_kwargs = client_kwargs or {} + self.get_client = get_client + self.encoded = encoded + self.kwargs = storage_options + self._session = None + + # Clean caching-related parameters from `storage_options` + # before propagating them as `request_options` through `self.kwargs`. + # TODO: Maybe rename `self.kwargs` to `self.request_options` to make + # it clearer. + request_options = copy(storage_options) + self.use_listings_cache = request_options.pop("use_listings_cache", False) + request_options.pop("listings_expiry_time", None) + request_options.pop("max_paths", None) + request_options.pop("skip_instance_cache", None) + self.kwargs = request_options + + @property + def fsid(self): + return "http" + + def encode_url(self, url): + return yarl.URL(url, encoded=self.encoded) + + @staticmethod + def close_session(loop, session): + if loop is not None and loop.is_running(): + try: + sync(loop, session.close, timeout=0.1) + return + except (TimeoutError, FSTimeoutError, NotImplementedError): + pass + connector = getattr(session, "_connector", None) + if connector is not None: + # close after loop is dead + connector._close() + + async def set_session(self): + if self._session is None: + self._session = await self.get_client(loop=self.loop, **self.client_kwargs) + if not self.asynchronous: + weakref.finalize(self, self.close_session, self.loop, self._session) + return self._session + + @classmethod + def _strip_protocol(cls, path): + """For HTTP, we always want to keep the full URL""" + return path + + @classmethod + def _parent(cls, path): + # override, since _strip_protocol is different for URLs + par = super()._parent(path) + if len(par) > 7: # "http://..." + return par + return "" + + async def _ls_real(self, url, detail=True, **kwargs): + # ignoring URL-encoded arguments + kw = self.kwargs.copy() + kw.update(kwargs) + logger.debug(url) + session = await self.set_session() + async with session.get(self.encode_url(url), **self.kwargs) as r: + self._raise_not_found_for_status(r, url) + + if "Content-Type" in r.headers: + mimetype = r.headers["Content-Type"].partition(";")[0] + else: + mimetype = None + + if mimetype in ("text/html", None): + try: + text = await r.text(errors="ignore") + if self.simple_links: + links = ex2.findall(text) + [u[2] for u in ex.findall(text)] + else: + links = [u[2] for u in ex.findall(text)] + except UnicodeDecodeError: + links = [] # binary, not HTML + else: + links = [] + + out = set() + parts = urlparse(url) + for l in links: + if isinstance(l, tuple): + l = l[1] + if l.startswith("/") and len(l) > 1: + # absolute URL on this server + l = f"{parts.scheme}://{parts.netloc}{l}" + if l.startswith("http"): + if self.same_schema and l.startswith(url.rstrip("/") + "/"): + out.add(l) + elif l.replace("https", "http").startswith( + url.replace("https", "http").rstrip("/") + "/" + ): + # allowed to cross http <-> https + out.add(l) + else: + if l not in ["..", "../"]: + # Ignore FTP-like "parent" + out.add("/".join([url.rstrip("/"), l.lstrip("/")])) + if not out and url.endswith("/"): + out = await self._ls_real(url.rstrip("/"), detail=False) + if detail: + return [ + { + "name": u, + "size": None, + "type": "directory" if u.endswith("/") else "file", + } + for u in out + ] + else: + return sorted(out) + + async def _ls(self, url, detail=True, **kwargs): + if self.use_listings_cache and url in self.dircache: + out = self.dircache[url] + else: + out = await self._ls_real(url, detail=detail, **kwargs) + self.dircache[url] = out + return out + + ls = sync_wrapper(_ls) + + def _raise_not_found_for_status(self, response, url): + """ + Raises FileNotFoundError for 404s, otherwise uses raise_for_status. + """ + if response.status == 404: + raise FileNotFoundError(url) + response.raise_for_status() + + async def _cat_file(self, url, start=None, end=None, **kwargs): + kw = self.kwargs.copy() + kw.update(kwargs) + logger.debug(url) + + if start is not None or end is not None: + if start == end: + return b"" + headers = kw.pop("headers", {}).copy() + + headers["Range"] = await self._process_limits(url, start, end) + kw["headers"] = headers + session = await self.set_session() + async with session.get(self.encode_url(url), **kw) as r: + out = await r.read() + self._raise_not_found_for_status(r, url) + return out + + async def _get_file( + self, rpath, lpath, chunk_size=5 * 2**20, callback=DEFAULT_CALLBACK, **kwargs + ): + kw = self.kwargs.copy() + kw.update(kwargs) + logger.debug(rpath) + session = await self.set_session() + async with session.get(self.encode_url(rpath), **kw) as r: + try: + size = int(r.headers["content-length"]) + except (ValueError, KeyError): + size = None + + callback.set_size(size) + self._raise_not_found_for_status(r, rpath) + if isfilelike(lpath): + outfile = lpath + else: + outfile = open(lpath, "wb") # noqa: ASYNC230 + + try: + chunk = True + while chunk: + chunk = await r.content.read(chunk_size) + outfile.write(chunk) + callback.relative_update(len(chunk)) + finally: + if not isfilelike(lpath): + outfile.close() + + async def _put_file( + self, + lpath, + rpath, + chunk_size=5 * 2**20, + callback=DEFAULT_CALLBACK, + method="post", + mode="overwrite", + **kwargs, + ): + if mode != "overwrite": + raise NotImplementedError("Exclusive write") + + async def gen_chunks(): + # Support passing arbitrary file-like objects + # and use them instead of streams. + if isinstance(lpath, io.IOBase): + context = nullcontext(lpath) + use_seek = False # might not support seeking + else: + context = open(lpath, "rb") # noqa: ASYNC230 + use_seek = True + + with context as f: + if use_seek: + callback.set_size(f.seek(0, 2)) + f.seek(0) + else: + callback.set_size(getattr(f, "size", None)) + + chunk = f.read(chunk_size) + while chunk: + yield chunk + callback.relative_update(len(chunk)) + chunk = f.read(chunk_size) + + kw = self.kwargs.copy() + kw.update(kwargs) + session = await self.set_session() + + method = method.lower() + if method not in ("post", "put"): + raise ValueError( + f"method has to be either 'post' or 'put', not: {method!r}" + ) + + meth = getattr(session, method) + async with meth(self.encode_url(rpath), data=gen_chunks(), **kw) as resp: + self._raise_not_found_for_status(resp, rpath) + + async def _exists(self, path, **kwargs): + kw = self.kwargs.copy() + kw.update(kwargs) + try: + logger.debug(path) + session = await self.set_session() + r = await session.get(self.encode_url(path), **kw) + async with r: + return r.status < 400 + except aiohttp.ClientError: + return False + + async def _isfile(self, path, **kwargs): + return await self._exists(path, **kwargs) + + def _open( + self, + path, + mode="rb", + block_size=None, + autocommit=None, # XXX: This differs from the base class. + cache_type=None, + cache_options=None, + size=None, + **kwargs, + ): + """Make a file-like object + + Parameters + ---------- + path: str + Full URL with protocol + mode: string + must be "rb" + block_size: int or None + Bytes to download in one request; use instance value if None. If + zero, will return a streaming Requests file-like instance. + kwargs: key-value + Any other parameters, passed to requests calls + """ + if mode != "rb": + raise NotImplementedError + block_size = block_size if block_size is not None else self.block_size + kw = self.kwargs.copy() + kw["asynchronous"] = self.asynchronous + kw.update(kwargs) + info = {} + size = size or info.update(self.info(path, **kwargs)) or info["size"] + session = sync(self.loop, self.set_session) + if block_size and size and info.get("partial", True): + return HTTPFile( + self, + path, + session=session, + block_size=block_size, + mode=mode, + size=size, + cache_type=cache_type or self.cache_type, + cache_options=cache_options or self.cache_options, + loop=self.loop, + **kw, + ) + else: + return HTTPStreamFile( + self, + path, + mode=mode, + loop=self.loop, + session=session, + **kw, + ) + + async def open_async(self, path, mode="rb", size=None, **kwargs): + session = await self.set_session() + if size is None: + try: + size = (await self._info(path, **kwargs))["size"] + except FileNotFoundError: + pass + return AsyncStreamFile( + self, + path, + loop=self.loop, + session=session, + size=size, + **kwargs, + ) + + def ukey(self, url): + """Unique identifier; assume HTTP files are static, unchanging""" + return tokenize(url, self.kwargs, self.protocol) + + async def _info(self, url, **kwargs): + """Get info of URL + + Tries to access location via HEAD, and then GET methods, but does + not fetch the data. + + It is possible that the server does not supply any size information, in + which case size will be given as None (and certain operations on the + corresponding file will not work). + """ + info = {} + session = await self.set_session() + + for policy in ["head", "get"]: + try: + info.update( + await _file_info( + self.encode_url(url), + size_policy=policy, + session=session, + **self.kwargs, + **kwargs, + ) + ) + if info.get("size") is not None: + break + except Exception as exc: + if policy == "get": + # If get failed, then raise a FileNotFoundError + raise FileNotFoundError(url) from exc + logger.debug("", exc_info=exc) + + return {"name": url, "size": None, **info, "type": "file"} + + async def _glob(self, path, maxdepth=None, **kwargs): + """ + Find files by glob-matching. + + This implementation is idntical to the one in AbstractFileSystem, + but "?" is not considered as a character for globbing, because it is + so common in URLs, often identifying the "query" part. + """ + if maxdepth is not None and maxdepth < 1: + raise ValueError("maxdepth must be at least 1") + import re + + ends_with_slash = path.endswith("/") # _strip_protocol strips trailing slash + path = self._strip_protocol(path) + append_slash_to_dirname = ends_with_slash or path.endswith(("/**", "/*")) + idx_star = path.find("*") if path.find("*") >= 0 else len(path) + idx_brace = path.find("[") if path.find("[") >= 0 else len(path) + + min_idx = min(idx_star, idx_brace) + + detail = kwargs.pop("detail", False) + + if not has_magic(path): + if await self._exists(path, **kwargs): + if not detail: + return [path] + else: + return {path: await self._info(path, **kwargs)} + else: + if not detail: + return [] # glob of non-existent returns empty + else: + return {} + elif "/" in path[:min_idx]: + min_idx = path[:min_idx].rindex("/") + root = path[: min_idx + 1] + depth = path[min_idx + 1 :].count("/") + 1 + else: + root = "" + depth = path[min_idx + 1 :].count("/") + 1 + + if "**" in path: + if maxdepth is not None: + idx_double_stars = path.find("**") + depth_double_stars = path[idx_double_stars:].count("/") + 1 + depth = depth - depth_double_stars + maxdepth + else: + depth = None + + allpaths = await self._find( + root, maxdepth=depth, withdirs=True, detail=True, **kwargs + ) + + pattern = glob_translate(path + ("/" if ends_with_slash else "")) + pattern = re.compile(pattern) + + out = { + ( + p.rstrip("/") + if not append_slash_to_dirname + and info["type"] == "directory" + and p.endswith("/") + else p + ): info + for p, info in sorted(allpaths.items()) + if pattern.match(p.rstrip("/")) + } + + if detail: + return out + else: + return list(out) + + async def _isdir(self, path): + # override, since all URLs are (also) files + try: + return bool(await self._ls(path)) + except (FileNotFoundError, ValueError): + return False + + async def _pipe_file(self, path, value, mode="overwrite", **kwargs): + """ + Write bytes to a remote file over HTTP. + + Parameters + ---------- + path : str + Target URL where the data should be written + value : bytes + Data to be written + mode : str + How to write to the file - 'overwrite' or 'append' + **kwargs : dict + Additional parameters to pass to the HTTP request + """ + url = self._strip_protocol(path) + headers = kwargs.pop("headers", {}) + headers["Content-Length"] = str(len(value)) + + session = await self.set_session() + + async with session.put(url, data=value, headers=headers, **kwargs) as r: + r.raise_for_status() + + +class HTTPFile(AbstractBufferedFile): + """ + A file-like object pointing to a remote HTTP(S) resource + + Supports only reading, with read-ahead of a predetermined block-size. + + In the case that the server does not supply the filesize, only reading of + the complete file in one go is supported. + + Parameters + ---------- + url: str + Full URL of the remote resource, including the protocol + session: aiohttp.ClientSession or None + All calls will be made within this session, to avoid restarting + connections where the server allows this + block_size: int or None + The amount of read-ahead to do, in bytes. Default is 5MB, or the value + configured for the FileSystem creating this file + size: None or int + If given, this is the size of the file in bytes, and we don't attempt + to call the server to find the value. + kwargs: all other key-values are passed to requests calls. + """ + + def __init__( + self, + fs, + url, + session=None, + block_size=None, + mode="rb", + cache_type="bytes", + cache_options=None, + size=None, + loop=None, + asynchronous=False, + **kwargs, + ): + if mode != "rb": + raise NotImplementedError("File mode not supported") + self.asynchronous = asynchronous + self.loop = loop + self.url = url + self.session = session + self.details = {"name": url, "size": size, "type": "file"} + super().__init__( + fs=fs, + path=url, + mode=mode, + block_size=block_size, + cache_type=cache_type, + cache_options=cache_options, + **kwargs, + ) + + def read(self, length=-1): + """Read bytes from file + + Parameters + ---------- + length: int + Read up to this many bytes. If negative, read all content to end of + file. If the server has not supplied the filesize, attempting to + read only part of the data will raise a ValueError. + """ + if ( + (length < 0 and self.loc == 0) # explicit read all + # but not when the size is known and fits into a block anyways + and not (self.size is not None and self.size <= self.blocksize) + ): + self._fetch_all() + if self.size is None: + if length < 0: + self._fetch_all() + else: + length = min(self.size - self.loc, length) + return super().read(length) + + async def async_fetch_all(self): + """Read whole file in one shot, without caching + + This is only called when position is still at zero, + and read() is called without a byte-count. + """ + logger.debug(f"Fetch all for {self}") + if not isinstance(self.cache, AllBytes): + r = await self.session.get(self.fs.encode_url(self.url), **self.kwargs) + async with r: + r.raise_for_status() + out = await r.read() + self.cache = AllBytes( + size=len(out), fetcher=None, blocksize=None, data=out + ) + self.size = len(out) + + _fetch_all = sync_wrapper(async_fetch_all) + + def _parse_content_range(self, headers): + """Parse the Content-Range header""" + s = headers.get("Content-Range", "") + m = re.match(r"bytes (\d+-\d+|\*)/(\d+|\*)", s) + if not m: + return None, None, None + + if m[1] == "*": + start = end = None + else: + start, end = [int(x) for x in m[1].split("-")] + total = None if m[2] == "*" else int(m[2]) + return start, end, total + + async def async_fetch_range(self, start, end): + """Download a block of data + + The expectation is that the server returns only the requested bytes, + with HTTP code 206. If this is not the case, we first check the headers, + and then stream the output - if the data size is bigger than we + requested, an exception is raised. + """ + logger.debug(f"Fetch range for {self}: {start}-{end}") + kwargs = self.kwargs.copy() + headers = kwargs.pop("headers", {}).copy() + headers["Range"] = f"bytes={start}-{end - 1}" + logger.debug(f"{self.url} : {headers['Range']}") + r = await self.session.get( + self.fs.encode_url(self.url), headers=headers, **kwargs + ) + async with r: + if r.status == 416: + # range request outside file + return b"" + r.raise_for_status() + + # If the server has handled the range request, it should reply + # with status 206 (partial content). But we'll guess that a suitable + # Content-Range header or a Content-Length no more than the + # requested range also mean we have got the desired range. + response_is_range = ( + r.status == 206 + or self._parse_content_range(r.headers)[0] == start + or int(r.headers.get("Content-Length", end + 1)) <= end - start + ) + + if response_is_range: + # partial content, as expected + out = await r.read() + elif start > 0: + raise ValueError( + "The HTTP server doesn't appear to support range requests. " + "Only reading this file from the beginning is supported. " + "Open with block_size=0 for a streaming file interface." + ) + else: + # Response is not a range, but we want the start of the file, + # so we can read the required amount anyway. + cl = 0 + out = [] + while True: + chunk = await r.content.read(2**20) + # data size unknown, let's read until we have enough + if chunk: + out.append(chunk) + cl += len(chunk) + if cl > end - start: + break + else: + break + out = b"".join(out)[: end - start] + return out + + _fetch_range = sync_wrapper(async_fetch_range) + + +magic_check = re.compile("([*[])") + + +def has_magic(s): + match = magic_check.search(s) + return match is not None + + +class HTTPStreamFile(AbstractBufferedFile): + def __init__(self, fs, url, mode="rb", loop=None, session=None, **kwargs): + self.asynchronous = kwargs.pop("asynchronous", False) + self.url = url + self.loop = loop + self.session = session + if mode != "rb": + raise ValueError + self.details = {"name": url, "size": None} + super().__init__(fs=fs, path=url, mode=mode, cache_type="none", **kwargs) + + async def cor(): + r = await self.session.get(self.fs.encode_url(url), **kwargs).__aenter__() + self.fs._raise_not_found_for_status(r, url) + return r + + self.r = sync(self.loop, cor) + self.loop = fs.loop + + def seek(self, loc, whence=0): + if loc == 0 and whence == 1: + return + if loc == self.loc and whence == 0: + return + raise ValueError("Cannot seek streaming HTTP file") + + async def _read(self, num=-1): + out = await self.r.content.read(num) + self.loc += len(out) + return out + + read = sync_wrapper(_read) + + async def _close(self): + self.r.close() + + def close(self): + asyncio.run_coroutine_threadsafe(self._close(), self.loop) + super().close() + + +class AsyncStreamFile(AbstractAsyncStreamedFile): + def __init__( + self, fs, url, mode="rb", loop=None, session=None, size=None, **kwargs + ): + self.url = url + self.session = session + self.r = None + if mode != "rb": + raise ValueError + self.details = {"name": url, "size": None} + self.kwargs = kwargs + super().__init__(fs=fs, path=url, mode=mode, cache_type="none") + self.size = size + + async def read(self, num=-1): + if self.r is None: + r = await self.session.get( + self.fs.encode_url(self.url), **self.kwargs + ).__aenter__() + self.fs._raise_not_found_for_status(r, self.url) + self.r = r + out = await self.r.content.read(num) + self.loc += len(out) + return out + + async def close(self): + if self.r is not None: + self.r.close() + self.r = None + await super().close() + + +async def get_range(session, url, start, end, file=None, **kwargs): + # explicit get a range when we know it must be safe + kwargs = kwargs.copy() + headers = kwargs.pop("headers", {}).copy() + headers["Range"] = f"bytes={start}-{end - 1}" + r = await session.get(url, headers=headers, **kwargs) + r.raise_for_status() + async with r: + out = await r.read() + if file: + with open(file, "r+b") as f: # noqa: ASYNC230 + f.seek(start) + f.write(out) + else: + return out + + +async def _file_info(url, session, size_policy="head", **kwargs): + """Call HEAD on the server to get details about the file (size/checksum etc.) + + Default operation is to explicitly allow redirects and use encoding + 'identity' (no compression) to get the true size of the target. + """ + logger.debug("Retrieve file size for %s", url) + kwargs = kwargs.copy() + ar = kwargs.pop("allow_redirects", True) + head = kwargs.get("headers", {}).copy() + head["Accept-Encoding"] = "identity" + kwargs["headers"] = head + + info = {} + if size_policy == "head": + r = await session.head(url, allow_redirects=ar, **kwargs) + elif size_policy == "get": + r = await session.get(url, allow_redirects=ar, **kwargs) + else: + raise TypeError(f'size_policy must be "head" or "get", got {size_policy}') + async with r: + r.raise_for_status() + + if "Content-Length" in r.headers: + # Some servers may choose to ignore Accept-Encoding and return + # compressed content, in which case the returned size is unreliable. + if "Content-Encoding" not in r.headers or r.headers["Content-Encoding"] in [ + "identity", + "", + ]: + info["size"] = int(r.headers["Content-Length"]) + elif "Content-Range" in r.headers: + info["size"] = int(r.headers["Content-Range"].split("/")[1]) + + if "Content-Type" in r.headers: + info["mimetype"] = r.headers["Content-Type"].partition(";")[0] + + if r.headers.get("Accept-Ranges") == "none": + # Some servers may explicitly discourage partial content requests, but + # the lack of "Accept-Ranges" does not always indicate they would fail + info["partial"] = False + + info["url"] = str(r.url) + + for checksum_field in ["ETag", "Content-MD5", "Digest", "Last-Modified"]: + if r.headers.get(checksum_field): + info[checksum_field] = r.headers[checksum_field] + + return info + + +async def _file_size(url, session=None, *args, **kwargs): + if session is None: + session = await get_client() + info = await _file_info(url, session=session, *args, **kwargs) + return info.get("size") + + +file_size = sync_wrapper(_file_size) diff --git a/fsspec/implementations/http_sync.py b/fsspec/implementations/http_sync.py new file mode 100644 index 0000000000000000000000000000000000000000..08799f20acde91b19dcdb4fbb02313bc64f9257a --- /dev/null +++ b/fsspec/implementations/http_sync.py @@ -0,0 +1,931 @@ +"""This file is largely copied from http.py""" + +import io +import logging +import re +import urllib.error +import urllib.parse +from copy import copy +from json import dumps, loads +from urllib.parse import urlparse + +try: + import yarl +except (ImportError, ModuleNotFoundError, OSError): + yarl = False + +from fsspec.callbacks import _DEFAULT_CALLBACK +from fsspec.registry import register_implementation +from fsspec.spec import AbstractBufferedFile, AbstractFileSystem +from fsspec.utils import DEFAULT_BLOCK_SIZE, isfilelike, nullcontext, tokenize + +from ..caching import AllBytes + +# https://stackoverflow.com/a/15926317/3821154 +ex = re.compile(r"""<(a|A)\s+(?:[^>]*?\s+)?(href|HREF)=["'](?P[^"']+)""") +ex2 = re.compile(r"""(?Phttp[s]?://[-a-zA-Z0-9@:%_+.~#?&/=]+)""") +logger = logging.getLogger("fsspec.http") + + +class JsHttpException(urllib.error.HTTPError): ... + + +class StreamIO(io.BytesIO): + # fake class, so you can set attributes on it + # will eventually actually stream + ... + + +class ResponseProxy: + """Looks like a requests response""" + + def __init__(self, req, stream=False): + self.request = req + self.stream = stream + self._data = None + self._headers = None + + @property + def raw(self): + if self._data is None: + b = self.request.response.to_bytes() + if self.stream: + self._data = StreamIO(b) + else: + self._data = b + return self._data + + def close(self): + if hasattr(self, "_data"): + del self._data + + @property + def headers(self): + if self._headers is None: + self._headers = dict( + [ + _.split(": ") + for _ in self.request.getAllResponseHeaders().strip().split("\r\n") + ] + ) + return self._headers + + @property + def status_code(self): + return int(self.request.status) + + def raise_for_status(self): + if not self.ok: + raise JsHttpException( + self.url, self.status_code, self.reason, self.headers, None + ) + + def iter_content(self, chunksize, *_, **__): + while True: + out = self.raw.read(chunksize) + if out: + yield out + else: + break + + @property + def reason(self): + return self.request.statusText + + @property + def ok(self): + return self.status_code < 400 + + @property + def url(self): + return self.request.response.responseURL + + @property + def text(self): + # TODO: encoding from headers + return self.content.decode() + + @property + def content(self): + self.stream = False + return self.raw + + def json(self): + return loads(self.text) + + +class RequestsSessionShim: + def __init__(self): + self.headers = {} + + def request( + self, + method, + url, + params=None, + data=None, + headers=None, + cookies=None, + files=None, + auth=None, + timeout=None, + allow_redirects=None, + proxies=None, + hooks=None, + stream=None, + verify=None, + cert=None, + json=None, + ): + from js import Blob, XMLHttpRequest + + logger.debug("JS request: %s %s", method, url) + + if cert or verify or proxies or files or cookies or hooks: + raise NotImplementedError + if data and json: + raise ValueError("Use json= or data=, not both") + req = XMLHttpRequest.new() + extra = auth if auth else () + if params: + url = f"{url}?{urllib.parse.urlencode(params)}" + req.open(method, url, False, *extra) + if timeout: + req.timeout = timeout + if headers: + for k, v in headers.items(): + req.setRequestHeader(k, v) + + req.setRequestHeader("Accept", "application/octet-stream") + req.responseType = "arraybuffer" + if json: + blob = Blob.new([dumps(data)], {type: "application/json"}) + req.send(blob) + elif data: + if isinstance(data, io.IOBase): + data = data.read() + blob = Blob.new([data], {type: "application/octet-stream"}) + req.send(blob) + else: + req.send(None) + return ResponseProxy(req, stream=stream) + + def get(self, url, **kwargs): + return self.request("GET", url, **kwargs) + + def head(self, url, **kwargs): + return self.request("HEAD", url, **kwargs) + + def post(self, url, **kwargs): + return self.request("POST}", url, **kwargs) + + def put(self, url, **kwargs): + return self.request("PUT", url, **kwargs) + + def patch(self, url, **kwargs): + return self.request("PATCH", url, **kwargs) + + def delete(self, url, **kwargs): + return self.request("DELETE", url, **kwargs) + + +class HTTPFileSystem(AbstractFileSystem): + """ + Simple File-System for fetching data via HTTP(S) + + This is the BLOCKING version of the normal HTTPFileSystem. It uses + requests in normal python and the JS runtime in pyodide. + + ***This implementation is extremely experimental, do not use unless + you are testing pyodide/pyscript integration*** + """ + + protocol = ("http", "https", "sync-http", "sync-https") + sep = "/" + + def __init__( + self, + simple_links=True, + block_size=None, + same_scheme=True, + cache_type="readahead", + cache_options=None, + client_kwargs=None, + encoded=False, + **storage_options, + ): + """ + + Parameters + ---------- + block_size: int + Blocks to read bytes; if 0, will default to raw requests file-like + objects instead of HTTPFile instances + simple_links: bool + If True, will consider both HTML tags and anything that looks + like a URL; if False, will consider only the former. + same_scheme: True + When doing ls/glob, if this is True, only consider paths that have + http/https matching the input URLs. + size_policy: this argument is deprecated + client_kwargs: dict + Passed to aiohttp.ClientSession, see + https://docs.aiohttp.org/en/stable/client_reference.html + For example, ``{'auth': aiohttp.BasicAuth('user', 'pass')}`` + storage_options: key-value + Any other parameters passed on to requests + cache_type, cache_options: defaults used in open + """ + super().__init__(self, **storage_options) + self.block_size = block_size if block_size is not None else DEFAULT_BLOCK_SIZE + self.simple_links = simple_links + self.same_schema = same_scheme + self.cache_type = cache_type + self.cache_options = cache_options + self.client_kwargs = client_kwargs or {} + self.encoded = encoded + self.kwargs = storage_options + + try: + import js # noqa: F401 + + logger.debug("Starting JS session") + self.session = RequestsSessionShim() + self.js = True + except Exception as e: + import requests + + logger.debug("Starting cpython session because of: %s", e) + self.session = requests.Session(**(client_kwargs or {})) + self.js = False + + request_options = copy(storage_options) + self.use_listings_cache = request_options.pop("use_listings_cache", False) + request_options.pop("listings_expiry_time", None) + request_options.pop("max_paths", None) + request_options.pop("skip_instance_cache", None) + self.kwargs = request_options + + @property + def fsid(self): + return "sync-http" + + def encode_url(self, url): + if yarl: + return yarl.URL(url, encoded=self.encoded) + return url + + @classmethod + def _strip_protocol(cls, path: str) -> str: + """For HTTP, we always want to keep the full URL""" + path = path.replace("sync-http://", "http://").replace( + "sync-https://", "https://" + ) + return path + + @classmethod + def _parent(cls, path): + # override, since _strip_protocol is different for URLs + par = super()._parent(path) + if len(par) > 7: # "http://..." + return par + return "" + + def _ls_real(self, url, detail=True, **kwargs): + # ignoring URL-encoded arguments + kw = self.kwargs.copy() + kw.update(kwargs) + logger.debug(url) + r = self.session.get(self.encode_url(url), **self.kwargs) + self._raise_not_found_for_status(r, url) + text = r.text + if self.simple_links: + links = ex2.findall(text) + [u[2] for u in ex.findall(text)] + else: + links = [u[2] for u in ex.findall(text)] + out = set() + parts = urlparse(url) + for l in links: + if isinstance(l, tuple): + l = l[1] + if l.startswith("/") and len(l) > 1: + # absolute URL on this server + l = parts.scheme + "://" + parts.netloc + l + if l.startswith("http"): + if self.same_schema and l.startswith(url.rstrip("/") + "/"): + out.add(l) + elif l.replace("https", "http").startswith( + url.replace("https", "http").rstrip("/") + "/" + ): + # allowed to cross http <-> https + out.add(l) + else: + if l not in ["..", "../"]: + # Ignore FTP-like "parent" + out.add("/".join([url.rstrip("/"), l.lstrip("/")])) + if not out and url.endswith("/"): + out = self._ls_real(url.rstrip("/"), detail=False) + if detail: + return [ + { + "name": u, + "size": None, + "type": "directory" if u.endswith("/") else "file", + } + for u in out + ] + else: + return sorted(out) + + def ls(self, url, detail=True, **kwargs): + if self.use_listings_cache and url in self.dircache: + out = self.dircache[url] + else: + out = self._ls_real(url, detail=detail, **kwargs) + self.dircache[url] = out + return out + + def _raise_not_found_for_status(self, response, url): + """ + Raises FileNotFoundError for 404s, otherwise uses raise_for_status. + """ + if response.status_code == 404: + raise FileNotFoundError(url) + response.raise_for_status() + + def cat_file(self, url, start=None, end=None, **kwargs): + kw = self.kwargs.copy() + kw.update(kwargs) + logger.debug(url) + + if start is not None or end is not None: + if start == end: + return b"" + headers = kw.pop("headers", {}).copy() + + headers["Range"] = self._process_limits(url, start, end) + kw["headers"] = headers + r = self.session.get(self.encode_url(url), **kw) + self._raise_not_found_for_status(r, url) + return r.content + + def get_file( + self, rpath, lpath, chunk_size=5 * 2**20, callback=_DEFAULT_CALLBACK, **kwargs + ): + kw = self.kwargs.copy() + kw.update(kwargs) + logger.debug(rpath) + r = self.session.get(self.encode_url(rpath), **kw) + try: + size = int( + r.headers.get("content-length", None) + or r.headers.get("Content-Length", None) + ) + except (ValueError, KeyError, TypeError): + size = None + + callback.set_size(size) + self._raise_not_found_for_status(r, rpath) + if not isfilelike(lpath): + lpath = open(lpath, "wb") + for chunk in r.iter_content(chunk_size, decode_unicode=False): + lpath.write(chunk) + callback.relative_update(len(chunk)) + + def put_file( + self, + lpath, + rpath, + chunk_size=5 * 2**20, + callback=_DEFAULT_CALLBACK, + method="post", + **kwargs, + ): + def gen_chunks(): + # Support passing arbitrary file-like objects + # and use them instead of streams. + if isinstance(lpath, io.IOBase): + context = nullcontext(lpath) + use_seek = False # might not support seeking + else: + context = open(lpath, "rb") + use_seek = True + + with context as f: + if use_seek: + callback.set_size(f.seek(0, 2)) + f.seek(0) + else: + callback.set_size(getattr(f, "size", None)) + + chunk = f.read(chunk_size) + while chunk: + yield chunk + callback.relative_update(len(chunk)) + chunk = f.read(chunk_size) + + kw = self.kwargs.copy() + kw.update(kwargs) + + method = method.lower() + if method not in ("post", "put"): + raise ValueError( + f"method has to be either 'post' or 'put', not: {method!r}" + ) + + meth = getattr(self.session, method) + resp = meth(rpath, data=gen_chunks(), **kw) + self._raise_not_found_for_status(resp, rpath) + + def _process_limits(self, url, start, end): + """Helper for "Range"-based _cat_file""" + size = None + suff = False + if start is not None and start < 0: + # if start is negative and end None, end is the "suffix length" + if end is None: + end = -start + start = "" + suff = True + else: + size = size or self.info(url)["size"] + start = size + start + elif start is None: + start = 0 + if not suff: + if end is not None and end < 0: + if start is not None: + size = size or self.info(url)["size"] + end = size + end + elif end is None: + end = "" + if isinstance(end, int): + end -= 1 # bytes range is inclusive + return f"bytes={start}-{end}" + + def exists(self, path, **kwargs): + kw = self.kwargs.copy() + kw.update(kwargs) + try: + logger.debug(path) + r = self.session.get(self.encode_url(path), **kw) + return r.status_code < 400 + except Exception: + return False + + def isfile(self, path, **kwargs): + return self.exists(path, **kwargs) + + def _open( + self, + path, + mode="rb", + block_size=None, + autocommit=None, # XXX: This differs from the base class. + cache_type=None, + cache_options=None, + size=None, + **kwargs, + ): + """Make a file-like object + + Parameters + ---------- + path: str + Full URL with protocol + mode: string + must be "rb" + block_size: int or None + Bytes to download in one request; use instance value if None. If + zero, will return a streaming Requests file-like instance. + kwargs: key-value + Any other parameters, passed to requests calls + """ + if mode != "rb": + raise NotImplementedError + block_size = block_size if block_size is not None else self.block_size + kw = self.kwargs.copy() + kw.update(kwargs) + size = size or self.info(path, **kwargs)["size"] + if block_size and size: + return HTTPFile( + self, + path, + session=self.session, + block_size=block_size, + mode=mode, + size=size, + cache_type=cache_type or self.cache_type, + cache_options=cache_options or self.cache_options, + **kw, + ) + else: + return HTTPStreamFile( + self, + path, + mode=mode, + session=self.session, + **kw, + ) + + def ukey(self, url): + """Unique identifier; assume HTTP files are static, unchanging""" + return tokenize(url, self.kwargs, self.protocol) + + def info(self, url, **kwargs): + """Get info of URL + + Tries to access location via HEAD, and then GET methods, but does + not fetch the data. + + It is possible that the server does not supply any size information, in + which case size will be given as None (and certain operations on the + corresponding file will not work). + """ + info = {} + for policy in ["head", "get"]: + try: + info.update( + _file_info( + self.encode_url(url), + size_policy=policy, + session=self.session, + **self.kwargs, + **kwargs, + ) + ) + if info.get("size") is not None: + break + except Exception as exc: + if policy == "get": + # If get failed, then raise a FileNotFoundError + raise FileNotFoundError(url) from exc + logger.debug(str(exc)) + + return {"name": url, "size": None, **info, "type": "file"} + + def glob(self, path, maxdepth=None, **kwargs): + """ + Find files by glob-matching. + + This implementation is idntical to the one in AbstractFileSystem, + but "?" is not considered as a character for globbing, because it is + so common in URLs, often identifying the "query" part. + """ + import re + + ends = path.endswith("/") + path = self._strip_protocol(path) + indstar = path.find("*") if path.find("*") >= 0 else len(path) + indbrace = path.find("[") if path.find("[") >= 0 else len(path) + + ind = min(indstar, indbrace) + + detail = kwargs.pop("detail", False) + + if not has_magic(path): + root = path + depth = 1 + if ends: + path += "/*" + elif self.exists(path): + if not detail: + return [path] + else: + return {path: self.info(path)} + else: + if not detail: + return [] # glob of non-existent returns empty + else: + return {} + elif "/" in path[:ind]: + ind2 = path[:ind].rindex("/") + root = path[: ind2 + 1] + depth = None if "**" in path else path[ind2 + 1 :].count("/") + 1 + else: + root = "" + depth = None if "**" in path else path[ind + 1 :].count("/") + 1 + + allpaths = self.find( + root, maxdepth=maxdepth or depth, withdirs=True, detail=True, **kwargs + ) + # Escape characters special to python regex, leaving our supported + # special characters in place. + # See https://www.gnu.org/software/bash/manual/html_node/Pattern-Matching.html + # for shell globbing details. + pattern = ( + "^" + + ( + path.replace("\\", r"\\") + .replace(".", r"\.") + .replace("+", r"\+") + .replace("//", "/") + .replace("(", r"\(") + .replace(")", r"\)") + .replace("|", r"\|") + .replace("^", r"\^") + .replace("$", r"\$") + .replace("{", r"\{") + .replace("}", r"\}") + .rstrip("/") + ) + + "$" + ) + pattern = re.sub("[*]{2}", "=PLACEHOLDER=", pattern) + pattern = re.sub("[*]", "[^/]*", pattern) + pattern = re.compile(pattern.replace("=PLACEHOLDER=", ".*")) + out = { + p: allpaths[p] + for p in sorted(allpaths) + if pattern.match(p.replace("//", "/").rstrip("/")) + } + if detail: + return out + else: + return list(out) + + def isdir(self, path): + # override, since all URLs are (also) files + try: + return bool(self.ls(path)) + except (FileNotFoundError, ValueError): + return False + + +class HTTPFile(AbstractBufferedFile): + """ + A file-like object pointing to a remove HTTP(S) resource + + Supports only reading, with read-ahead of a predermined block-size. + + In the case that the server does not supply the filesize, only reading of + the complete file in one go is supported. + + Parameters + ---------- + url: str + Full URL of the remote resource, including the protocol + session: requests.Session or None + All calls will be made within this session, to avoid restarting + connections where the server allows this + block_size: int or None + The amount of read-ahead to do, in bytes. Default is 5MB, or the value + configured for the FileSystem creating this file + size: None or int + If given, this is the size of the file in bytes, and we don't attempt + to call the server to find the value. + kwargs: all other key-values are passed to requests calls. + """ + + def __init__( + self, + fs, + url, + session=None, + block_size=None, + mode="rb", + cache_type="bytes", + cache_options=None, + size=None, + **kwargs, + ): + if mode != "rb": + raise NotImplementedError("File mode not supported") + self.url = url + self.session = session + self.details = {"name": url, "size": size, "type": "file"} + super().__init__( + fs=fs, + path=url, + mode=mode, + block_size=block_size, + cache_type=cache_type, + cache_options=cache_options, + **kwargs, + ) + + def read(self, length=-1): + """Read bytes from file + + Parameters + ---------- + length: int + Read up to this many bytes. If negative, read all content to end of + file. If the server has not supplied the filesize, attempting to + read only part of the data will raise a ValueError. + """ + if ( + (length < 0 and self.loc == 0) # explicit read all + # but not when the size is known and fits into a block anyways + and not (self.size is not None and self.size <= self.blocksize) + ): + self._fetch_all() + if self.size is None: + if length < 0: + self._fetch_all() + else: + length = min(self.size - self.loc, length) + return super().read(length) + + def _fetch_all(self): + """Read whole file in one shot, without caching + + This is only called when position is still at zero, + and read() is called without a byte-count. + """ + logger.debug(f"Fetch all for {self}") + if not isinstance(self.cache, AllBytes): + r = self.session.get(self.fs.encode_url(self.url), **self.kwargs) + r.raise_for_status() + out = r.content + self.cache = AllBytes(size=len(out), fetcher=None, blocksize=None, data=out) + self.size = len(out) + + def _parse_content_range(self, headers): + """Parse the Content-Range header""" + s = headers.get("Content-Range", "") + m = re.match(r"bytes (\d+-\d+|\*)/(\d+|\*)", s) + if not m: + return None, None, None + + if m[1] == "*": + start = end = None + else: + start, end = [int(x) for x in m[1].split("-")] + total = None if m[2] == "*" else int(m[2]) + return start, end, total + + def _fetch_range(self, start, end): + """Download a block of data + + The expectation is that the server returns only the requested bytes, + with HTTP code 206. If this is not the case, we first check the headers, + and then stream the output - if the data size is bigger than we + requested, an exception is raised. + """ + logger.debug(f"Fetch range for {self}: {start}-{end}") + kwargs = self.kwargs.copy() + headers = kwargs.pop("headers", {}).copy() + headers["Range"] = f"bytes={start}-{end - 1}" + logger.debug("%s : %s", self.url, headers["Range"]) + r = self.session.get(self.fs.encode_url(self.url), headers=headers, **kwargs) + if r.status_code == 416: + # range request outside file + return b"" + r.raise_for_status() + + # If the server has handled the range request, it should reply + # with status 206 (partial content). But we'll guess that a suitable + # Content-Range header or a Content-Length no more than the + # requested range also mean we have got the desired range. + cl = r.headers.get("Content-Length", r.headers.get("content-length", end + 1)) + response_is_range = ( + r.status_code == 206 + or self._parse_content_range(r.headers)[0] == start + or int(cl) <= end - start + ) + + if response_is_range: + # partial content, as expected + out = r.content + elif start > 0: + raise ValueError( + "The HTTP server doesn't appear to support range requests. " + "Only reading this file from the beginning is supported. " + "Open with block_size=0 for a streaming file interface." + ) + else: + # Response is not a range, but we want the start of the file, + # so we can read the required amount anyway. + cl = 0 + out = [] + for chunk in r.iter_content(2**20, False): + out.append(chunk) + cl += len(chunk) + out = b"".join(out)[: end - start] + return out + + +magic_check = re.compile("([*[])") + + +def has_magic(s): + match = magic_check.search(s) + return match is not None + + +class HTTPStreamFile(AbstractBufferedFile): + def __init__(self, fs, url, mode="rb", session=None, **kwargs): + self.url = url + self.session = session + if mode != "rb": + raise ValueError + self.details = {"name": url, "size": None} + super().__init__(fs=fs, path=url, mode=mode, cache_type="readahead", **kwargs) + + r = self.session.get(self.fs.encode_url(url), stream=True, **kwargs) + self.fs._raise_not_found_for_status(r, url) + self.it = r.iter_content(1024, False) + self.leftover = b"" + + self.r = r + + def seek(self, *args, **kwargs): + raise ValueError("Cannot seek streaming HTTP file") + + def read(self, num=-1): + bufs = [self.leftover] + leng = len(self.leftover) + while leng < num or num < 0: + try: + out = self.it.__next__() + except StopIteration: + break + if out: + bufs.append(out) + else: + break + leng += len(out) + out = b"".join(bufs) + if num >= 0: + self.leftover = out[num:] + out = out[:num] + else: + self.leftover = b"" + self.loc += len(out) + return out + + def close(self): + self.r.close() + self.closed = True + + +def get_range(session, url, start, end, **kwargs): + # explicit get a range when we know it must be safe + kwargs = kwargs.copy() + headers = kwargs.pop("headers", {}).copy() + headers["Range"] = f"bytes={start}-{end - 1}" + r = session.get(url, headers=headers, **kwargs) + r.raise_for_status() + return r.content + + +def _file_info(url, session, size_policy="head", **kwargs): + """Call HEAD on the server to get details about the file (size/checksum etc.) + + Default operation is to explicitly allow redirects and use encoding + 'identity' (no compression) to get the true size of the target. + """ + logger.debug("Retrieve file size for %s", url) + kwargs = kwargs.copy() + ar = kwargs.pop("allow_redirects", True) + head = kwargs.get("headers", {}).copy() + # TODO: not allowed in JS + # head["Accept-Encoding"] = "identity" + kwargs["headers"] = head + + info = {} + if size_policy == "head": + r = session.head(url, allow_redirects=ar, **kwargs) + elif size_policy == "get": + r = session.get(url, allow_redirects=ar, **kwargs) + else: + raise TypeError(f'size_policy must be "head" or "get", got {size_policy}') + r.raise_for_status() + + # TODO: + # recognise lack of 'Accept-Ranges', + # or 'Accept-Ranges': 'none' (not 'bytes') + # to mean streaming only, no random access => return None + if "Content-Length" in r.headers: + info["size"] = int(r.headers["Content-Length"]) + elif "Content-Range" in r.headers: + info["size"] = int(r.headers["Content-Range"].split("/")[1]) + elif "content-length" in r.headers: + info["size"] = int(r.headers["content-length"]) + elif "content-range" in r.headers: + info["size"] = int(r.headers["content-range"].split("/")[1]) + + for checksum_field in ["ETag", "Content-MD5", "Digest"]: + if r.headers.get(checksum_field): + info[checksum_field] = r.headers[checksum_field] + + return info + + +# importing this is enough to register it +def register(): + register_implementation("http", HTTPFileSystem, clobber=True) + register_implementation("https", HTTPFileSystem, clobber=True) + register_implementation("sync-http", HTTPFileSystem, clobber=True) + register_implementation("sync-https", HTTPFileSystem, clobber=True) + + +register() + + +def unregister(): + from fsspec.implementations.http import HTTPFileSystem + + register_implementation("http", HTTPFileSystem, clobber=True) + register_implementation("https", HTTPFileSystem, clobber=True) diff --git a/fsspec/implementations/jupyter.py b/fsspec/implementations/jupyter.py new file mode 100644 index 0000000000000000000000000000000000000000..2839f4c1feea56dddd54bdc00f0b884c8461d29e --- /dev/null +++ b/fsspec/implementations/jupyter.py @@ -0,0 +1,124 @@ +import base64 +import io +import re + +import requests + +import fsspec + + +class JupyterFileSystem(fsspec.AbstractFileSystem): + """View of the files as seen by a Jupyter server (notebook or lab)""" + + protocol = ("jupyter", "jlab") + + def __init__(self, url, tok=None, **kwargs): + """ + + Parameters + ---------- + url : str + Base URL of the server, like "http://127.0.0.1:8888". May include + token in the string, which is given by the process when starting up + tok : str + If the token is obtained separately, can be given here + kwargs + """ + if "?" in url: + if tok is None: + try: + tok = re.findall("token=([a-z0-9]+)", url)[0] + except IndexError as e: + raise ValueError("Could not determine token") from e + url = url.split("?", 1)[0] + self.url = url.rstrip("/") + "/api/contents" + self.session = requests.Session() + if tok: + self.session.headers["Authorization"] = f"token {tok}" + + super().__init__(**kwargs) + + def ls(self, path, detail=True, **kwargs): + path = self._strip_protocol(path) + r = self.session.get(f"{self.url}/{path}") + if r.status_code == 404: + return FileNotFoundError(path) + r.raise_for_status() + out = r.json() + + if out["type"] == "directory": + out = out["content"] + else: + out = [out] + for o in out: + o["name"] = o.pop("path") + o.pop("content") + if o["type"] == "notebook": + o["type"] = "file" + if detail: + return out + return [o["name"] for o in out] + + def cat_file(self, path, start=None, end=None, **kwargs): + path = self._strip_protocol(path) + r = self.session.get(f"{self.url}/{path}") + if r.status_code == 404: + return FileNotFoundError(path) + r.raise_for_status() + out = r.json() + if out["format"] == "text": + # data should be binary + b = out["content"].encode() + else: + b = base64.b64decode(out["content"]) + return b[start:end] + + def pipe_file(self, path, value, **_): + path = self._strip_protocol(path) + json = { + "name": path.rsplit("/", 1)[-1], + "path": path, + "size": len(value), + "content": base64.b64encode(value).decode(), + "format": "base64", + "type": "file", + } + self.session.put(f"{self.url}/{path}", json=json) + + def mkdir(self, path, create_parents=True, **kwargs): + path = self._strip_protocol(path) + if create_parents and "/" in path: + self.mkdir(path.rsplit("/", 1)[0], True) + json = { + "name": path.rsplit("/", 1)[-1], + "path": path, + "size": None, + "content": None, + "type": "directory", + } + self.session.put(f"{self.url}/{path}", json=json) + + def _rm(self, path): + path = self._strip_protocol(path) + self.session.delete(f"{self.url}/{path}") + + def _open(self, path, mode="rb", **kwargs): + path = self._strip_protocol(path) + if mode == "rb": + data = self.cat_file(path) + return io.BytesIO(data) + else: + return SimpleFileWriter(self, path, mode="wb") + + +class SimpleFileWriter(fsspec.spec.AbstractBufferedFile): + def _upload_chunk(self, final=False): + """Never uploads a chunk until file is done + + Not suitable for large files + """ + if final is False: + return False + self.buffer.seek(0) + data = self.buffer.read() + self.fs.pipe_file(self.path, data) diff --git a/fsspec/implementations/libarchive.py b/fsspec/implementations/libarchive.py new file mode 100644 index 0000000000000000000000000000000000000000..eb6f145352e1989e0477e259be02d8d7f4d729e2 --- /dev/null +++ b/fsspec/implementations/libarchive.py @@ -0,0 +1,213 @@ +from contextlib import contextmanager +from ctypes import ( + CFUNCTYPE, + POINTER, + c_int, + c_longlong, + c_void_p, + cast, + create_string_buffer, +) + +import libarchive +import libarchive.ffi as ffi + +from fsspec import open_files +from fsspec.archive import AbstractArchiveFileSystem +from fsspec.implementations.memory import MemoryFile +from fsspec.utils import DEFAULT_BLOCK_SIZE + +# Libarchive requires seekable files or memory only for certain archive +# types. However, since we read the directory first to cache the contents +# and also allow random access to any file, the file-like object needs +# to be seekable no matter what. + +# Seek call-backs (not provided in the libarchive python wrapper) +SEEK_CALLBACK = CFUNCTYPE(c_longlong, c_int, c_void_p, c_longlong, c_int) +read_set_seek_callback = ffi.ffi( + "read_set_seek_callback", [ffi.c_archive_p, SEEK_CALLBACK], c_int, ffi.check_int +) +new_api = hasattr(ffi, "NO_OPEN_CB") + + +@contextmanager +def custom_reader(file, format_name="all", filter_name="all", block_size=ffi.page_size): + """Read an archive from a seekable file-like object. + + The `file` object must support the standard `readinto` and 'seek' methods. + """ + buf = create_string_buffer(block_size) + buf_p = cast(buf, c_void_p) + + def read_func(archive_p, context, ptrptr): + # readinto the buffer, returns number of bytes read + length = file.readinto(buf) + # write the address of the buffer into the pointer + ptrptr = cast(ptrptr, POINTER(c_void_p)) + ptrptr[0] = buf_p + # tell libarchive how much data was written into the buffer + return length + + def seek_func(archive_p, context, offset, whence): + file.seek(offset, whence) + # tell libarchvie the current position + return file.tell() + + read_cb = ffi.READ_CALLBACK(read_func) + seek_cb = SEEK_CALLBACK(seek_func) + + if new_api: + open_cb = ffi.NO_OPEN_CB + close_cb = ffi.NO_CLOSE_CB + else: + open_cb = libarchive.read.OPEN_CALLBACK(ffi.VOID_CB) + close_cb = libarchive.read.CLOSE_CALLBACK(ffi.VOID_CB) + + with libarchive.read.new_archive_read(format_name, filter_name) as archive_p: + read_set_seek_callback(archive_p, seek_cb) + ffi.read_open(archive_p, None, open_cb, read_cb, close_cb) + yield libarchive.read.ArchiveRead(archive_p) + + +class LibArchiveFileSystem(AbstractArchiveFileSystem): + """Compressed archives as a file-system (read-only) + + Supports the following formats: + tar, pax , cpio, ISO9660, zip, mtree, shar, ar, raw, xar, lha/lzh, rar + Microsoft CAB, 7-Zip, WARC + + See the libarchive documentation for further restrictions. + https://www.libarchive.org/ + + Keeps file object open while instance lives. It only works in seekable + file-like objects. In case the filesystem does not support this kind of + file object, it is recommended to cache locally. + + This class is pickleable, but not necessarily thread-safe (depends on the + platform). See libarchive documentation for details. + """ + + root_marker = "" + protocol = "libarchive" + cachable = False + + def __init__( + self, + fo="", + mode="r", + target_protocol=None, + target_options=None, + block_size=DEFAULT_BLOCK_SIZE, + **kwargs, + ): + """ + Parameters + ---------- + fo: str or file-like + Contains ZIP, and must exist. If a str, will fetch file using + :meth:`~fsspec.open_files`, which must return one file exactly. + mode: str + Currently, only 'r' accepted + target_protocol: str (optional) + If ``fo`` is a string, this value can be used to override the + FS protocol inferred from a URL + target_options: dict (optional) + Kwargs passed when instantiating the target FS, if ``fo`` is + a string. + """ + super().__init__(self, **kwargs) + if mode != "r": + raise ValueError("Only read from archive files accepted") + if isinstance(fo, str): + files = open_files(fo, protocol=target_protocol, **(target_options or {})) + if len(files) != 1: + raise ValueError( + f'Path "{fo}" did not resolve to exactly one file: "{files}"' + ) + fo = files[0] + self.of = fo + self.fo = fo.__enter__() # the whole instance is a context + self.block_size = block_size + self.dir_cache = None + + @contextmanager + def _open_archive(self): + self.fo.seek(0) + with custom_reader(self.fo, block_size=self.block_size) as arc: + yield arc + + @classmethod + def _strip_protocol(cls, path): + # file paths are always relative to the archive root + return super()._strip_protocol(path).lstrip("/") + + def _get_dirs(self): + fields = { + "name": "pathname", + "size": "size", + "created": "ctime", + "mode": "mode", + "uid": "uid", + "gid": "gid", + "mtime": "mtime", + } + + if self.dir_cache is not None: + return + + self.dir_cache = {} + list_names = [] + with self._open_archive() as arc: + for entry in arc: + if not entry.isdir and not entry.isfile: + # Skip symbolic links, fifo entries, etc. + continue + self.dir_cache.update( + { + dirname: {"name": dirname, "size": 0, "type": "directory"} + for dirname in self._all_dirnames(set(entry.name)) + } + ) + f = {key: getattr(entry, fields[key]) for key in fields} + f["type"] = "directory" if entry.isdir else "file" + list_names.append(entry.name) + + self.dir_cache[f["name"]] = f + # libarchive does not seem to return an entry for the directories (at least + # not in all formats), so get the directories names from the files names + self.dir_cache.update( + { + dirname: {"name": dirname, "size": 0, "type": "directory"} + for dirname in self._all_dirnames(list_names) + } + ) + + def _open( + self, + path, + mode="rb", + block_size=None, + autocommit=True, + cache_options=None, + **kwargs, + ): + path = self._strip_protocol(path) + if mode != "rb": + raise NotImplementedError + + data = bytes() + with self._open_archive() as arc: + for entry in arc: + if entry.pathname != path: + continue + + if entry.size == 0: + # empty file, so there are no blocks + break + + for block in entry.get_blocks(entry.size): + data = block + break + else: + raise ValueError + return MemoryFile(fs=self, path=path, data=data) diff --git a/fsspec/implementations/local.py b/fsspec/implementations/local.py new file mode 100644 index 0000000000000000000000000000000000000000..64dc8bb21956e9c5024b47d46d9e9550931891d8 --- /dev/null +++ b/fsspec/implementations/local.py @@ -0,0 +1,514 @@ +import datetime +import io +import logging +import os +import os.path as osp +import shutil +import stat +import tempfile +from functools import lru_cache + +from fsspec import AbstractFileSystem +from fsspec.compression import compr +from fsspec.core import get_compression +from fsspec.utils import isfilelike, stringify_path + +logger = logging.getLogger("fsspec.local") + + +class LocalFileSystem(AbstractFileSystem): + """Interface to files on local storage + + Parameters + ---------- + auto_mkdir: bool + Whether, when opening a file, the directory containing it should + be created (if it doesn't already exist). This is assumed by pyarrow + code. + """ + + root_marker = "/" + protocol = "file", "local" + local_file = True + + def __init__(self, auto_mkdir=False, **kwargs): + super().__init__(**kwargs) + self.auto_mkdir = auto_mkdir + + @property + def fsid(self): + return "local" + + def mkdir(self, path, create_parents=True, **kwargs): + path = self._strip_protocol(path) + if self.exists(path): + raise FileExistsError(path) + if create_parents: + self.makedirs(path, exist_ok=True) + else: + os.mkdir(path, **kwargs) + + def makedirs(self, path, exist_ok=False): + path = self._strip_protocol(path) + os.makedirs(path, exist_ok=exist_ok) + + def rmdir(self, path): + path = self._strip_protocol(path) + os.rmdir(path) + + def ls(self, path, detail=False, **kwargs): + path = self._strip_protocol(path) + path_info = self.info(path) + infos = [] + if path_info["type"] == "directory": + with os.scandir(path) as it: + for f in it: + try: + # Only get the info if requested since it is a bit expensive (the stat call inside) + # The strip_protocol is also used in info() and calls make_path_posix to always return posix paths + info = self.info(f) if detail else self._strip_protocol(f.path) + infos.append(info) + except FileNotFoundError: + pass + else: + infos = [path_info] if detail else [path_info["name"]] + + return infos + + def info(self, path, **kwargs): + if isinstance(path, os.DirEntry): + # scandir DirEntry + out = path.stat(follow_symlinks=False) + link = path.is_symlink() + if path.is_dir(follow_symlinks=False): + t = "directory" + elif path.is_file(follow_symlinks=False): + t = "file" + else: + t = "other" + + size = out.st_size + if link: + try: + out2 = path.stat(follow_symlinks=True) + size = out2.st_size + except OSError: + size = 0 + path = self._strip_protocol(path.path) + else: + # str or path-like + path = self._strip_protocol(path) + out = os.stat(path, follow_symlinks=False) + link = stat.S_ISLNK(out.st_mode) + if link: + out = os.stat(path, follow_symlinks=True) + size = out.st_size + if stat.S_ISDIR(out.st_mode): + t = "directory" + elif stat.S_ISREG(out.st_mode): + t = "file" + else: + t = "other" + + # Check for the 'st_birthtime' attribute, which is not always present; fallback to st_ctime + created_time = getattr(out, "st_birthtime", out.st_ctime) + + result = { + "name": path, + "size": size, + "type": t, + "created": created_time, + "islink": link, + } + for field in ["mode", "uid", "gid", "mtime", "ino", "nlink"]: + result[field] = getattr(out, f"st_{field}") + if link: + result["destination"] = os.readlink(path) + return result + + def lexists(self, path, **kwargs): + return osp.lexists(path) + + def cp_file(self, path1, path2, **kwargs): + path1 = self._strip_protocol(path1) + path2 = self._strip_protocol(path2) + if self.auto_mkdir: + self.makedirs(self._parent(path2), exist_ok=True) + if self.isfile(path1): + shutil.copyfile(path1, path2) + elif self.isdir(path1): + self.mkdirs(path2, exist_ok=True) + else: + raise FileNotFoundError(path1) + + def isfile(self, path): + path = self._strip_protocol(path) + return os.path.isfile(path) + + def isdir(self, path): + path = self._strip_protocol(path) + return os.path.isdir(path) + + def get_file(self, path1, path2, callback=None, **kwargs): + if isfilelike(path2): + with open(path1, "rb") as f: + shutil.copyfileobj(f, path2) + else: + return self.cp_file(path1, path2, **kwargs) + + def put_file(self, path1, path2, callback=None, **kwargs): + return self.cp_file(path1, path2, **kwargs) + + def mv(self, path1, path2, recursive: bool = True, **kwargs): + """Move files/directories + For the specific case of local, all ops on directories are recursive and + the recursive= kwarg is ignored. + """ + path1 = self._strip_protocol(path1) + path2 = self._strip_protocol(path2) + shutil.move(path1, path2) + + def link(self, src, dst, **kwargs): + src = self._strip_protocol(src) + dst = self._strip_protocol(dst) + os.link(src, dst, **kwargs) + + def symlink(self, src, dst, **kwargs): + src = self._strip_protocol(src) + dst = self._strip_protocol(dst) + os.symlink(src, dst, **kwargs) + + def islink(self, path) -> bool: + return os.path.islink(self._strip_protocol(path)) + + def rm_file(self, path): + os.remove(self._strip_protocol(path)) + + def rm(self, path, recursive=False, maxdepth=None): + if not isinstance(path, list): + path = [path] + + for p in path: + p = self._strip_protocol(p) + if self.isdir(p): + if not recursive: + raise ValueError("Cannot delete directory, set recursive=True") + if osp.abspath(p) == os.getcwd(): + raise ValueError("Cannot delete current working directory") + shutil.rmtree(p) + else: + os.remove(p) + + def unstrip_protocol(self, name): + name = self._strip_protocol(name) # normalise for local/win/... + return f"file://{name}" + + def _open(self, path, mode="rb", block_size=None, **kwargs): + path = self._strip_protocol(path) + if self.auto_mkdir and "w" in mode: + self.makedirs(self._parent(path), exist_ok=True) + return LocalFileOpener(path, mode, fs=self, **kwargs) + + def touch(self, path, truncate=True, **kwargs): + path = self._strip_protocol(path) + if self.auto_mkdir: + self.makedirs(self._parent(path), exist_ok=True) + if self.exists(path): + os.utime(path, None) + else: + open(path, "a").close() + if truncate: + os.truncate(path, 0) + + def created(self, path): + info = self.info(path=path) + return datetime.datetime.fromtimestamp( + info["created"], tz=datetime.timezone.utc + ) + + def modified(self, path): + info = self.info(path=path) + return datetime.datetime.fromtimestamp(info["mtime"], tz=datetime.timezone.utc) + + @classmethod + def _parent(cls, path): + path = cls._strip_protocol(path) + if os.sep == "/": + # posix native + return path.rsplit("/", 1)[0] or "/" + else: + # NT + path_ = path.rsplit("/", 1)[0] + if len(path_) <= 3: + if path_[1:2] == ":": + # nt root (something like c:/) + return path_[0] + ":/" + # More cases may be required here + return path_ + + @classmethod + def _strip_protocol(cls, path): + path = stringify_path(path) + if path.startswith("file://"): + path = path[7:] + elif path.startswith("file:"): + path = path[5:] + elif path.startswith("local://"): + path = path[8:] + elif path.startswith("local:"): + path = path[6:] + + path = make_path_posix(path) + if os.sep != "/": + # This code-path is a stripped down version of + # > drive, path = ntpath.splitdrive(path) + if path[1:2] == ":": + # Absolute drive-letter path, e.g. X:\Windows + # Relative path with drive, e.g. X:Windows + drive, path = path[:2], path[2:] + elif path[:2] == "//": + # UNC drives, e.g. \\server\share or \\?\UNC\server\share + # Device drives, e.g. \\.\device or \\?\device + if (index1 := path.find("/", 2)) == -1 or ( + index2 := path.find("/", index1 + 1) + ) == -1: + drive, path = path, "" + else: + drive, path = path[:index2], path[index2:] + else: + # Relative path, e.g. Windows + drive = "" + + path = path.rstrip("/") or cls.root_marker + return drive + path + + else: + return path.rstrip("/") or cls.root_marker + + def _isfilestore(self): + # Inheriting from DaskFileSystem makes this False (S3, etc. were) + # the original motivation. But we are a posix-like file system. + # See https://github.com/dask/dask/issues/5526 + return True + + def chmod(self, path, mode): + path = stringify_path(path) + return os.chmod(path, mode) + + +def make_path_posix(path): + """Make path generic and absolute for current OS""" + if not isinstance(path, str): + if isinstance(path, (list, set, tuple)): + return type(path)(make_path_posix(p) for p in path) + else: + path = stringify_path(path) + if not isinstance(path, str): + raise TypeError(f"could not convert {path!r} to string") + if os.sep == "/": + # Native posix + if path.startswith("/"): + # most common fast case for posix + return path + elif path.startswith("~"): + return osp.expanduser(path) + elif path.startswith("./"): + path = path[2:] + elif path == ".": + path = "" + return f"{os.getcwd()}/{path}" + else: + # NT handling + if path[0:1] == "/" and path[2:3] == ":": + # path is like "/c:/local/path" + path = path[1:] + if path[1:2] == ":": + # windows full path like "C:\\local\\path" + if len(path) <= 3: + # nt root (something like c:/) + return path[0] + ":/" + path = path.replace("\\", "/") + return path + elif path[0:1] == "~": + return make_path_posix(osp.expanduser(path)) + elif path.startswith(("\\\\", "//")): + # windows UNC/DFS-style paths + return "//" + path[2:].replace("\\", "/") + elif path.startswith(("\\", "/")): + # windows relative path with root + path = path.replace("\\", "/") + return f"{osp.splitdrive(os.getcwd())[0]}{path}" + else: + path = path.replace("\\", "/") + if path.startswith("./"): + path = path[2:] + elif path == ".": + path = "" + return f"{make_path_posix(os.getcwd())}/{path}" + + +def trailing_sep(path): + """Return True if the path ends with a path separator. + + A forward slash is always considered a path separator, even on Operating + Systems that normally use a backslash. + """ + # TODO: if all incoming paths were posix-compliant then separator would + # always be a forward slash, simplifying this function. + # See https://github.com/fsspec/filesystem_spec/pull/1250 + return path.endswith(os.sep) or (os.altsep is not None and path.endswith(os.altsep)) + + +@lru_cache(maxsize=1) +def get_umask(mask: int = 0o666) -> int: + """Get the current umask. + + Follows https://stackoverflow.com/a/44130549 to get the umask. + Temporarily sets the umask to the given value, and then resets it to the + original value. + """ + value = os.umask(mask) + os.umask(value) + return value + + +class LocalFileOpener(io.IOBase): + def __init__( + self, path, mode, autocommit=True, fs=None, compression=None, **kwargs + ): + logger.debug("open file: %s", path) + self.path = path + self.mode = mode + self.fs = fs + self.f = None + self.autocommit = autocommit + self.compression = get_compression(path, compression) + self.blocksize = io.DEFAULT_BUFFER_SIZE + self._open() + + def _open(self): + if self.f is None or self.f.closed: + if self.autocommit or "w" not in self.mode: + self.f = open(self.path, mode=self.mode) + if self.compression: + compress = compr[self.compression] + self.f = compress(self.f, mode=self.mode) + else: + # TODO: check if path is writable? + i, name = tempfile.mkstemp() + os.close(i) # we want normal open and normal buffered file + self.temp = name + self.f = open(name, mode=self.mode) + if "w" not in self.mode: + self.size = self.f.seek(0, 2) + self.f.seek(0) + self.f.size = self.size + + def _fetch_range(self, start, end): + # probably only used by cached FS + if "r" not in self.mode: + raise ValueError + self._open() + self.f.seek(start) + return self.f.read(end - start) + + def __setstate__(self, state): + self.f = None + loc = state.pop("loc", None) + self.__dict__.update(state) + if "r" in state["mode"]: + self.f = None + self._open() + self.f.seek(loc) + + def __getstate__(self): + d = self.__dict__.copy() + d.pop("f") + if "r" in self.mode: + d["loc"] = self.f.tell() + else: + if not self.f.closed: + raise ValueError("Cannot serialise open write-mode local file") + return d + + def commit(self): + if self.autocommit: + raise RuntimeError("Can only commit if not already set to autocommit") + try: + shutil.move(self.temp, self.path) + except PermissionError as e: + # shutil.move raises PermissionError if os.rename + # and the default copy2 fallback with shutil.copystats fail. + # The file should be there nonetheless, but without copied permissions. + # If it doesn't exist, there was no permission to create the file. + if not os.path.exists(self.path): + raise e + else: + # If PermissionError is not raised, permissions can be set. + try: + mask = 0o666 + os.chmod(self.path, mask & ~get_umask(mask)) + except RuntimeError: + pass + + def discard(self): + if self.autocommit: + raise RuntimeError("Cannot discard if set to autocommit") + os.remove(self.temp) + + def readable(self) -> bool: + return True + + def writable(self) -> bool: + return "r" not in self.mode + + def read(self, *args, **kwargs): + return self.f.read(*args, **kwargs) + + def write(self, *args, **kwargs): + return self.f.write(*args, **kwargs) + + def tell(self, *args, **kwargs): + return self.f.tell(*args, **kwargs) + + def seek(self, *args, **kwargs): + return self.f.seek(*args, **kwargs) + + def seekable(self, *args, **kwargs): + return self.f.seekable(*args, **kwargs) + + def readline(self, *args, **kwargs): + return self.f.readline(*args, **kwargs) + + def readlines(self, *args, **kwargs): + return self.f.readlines(*args, **kwargs) + + def close(self): + return self.f.close() + + def truncate(self, size=None) -> int: + return self.f.truncate(size) + + @property + def closed(self): + return self.f.closed + + def fileno(self): + return self.raw.fileno() + + def flush(self) -> None: + self.f.flush() + + def __iter__(self): + return self.f.__iter__() + + def __getattr__(self, item): + return getattr(self.f, item) + + def __enter__(self): + self._incontext = True + return self + + def __exit__(self, exc_type, exc_value, traceback): + self._incontext = False + self.f.__exit__(exc_type, exc_value, traceback) diff --git a/fsspec/implementations/memory.py b/fsspec/implementations/memory.py new file mode 100644 index 0000000000000000000000000000000000000000..838876d1f87faa28ed05f02454e210fa534b736e --- /dev/null +++ b/fsspec/implementations/memory.py @@ -0,0 +1,311 @@ +from __future__ import annotations + +import logging +from datetime import datetime, timezone +from errno import ENOTEMPTY +from io import BytesIO +from pathlib import PurePath, PureWindowsPath +from typing import Any, ClassVar + +from fsspec import AbstractFileSystem +from fsspec.implementations.local import LocalFileSystem +from fsspec.utils import stringify_path + +logger = logging.getLogger("fsspec.memoryfs") + + +class MemoryFileSystem(AbstractFileSystem): + """A filesystem based on a dict of BytesIO objects + + This is a global filesystem so instances of this class all point to the same + in memory filesystem. + """ + + store: ClassVar[dict[str, Any]] = {} # global, do not overwrite! + pseudo_dirs = [""] # global, do not overwrite! + protocol = "memory" + root_marker = "/" + + @classmethod + def _strip_protocol(cls, path): + if isinstance(path, PurePath): + if isinstance(path, PureWindowsPath): + return LocalFileSystem._strip_protocol(path) + else: + path = stringify_path(path) + + path = path.removeprefix("memory://") + if "::" in path or "://" in path: + return path.rstrip("/") + path = path.lstrip("/").rstrip("/") + return "/" + path if path else "" + + def ls(self, path, detail=True, **kwargs): + path = self._strip_protocol(path) + if path in self.store: + # there is a key with this exact name + if not detail: + return [path] + return [ + { + "name": path, + "size": self.store[path].size, + "type": "file", + "created": self.store[path].created.timestamp(), + } + ] + paths = set() + starter = path + "/" + out = [] + for p2 in tuple(self.store): + if p2.startswith(starter): + if "/" not in p2[len(starter) :]: + # exact child + out.append( + { + "name": p2, + "size": self.store[p2].size, + "type": "file", + "created": self.store[p2].created.timestamp(), + } + ) + elif len(p2) > len(starter): + # implied child directory + ppath = starter + p2[len(starter) :].split("/", 1)[0] + if ppath not in paths: + out = out or [] + out.append( + { + "name": ppath, + "size": 0, + "type": "directory", + } + ) + paths.add(ppath) + for p2 in self.pseudo_dirs: + if p2.startswith(starter): + if "/" not in p2[len(starter) :]: + # exact child pdir + if p2 not in paths: + out.append({"name": p2, "size": 0, "type": "directory"}) + paths.add(p2) + else: + # directory implied by deeper pdir + ppath = starter + p2[len(starter) :].split("/", 1)[0] + if ppath not in paths: + out.append({"name": ppath, "size": 0, "type": "directory"}) + paths.add(ppath) + if not out: + if path in self.pseudo_dirs: + # empty dir + return [] + raise FileNotFoundError(path) + if detail: + return out + return sorted([f["name"] for f in out]) + + def mkdir(self, path, create_parents=True, **kwargs): + path = self._strip_protocol(path) + if path in self.store or path in self.pseudo_dirs: + raise FileExistsError(path) + if self._parent(path).strip("/") and self.isfile(self._parent(path)): + raise NotADirectoryError(self._parent(path)) + if create_parents and self._parent(path).strip("/"): + try: + self.mkdir(self._parent(path), create_parents, **kwargs) + except FileExistsError: + pass + if path and path not in self.pseudo_dirs: + self.pseudo_dirs.append(path) + + def makedirs(self, path, exist_ok=False): + try: + self.mkdir(path, create_parents=True) + except FileExistsError: + if not exist_ok: + raise + + def pipe_file(self, path, value, mode="overwrite", **kwargs): + """Set the bytes of given file + + Avoids copies of the data if possible + """ + mode = "xb" if mode == "create" else "wb" + self.open(path, mode=mode, data=value) + + def rmdir(self, path): + path = self._strip_protocol(path) + if path == "": + # silently avoid deleting FS root + return + if path in self.pseudo_dirs: + if not self.ls(path): + self.pseudo_dirs.remove(path) + else: + raise OSError(ENOTEMPTY, "Directory not empty", path) + else: + raise FileNotFoundError(path) + + def info(self, path, **kwargs): + logger.debug("info: %s", path) + path = self._strip_protocol(path) + if path in self.pseudo_dirs or any( + p.startswith(path + "/") for p in list(self.store) + self.pseudo_dirs + ): + return { + "name": path, + "size": 0, + "type": "directory", + } + elif path in self.store: + filelike = self.store[path] + return { + "name": path, + "size": filelike.size, + "type": "file", + "created": getattr(filelike, "created", None), + } + else: + raise FileNotFoundError(path) + + def _open( + self, + path, + mode="rb", + block_size=None, + autocommit=True, + cache_options=None, + **kwargs, + ): + path = self._strip_protocol(path) + if "x" in mode and self.exists(path): + raise FileExistsError + if path in self.pseudo_dirs: + raise IsADirectoryError(path) + parent = path + while len(parent) > 1: + parent = self._parent(parent) + if self.isfile(parent): + raise FileExistsError(parent) + if mode in ["rb", "ab", "r+b"]: + if path in self.store: + f = self.store[path] + if mode == "ab": + # position at the end of file + f.seek(0, 2) + else: + # position at the beginning of file + f.seek(0) + return f + else: + raise FileNotFoundError(path) + elif mode in {"wb", "xb"}: + if mode == "xb" and self.exists(path): + raise FileExistsError + m = MemoryFile(self, path, kwargs.get("data")) + if not self._intrans: + m.commit() + return m + else: + name = self.__class__.__name__ + raise ValueError(f"unsupported file mode for {name}: {mode!r}") + + def cp_file(self, path1, path2, **kwargs): + path1 = self._strip_protocol(path1) + path2 = self._strip_protocol(path2) + if self.isfile(path1): + self.store[path2] = MemoryFile( + self, path2, self.store[path1].getvalue() + ) # implicit copy + elif self.isdir(path1): + if path2 not in self.pseudo_dirs: + self.pseudo_dirs.append(path2) + else: + raise FileNotFoundError(path1) + + def cat_file(self, path, start=None, end=None, **kwargs): + logger.debug("cat: %s", path) + path = self._strip_protocol(path) + try: + return bytes(self.store[path].getbuffer()[start:end]) + except KeyError as e: + raise FileNotFoundError(path) from e + + def _rm(self, path): + path = self._strip_protocol(path) + try: + del self.store[path] + except KeyError as e: + raise FileNotFoundError(path) from e + + def modified(self, path): + path = self._strip_protocol(path) + try: + return self.store[path].modified + except KeyError as e: + raise FileNotFoundError(path) from e + + def created(self, path): + path = self._strip_protocol(path) + try: + return self.store[path].created + except KeyError as e: + raise FileNotFoundError(path) from e + + def isfile(self, path): + path = self._strip_protocol(path) + return path in self.store + + def rm(self, path, recursive=False, maxdepth=None): + if isinstance(path, str): + path = self._strip_protocol(path) + else: + path = [self._strip_protocol(p) for p in path] + paths = self.expand_path(path, recursive=recursive, maxdepth=maxdepth) + for p in reversed(paths): + if self.isfile(p): + self.rm_file(p) + # If the expanded path doesn't exist, it is only because the expanded + # path was a directory that does not exist in self.pseudo_dirs. This + # is possible if you directly create files without making the + # directories first. + elif not self.exists(p): + continue + else: + self.rmdir(p) + + +class MemoryFile(BytesIO): + """A BytesIO which can't close and works as a context manager + + Can initialise with data. Each path should only be active once at any moment. + + No need to provide fs, path if auto-committing (default) + """ + + def __init__(self, fs=None, path=None, data=None): + logger.debug("open file %s", path) + self.fs = fs + self.path = path + self.created = datetime.now(tz=timezone.utc) + self.modified = datetime.now(tz=timezone.utc) + if data: + super().__init__(data) + self.seek(0) + + @property + def size(self): + return self.getbuffer().nbytes + + def __enter__(self): + return self + + def close(self): + pass + + def discard(self): + pass + + def commit(self): + self.fs.store[self.path] = self + self.modified = datetime.now(tz=timezone.utc) diff --git a/fsspec/implementations/reference.py b/fsspec/implementations/reference.py new file mode 100644 index 0000000000000000000000000000000000000000..66be6ec7ca1cbebc5d80a328f7ed78becaeb0a37 --- /dev/null +++ b/fsspec/implementations/reference.py @@ -0,0 +1,1305 @@ +import base64 +import collections +import io +import itertools +import logging +import math +import os +from functools import lru_cache +from itertools import chain +from typing import TYPE_CHECKING, Literal + +import fsspec.core +from fsspec.spec import AbstractBufferedFile + +try: + import ujson as json +except ImportError: + if not TYPE_CHECKING: + import json + +from fsspec.asyn import AsyncFileSystem +from fsspec.callbacks import DEFAULT_CALLBACK +from fsspec.core import filesystem, open, split_protocol +from fsspec.implementations.asyn_wrapper import AsyncFileSystemWrapper +from fsspec.utils import isfilelike, merge_offset_ranges, other_paths + +logger = logging.getLogger("fsspec.reference") + + +class ReferenceNotReachable(RuntimeError): + def __init__(self, reference, target, *args): + super().__init__(*args) + self.reference = reference + self.target = target + + def __str__(self): + return f'Reference "{self.reference}" failed to fetch target {self.target}' + + +def _first(d): + return next(iter(d.values())) + + +def _prot_in_references(path, references): + ref = references.get(path) + if isinstance(ref, (list, tuple)) and isinstance(ref[0], str): + return split_protocol(ref[0])[0] if ref[0] else ref[0] + + +def _protocol_groups(paths, references): + if isinstance(paths, str): + return {_prot_in_references(paths, references): [paths]} + out = {} + for path in paths: + protocol = _prot_in_references(path, references) + out.setdefault(protocol, []).append(path) + return out + + +class RefsValuesView(collections.abc.ValuesView): + def __iter__(self): + for val in self._mapping.zmetadata.values(): + yield json.dumps(val).encode() + yield from self._mapping._items.values() + for field in self._mapping.listdir(): + chunk_sizes = self._mapping._get_chunk_sizes(field) + if len(chunk_sizes) == 0: + yield self._mapping[field + "/0"] + continue + yield from self._mapping._generate_all_records(field) + + +class RefsItemsView(collections.abc.ItemsView): + def __iter__(self): + return zip(self._mapping.keys(), self._mapping.values()) + + +def ravel_multi_index(idx, sizes): + val = 0 + mult = 1 + for i, s in zip(idx[::-1], sizes[::-1]): + val += i * mult + mult *= s + return val + + +class LazyReferenceMapper(collections.abc.MutableMapping): + """This interface can be used to read/write references from Parquet stores. + It is not intended for other types of references. + It can be used with Kerchunk's MultiZarrToZarr method to combine + references into a parquet store. + Examples of this use-case can be found here: + https://fsspec.github.io/kerchunk/advanced.html?highlight=parquet#parquet-storage""" + + # import is class level to prevent numpy dep requirement for fsspec + @property + def np(self): + import numpy as np + + return np + + @property + def pd(self): + import pandas as pd + + return pd + + def __init__( + self, + root, + fs=None, + out_root=None, + cache_size=128, + categorical_threshold=10, + engine: Literal["fastparquet", "pyarrow"] = "fastparquet", + ): + """ + + This instance will be writable, storing changes in memory until full partitions + are accumulated or .flush() is called. + + To create an empty lazy store, use .create() + + Parameters + ---------- + root : str + Root of parquet store + fs : fsspec.AbstractFileSystem + fsspec filesystem object, default is local filesystem. + cache_size : int, default=128 + Maximum size of LRU cache, where cache_size*record_size denotes + the total number of references that can be loaded in memory at once. + categorical_threshold : int + Encode urls as pandas.Categorical to reduce memory footprint if the ratio + of the number of unique urls to total number of refs for each variable + is greater than or equal to this number. (default 10) + engine: Literal["fastparquet","pyarrow"] + Engine choice for reading parquet files. (default is "fastparquet") + """ + + self.root = root + self.chunk_sizes = {} + self.cat_thresh = categorical_threshold + self.engine = engine + self.cache_size = cache_size + self.url = self.root + "/{field}/refs.{record}.parq" + # TODO: derive fs from `root` + self.fs = fsspec.filesystem("file") if fs is None else fs + self.out_root = self.fs.unstrip_protocol(out_root or self.root) + + from importlib.util import find_spec + + if self.engine == "pyarrow" and find_spec("pyarrow") is None: + raise ImportError("engine choice `pyarrow` is not installed.") + + def __getattr__(self, item): + if item in ("_items", "record_size", "zmetadata"): + self.setup() + # avoid possible recursion if setup fails somehow + return self.__dict__[item] + raise AttributeError(item) + + def setup(self): + self._items = {} + self._items[".zmetadata"] = self.fs.cat_file( + "/".join([self.root, ".zmetadata"]) + ) + met = json.loads(self._items[".zmetadata"]) + self.record_size = met["record_size"] + self.zmetadata = met["metadata"] + + # Define function to open and decompress refs + @lru_cache(maxsize=self.cache_size) + def open_refs(field, record): + """cached parquet file loader""" + path = self.url.format(field=field, record=record) + data = io.BytesIO(self.fs.cat_file(path)) + try: + df = self.pd.read_parquet(data, engine=self.engine) + refs = {c: df[c].to_numpy() for c in df.columns} + except OSError: + refs = None + return refs + + self.open_refs = open_refs + + @staticmethod + def create(root, storage_options=None, fs=None, record_size=10000, **kwargs): + """Make empty parquet reference set + + First deletes the contents of the given directory, if it exists. + + Parameters + ---------- + root: str + Directory to contain the output; will be created + storage_options: dict | None + For making the filesystem to use for writing is fs is None + fs: FileSystem | None + Filesystem for writing + record_size: int + Number of references per parquet file + kwargs: passed to __init__ + + Returns + ------- + LazyReferenceMapper instance + """ + met = {"metadata": {}, "record_size": record_size} + if fs is None: + fs, root = fsspec.core.url_to_fs(root, **(storage_options or {})) + if fs.exists(root): + fs.rm(root, recursive=True) + fs.makedirs(root, exist_ok=True) + fs.pipe("/".join([root, ".zmetadata"]), json.dumps(met).encode()) + return LazyReferenceMapper(root, fs, **kwargs) + + @lru_cache() + def listdir(self): + """List top-level directories""" + dirs = (p.rsplit("/", 1)[0] for p in self.zmetadata if not p.startswith(".z")) + return set(dirs) + + def ls(self, path="", detail=True): + """Shortcut file listings""" + path = path.rstrip("/") + pathdash = path + "/" if path else "" + dirnames = self.listdir() + dirs = [ + d + for d in dirnames + if d.startswith(pathdash) and "/" not in d.lstrip(pathdash) + ] + if dirs: + others = { + f + for f in chain( + [".zmetadata"], + (name for name in self.zmetadata), + (name for name in self._items), + ) + if f.startswith(pathdash) and "/" not in f.lstrip(pathdash) + } + if detail is False: + others.update(dirs) + return sorted(others) + dirinfo = [{"name": name, "type": "directory", "size": 0} for name in dirs] + fileinfo = [ + { + "name": name, + "type": "file", + "size": len( + json.dumps(self.zmetadata[name]) + if name in self.zmetadata + else self._items[name] + ), + } + for name in others + ] + return sorted(dirinfo + fileinfo, key=lambda s: s["name"]) + field = path + others = set( + [name for name in self.zmetadata if name.startswith(f"{path}/")] + + [name for name in self._items if name.startswith(f"{path}/")] + ) + fileinfo = [ + { + "name": name, + "type": "file", + "size": len( + json.dumps(self.zmetadata[name]) + if name in self.zmetadata + else self._items[name] + ), + } + for name in others + ] + keys = self._keys_in_field(field) + + if detail is False: + return list(others) + list(keys) + recs = self._generate_all_records(field) + recinfo = [ + {"name": name, "type": "file", "size": rec[-1]} + for name, rec in zip(keys, recs) + if rec[0] # filters out path==None, deleted/missing + ] + return fileinfo + recinfo + + def _load_one_key(self, key): + """Get the reference for one key + + Returns bytes, one-element list or three-element list. + """ + if key in self._items: + return self._items[key] + elif key in self.zmetadata: + return json.dumps(self.zmetadata[key]).encode() + elif "/" not in key or self._is_meta(key): + raise KeyError(key) + field, _ = key.rsplit("/", 1) + record, ri, chunk_size = self._key_to_record(key) + maybe = self._items.get((field, record), {}).get(ri, False) + if maybe is None: + # explicitly deleted + raise KeyError + elif maybe: + return maybe + elif chunk_size == 0: + return b"" + + # Chunk keys can be loaded from row group and cached in LRU cache + try: + refs = self.open_refs(field, record) + except (ValueError, TypeError, FileNotFoundError) as exc: + raise KeyError(key) from exc + columns = ["path", "offset", "size", "raw"] + selection = [refs[c][ri] if c in refs else None for c in columns] + raw = selection[-1] + if raw is not None: + return raw + if selection[0] is None: + raise KeyError("This reference does not exist or has been deleted") + if selection[1:3] == [0, 0]: + # URL only + return selection[:1] + # URL, offset, size + return selection[:3] + + @lru_cache(4096) + def _key_to_record(self, key): + """Details needed to construct a reference for one key""" + field, chunk = key.rsplit("/", 1) + chunk_sizes = self._get_chunk_sizes(field) + if len(chunk_sizes) == 0: + return 0, 0, 0 + chunk_idx = [int(c) for c in chunk.split(".")] + chunk_number = ravel_multi_index(chunk_idx, chunk_sizes) + record = chunk_number // self.record_size + ri = chunk_number % self.record_size + return record, ri, len(chunk_sizes) + + def _get_chunk_sizes(self, field): + """The number of chunks along each axis for a given field""" + if field not in self.chunk_sizes: + zarray = self.zmetadata[f"{field}/.zarray"] + size_ratio = [ + math.ceil(s / c) for s, c in zip(zarray["shape"], zarray["chunks"]) + ] + self.chunk_sizes[field] = size_ratio or [1] + return self.chunk_sizes[field] + + def _generate_record(self, field, record): + """The references for a given parquet file of a given field""" + refs = self.open_refs(field, record) + it = iter(zip(*refs.values())) + if len(refs) == 3: + # All urls + return (list(t) for t in it) + elif len(refs) == 1: + # All raws + return refs["raw"] + else: + # Mix of urls and raws + return (list(t[:3]) if not t[3] else t[3] for t in it) + + def _generate_all_records(self, field): + """Load all the references within a field by iterating over the parquet files""" + nrec = 1 + for ch in self._get_chunk_sizes(field): + nrec *= ch + nrec = math.ceil(nrec / self.record_size) + for record in range(nrec): + yield from self._generate_record(field, record) + + def values(self): + return RefsValuesView(self) + + def items(self): + return RefsItemsView(self) + + def __hash__(self): + return id(self) + + def __getitem__(self, key): + return self._load_one_key(key) + + def __setitem__(self, key, value): + if "/" in key and not self._is_meta(key): + field, chunk = key.rsplit("/", 1) + record, i, _ = self._key_to_record(key) + subdict = self._items.setdefault((field, record), {}) + subdict[i] = value + if len(subdict) == self.record_size: + self.write(field, record) + else: + # metadata or top-level + if hasattr(value, "to_bytes"): + val = value.to_bytes().decode() + elif isinstance(value, bytes): + val = value.decode() + else: + val = value + self._items[key] = val + new_value = json.loads(val) + self.zmetadata[key] = {**self.zmetadata.get(key, {}), **new_value} + + @staticmethod + def _is_meta(key): + return key.startswith(".z") or "/.z" in key + + def __delitem__(self, key): + if key in self._items: + del self._items[key] + elif key in self.zmetadata: + del self.zmetadata[key] + else: + if "/" in key and not self._is_meta(key): + field, _ = key.rsplit("/", 1) + record, i, _ = self._key_to_record(key) + subdict = self._items.setdefault((field, record), {}) + subdict[i] = None + if len(subdict) == self.record_size: + self.write(field, record) + else: + # metadata or top-level + self._items[key] = None + + def write(self, field, record, base_url=None, storage_options=None): + # extra requirements if writing + import kerchunk.df + import numpy as np + import pandas as pd + + partition = self._items[(field, record)] + original = False + if len(partition) < self.record_size: + try: + original = self.open_refs(field, record) + except OSError: + pass + + if original: + paths = original["path"] + offsets = original["offset"] + sizes = original["size"] + raws = original["raw"] + else: + paths = np.full(self.record_size, np.nan, dtype="O") + offsets = np.zeros(self.record_size, dtype="int64") + sizes = np.zeros(self.record_size, dtype="int64") + raws = np.full(self.record_size, np.nan, dtype="O") + for j, data in partition.items(): + if isinstance(data, list): + if ( + str(paths.dtype) == "category" + and data[0] not in paths.dtype.categories + ): + paths = paths.add_categories(data[0]) + paths[j] = data[0] + if len(data) > 1: + offsets[j] = data[1] + sizes[j] = data[2] + elif data is None: + # delete + paths[j] = None + offsets[j] = 0 + sizes[j] = 0 + raws[j] = None + else: + # this is the only call into kerchunk, could remove + raws[j] = kerchunk.df._proc_raw(data) + # TODO: only save needed columns + df = pd.DataFrame( + { + "path": paths, + "offset": offsets, + "size": sizes, + "raw": raws, + }, + copy=False, + ) + if df.path.count() / (df.path.nunique() or 1) > self.cat_thresh: + df["path"] = df["path"].astype("category") + object_encoding = {"raw": "bytes", "path": "utf8"} + has_nulls = ["path", "raw"] + + fn = f"{base_url or self.out_root}/{field}/refs.{record}.parq" + self.fs.mkdirs(f"{base_url or self.out_root}/{field}", exist_ok=True) + + if self.engine == "pyarrow": + df_backend_kwargs = {"write_statistics": False} + elif self.engine == "fastparquet": + df_backend_kwargs = { + "stats": False, + "object_encoding": object_encoding, + "has_nulls": has_nulls, + } + else: + raise NotImplementedError(f"{self.engine} not supported") + df.to_parquet( + fn, + engine=self.engine, + storage_options=storage_options + or getattr(self.fs, "storage_options", None), + compression="zstd", + index=False, + **df_backend_kwargs, + ) + + partition.clear() + self._items.pop((field, record)) + + def flush(self, base_url=None, storage_options=None): + """Output any modified or deleted keys + + Parameters + ---------- + base_url: str + Location of the output + """ + + # write what we have so far and clear sub chunks + for thing in list(self._items): + if isinstance(thing, tuple): + field, record = thing + self.write( + field, + record, + base_url=base_url, + storage_options=storage_options, + ) + + # gather .zmetadata from self._items and write that too + for k in list(self._items): + if k != ".zmetadata" and ".z" in k: + self.zmetadata[k] = json.loads(self._items.pop(k)) + met = {"metadata": self.zmetadata, "record_size": self.record_size} + self._items.clear() + self._items[".zmetadata"] = json.dumps(met).encode() + self.fs.pipe( + "/".join([base_url or self.out_root, ".zmetadata"]), + self._items[".zmetadata"], + ) + + # TODO: only clear those that we wrote to? + self.open_refs.cache_clear() + + def __len__(self): + # Caveat: This counts expected references, not actual - but is fast + count = 0 + for field in self.listdir(): + if field.startswith("."): + count += 1 + else: + count += math.prod(self._get_chunk_sizes(field)) + count += len(self.zmetadata) # all metadata keys + # any other files not in reference partitions + count += sum(1 for _ in self._items if not isinstance(_, tuple)) + return count + + def __iter__(self): + # Caveat: returns only existing keys, so the number of these does not + # match len(self) + metas = set(self.zmetadata) + metas.update(self._items) + for bit in metas: + if isinstance(bit, str): + yield bit + for field in self.listdir(): + for k in self._keys_in_field(field): + if k in self: + yield k + + def __contains__(self, item): + try: + self._load_one_key(item) + return True + except KeyError: + return False + + def _keys_in_field(self, field): + """List key names in given field + + Produces strings like "field/x.y" appropriate from the chunking of the array + """ + chunk_sizes = self._get_chunk_sizes(field) + if len(chunk_sizes) == 0: + yield field + "/0" + return + inds = itertools.product(*(range(i) for i in chunk_sizes)) + for ind in inds: + yield field + "/" + ".".join([str(c) for c in ind]) + + +class ReferenceFileSystem(AsyncFileSystem): + """View byte ranges of some other file as a file system + Initial version: single file system target, which must support + async, and must allow start and end args in _cat_file. Later versions + may allow multiple arbitrary URLs for the targets. + This FileSystem is read-only. It is designed to be used with async + targets (for now). We do not get original file details from the target FS. + Configuration is by passing a dict of references at init, or a URL to + a JSON file containing the same; this dict + can also contain concrete data for some set of paths. + Reference dict format: + {path0: bytes_data, path1: (target_url, offset, size)} + https://github.com/fsspec/kerchunk/blob/main/README.md + """ + + protocol = "reference" + cachable = False + + def __init__( + self, + fo, + target=None, + ref_storage_args=None, + target_protocol=None, + target_options=None, + remote_protocol=None, + remote_options=None, + fs=None, + template_overrides=None, + simple_templates=True, + max_gap=64_000, + max_block=256_000_000, + cache_size=128, + **kwargs, + ): + """ + Parameters + ---------- + fo : dict or str + The set of references to use for this instance, with a structure as above. + If str referencing a JSON file, will use fsspec.open, in conjunction + with target_options and target_protocol to open and parse JSON at this + location. If a directory, then assume references are a set of parquet + files to be loaded lazily. + target : str + For any references having target_url as None, this is the default file + target to use + ref_storage_args : dict + If references is a str, use these kwargs for loading the JSON file. + Deprecated: use target_options instead. + target_protocol : str + Used for loading the reference file, if it is a path. If None, protocol + will be derived from the given path + target_options : dict + Extra FS options for loading the reference file ``fo``, if given as a path + remote_protocol : str + The protocol of the filesystem on which the references will be evaluated + (unless fs is provided). If not given, will be derived from the first + URL that has a protocol in the templates or in the references, in that + order. + remote_options : dict + kwargs to go with remote_protocol + fs : AbstractFileSystem | dict(str, (AbstractFileSystem | dict)) + Directly provide a file system(s): + - a single filesystem instance + - a dict of protocol:filesystem, where each value is either a filesystem + instance, or a dict of kwargs that can be used to create in + instance for the given protocol + + If this is given, remote_options and remote_protocol are ignored. + template_overrides : dict + Swap out any templates in the references file with these - useful for + testing. + simple_templates: bool + Whether templates can be processed with simple replace (True) or if + jinja is needed (False, much slower). All reference sets produced by + ``kerchunk`` are simple in this sense, but the spec allows for complex. + max_gap, max_block: int + For merging multiple concurrent requests to the same remote file. + Neighboring byte ranges will only be merged when their + inter-range gap is <= ``max_gap``. Default is 64KB. Set to 0 + to only merge when it requires no extra bytes. Pass a negative + number to disable merging, appropriate for local target files. + Neighboring byte ranges will only be merged when the size of + the aggregated range is <= ``max_block``. Default is 256MB. + cache_size : int + Maximum size of LRU cache, where cache_size*record_size denotes + the total number of references that can be loaded in memory at once. + Only used for lazily loaded references. + kwargs : passed to parent class + """ + super().__init__(**kwargs) + self.target = target + self.template_overrides = template_overrides + self.simple_templates = simple_templates + self.templates = {} + self.fss = {} + self._dircache = {} + self.max_gap = max_gap + self.max_block = max_block + if isinstance(fo, str): + dic = dict( + **(ref_storage_args or target_options or {}), protocol=target_protocol + ) + ref_fs, fo2 = fsspec.core.url_to_fs(fo, **dic) + if ref_fs.isfile(fo2): + # text JSON + with fsspec.open(fo, "rb", **dic) as f: + logger.info("Read reference from URL %s", fo) + text = json.load(f) + self._process_references(text, template_overrides) + else: + # Lazy parquet refs + logger.info("Open lazy reference dict from URL %s", fo) + self.references = LazyReferenceMapper( + fo2, + fs=ref_fs, + cache_size=cache_size, + ) + else: + # dictionaries + self._process_references(fo, template_overrides) + if isinstance(fs, dict): + self.fss = { + k: ( + fsspec.filesystem(k.split(":", 1)[0], **opts) + if isinstance(opts, dict) + else opts + ) + for k, opts in fs.items() + } + if None not in self.fss: + self.fss[None] = filesystem("file") + return + if fs is not None: + # single remote FS + remote_protocol = ( + fs.protocol[0] if isinstance(fs.protocol, tuple) else fs.protocol + ) + self.fss[remote_protocol] = fs + + if remote_protocol is None: + # get single protocol from any templates + for ref in self.templates.values(): + if callable(ref): + ref = ref() + protocol, _ = fsspec.core.split_protocol(ref) + if protocol and protocol not in self.fss: + fs = filesystem(protocol, **(remote_options or {})) + self.fss[protocol] = fs + if remote_protocol is None: + # get single protocol from references + # TODO: warning here, since this can be very expensive? + for ref in self.references.values(): + if callable(ref): + ref = ref() + if isinstance(ref, list) and ref[0]: + protocol, _ = fsspec.core.split_protocol(ref[0]) + if protocol not in self.fss: + fs = filesystem(protocol, **(remote_options or {})) + self.fss[protocol] = fs + # only use first remote URL + break + + if remote_protocol and remote_protocol not in self.fss: + fs = filesystem(remote_protocol, **(remote_options or {})) + self.fss[remote_protocol] = fs + + self.fss[None] = fs or filesystem("file") # default one + # Wrap any non-async filesystems to ensure async methods are available below + for k, f in self.fss.items(): + if not f.async_impl: + self.fss[k] = AsyncFileSystemWrapper(f, asynchronous=self.asynchronous) + elif self.asynchronous ^ f.asynchronous: + raise ValueError( + "Reference-FS's target filesystem must have same value " + "of asynchronous" + ) + + def _cat_common(self, path, start=None, end=None): + path = self._strip_protocol(path) + logger.debug(f"cat: {path}") + try: + part = self.references[path] + except KeyError as exc: + raise FileNotFoundError(path) from exc + if isinstance(part, str): + part = part.encode() + if hasattr(part, "to_bytes"): + part = part.to_bytes() + if isinstance(part, bytes): + logger.debug(f"Reference: {path}, type bytes") + if part.startswith(b"base64:"): + part = base64.b64decode(part[7:]) + return part, None, None + + if len(part) == 1: + logger.debug(f"Reference: {path}, whole file => {part}") + url = part[0] + start1, end1 = start, end + else: + url, start0, size = part + logger.debug(f"Reference: {path} => {url}, offset {start0}, size {size}") + end0 = start0 + size + + if start is not None: + if start >= 0: + start1 = start0 + start + else: + start1 = end0 + start + else: + start1 = start0 + if end is not None: + if end >= 0: + end1 = start0 + end + else: + end1 = end0 + end + else: + end1 = end0 + if url is None: + url = self.target + return url, start1, end1 + + async def _cat_file(self, path, start=None, end=None, **kwargs): + part_or_url, start0, end0 = self._cat_common(path, start=start, end=end) + if isinstance(part_or_url, bytes): + return part_or_url[start:end] + protocol, _ = split_protocol(part_or_url) + try: + return await self.fss[protocol]._cat_file( + part_or_url, start=start0, end=end0 + ) + except Exception as e: + raise ReferenceNotReachable(path, part_or_url) from e + + def cat_file(self, path, start=None, end=None, **kwargs): + part_or_url, start0, end0 = self._cat_common(path, start=start, end=end) + if isinstance(part_or_url, bytes): + return part_or_url[start:end] + protocol, _ = split_protocol(part_or_url) + try: + return self.fss[protocol].cat_file(part_or_url, start=start0, end=end0) + except Exception as e: + raise ReferenceNotReachable(path, part_or_url) from e + + def pipe_file(self, path, value, **_): + """Temporarily add binary data or reference as a file""" + self.references[path] = value + + async def _get_file(self, rpath, lpath, **kwargs): + if self.isdir(rpath): + return os.makedirs(lpath, exist_ok=True) + data = await self._cat_file(rpath) + with open(lpath, "wb") as f: + f.write(data) + + def get_file(self, rpath, lpath, callback=DEFAULT_CALLBACK, **kwargs): + if self.isdir(rpath): + return os.makedirs(lpath, exist_ok=True) + data = self.cat_file(rpath, **kwargs) + callback.set_size(len(data)) + if isfilelike(lpath): + lpath.write(data) + else: + with open(lpath, "wb") as f: + f.write(data) + callback.absolute_update(len(data)) + + def get(self, rpath, lpath, recursive=False, **kwargs): + if recursive: + # trigger directory build + self.ls("") + rpath = self.expand_path(rpath, recursive=recursive) + fs = fsspec.filesystem("file", auto_mkdir=True) + targets = other_paths(rpath, lpath) + if recursive: + data = self.cat([r for r in rpath if not self.isdir(r)]) + else: + data = self.cat(rpath) + for remote, local in zip(rpath, targets): + if remote in data: + fs.pipe_file(local, data[remote]) + + def cat(self, path, recursive=False, on_error="raise", **kwargs): + if isinstance(path, str) and recursive: + raise NotImplementedError + if isinstance(path, list) and (recursive or any("*" in p for p in path)): + raise NotImplementedError + # TODO: if references is lazy, pre-fetch all paths in batch before access + proto_dict = _protocol_groups(path, self.references) + out = {} + for proto, paths in proto_dict.items(): + fs = self.fss[proto] + urls, starts, ends, valid_paths = [], [], [], [] + for p in paths: + # find references or label not-found. Early exit if any not + # found and on_error is "raise" + try: + u, s, e = self._cat_common(p) + if not isinstance(u, (bytes, str)): + # nan/None from parquet + continue + except FileNotFoundError as err: + if on_error == "raise": + raise + if on_error != "omit": + out[p] = err + else: + urls.append(u) + starts.append(s) + ends.append(e) + valid_paths.append(p) + + # process references into form for merging + urls2 = [] + starts2 = [] + ends2 = [] + paths2 = [] + whole_files = set() + for u, s, e, p in zip(urls, starts, ends, valid_paths): + if isinstance(u, bytes): + # data + out[p] = u + elif s is None: + # whole file - limits are None, None, but no further + # entries take for this file + whole_files.add(u) + urls2.append(u) + starts2.append(s) + ends2.append(e) + paths2.append(p) + for u, s, e, p in zip(urls, starts, ends, valid_paths): + # second run to account for files that are to be loaded whole + if s is not None and u not in whole_files: + urls2.append(u) + starts2.append(s) + ends2.append(e) + paths2.append(p) + + # merge and fetch consolidated ranges + new_paths, new_starts, new_ends = merge_offset_ranges( + list(urls2), + list(starts2), + list(ends2), + sort=True, + max_gap=self.max_gap, + max_block=self.max_block, + ) + bytes_out = fs.cat_ranges(new_paths, new_starts, new_ends) + + # unbundle from merged bytes - simple approach + for u, s, e, p in zip(urls, starts, ends, valid_paths): + if p in out: + continue # was bytes, already handled + for np, ns, ne, b in zip(new_paths, new_starts, new_ends, bytes_out): + if np == u and (ns is None or ne is None): + if isinstance(b, Exception): + out[p] = b + else: + out[p] = b[s:e] + elif np == u and s >= ns and e <= ne: + if isinstance(b, Exception): + out[p] = b + else: + out[p] = b[s - ns : (e - ne) or None] + + for k, v in out.copy().items(): + # these were valid references, but fetch failed, so transform exc + if isinstance(v, Exception) and k in self.references: + ex = out[k] + new_ex = ReferenceNotReachable(k, self.references[k]) + new_ex.__cause__ = ex + if on_error == "raise": + raise new_ex + elif on_error != "omit": + out[k] = new_ex + + if len(out) == 1 and isinstance(path, str) and "*" not in path: + return _first(out) + return out + + def _process_references(self, references, template_overrides=None): + vers = references.get("version", None) + if vers is None: + self._process_references0(references) + elif vers == 1: + self._process_references1(references, template_overrides=template_overrides) + else: + raise ValueError(f"Unknown reference spec version: {vers}") + # TODO: we make dircache by iterating over all entries, but for Spec >= 1, + # can replace with programmatic. Is it even needed for mapper interface? + + def _process_references0(self, references): + """Make reference dict for Spec Version 0""" + if isinstance(references, dict): + # do not do this for lazy/parquet backend, which will not make dicts, + # but must remain writable in the original object + references = { + key: json.dumps(val) if isinstance(val, dict) else val + for key, val in references.items() + } + self.references = references + + def _process_references1(self, references, template_overrides=None): + if not self.simple_templates or self.templates: + import jinja2 + self.references = {} + self._process_templates(references.get("templates", {})) + + @lru_cache(1000) + def _render_jinja(u): + return jinja2.Template(u).render(**self.templates) + + for k, v in references.get("refs", {}).items(): + if isinstance(v, str): + if v.startswith("base64:"): + self.references[k] = base64.b64decode(v[7:]) + self.references[k] = v + elif isinstance(v, dict): + self.references[k] = json.dumps(v) + elif self.templates: + u = v[0] + if "{{" in u: + if self.simple_templates: + u = ( + u.replace("{{", "{") + .replace("}}", "}") + .format(**self.templates) + ) + else: + u = _render_jinja(u) + self.references[k] = [u] if len(v) == 1 else [u, v[1], v[2]] + else: + self.references[k] = v + self.references.update(self._process_gen(references.get("gen", []))) + + def _process_templates(self, tmp): + self.templates = {} + if self.template_overrides is not None: + tmp.update(self.template_overrides) + for k, v in tmp.items(): + if "{{" in v: + import jinja2 + + self.templates[k] = lambda temp=v, **kwargs: jinja2.Template( + temp + ).render(**kwargs) + else: + self.templates[k] = v + + def _process_gen(self, gens): + out = {} + for gen in gens: + dimension = { + k: ( + v + if isinstance(v, list) + else range(v.get("start", 0), v["stop"], v.get("step", 1)) + ) + for k, v in gen["dimensions"].items() + } + products = ( + dict(zip(dimension.keys(), values)) + for values in itertools.product(*dimension.values()) + ) + for pr in products: + import jinja2 + + key = jinja2.Template(gen["key"]).render(**pr, **self.templates) + url = jinja2.Template(gen["url"]).render(**pr, **self.templates) + if ("offset" in gen) and ("length" in gen): + offset = int( + jinja2.Template(gen["offset"]).render(**pr, **self.templates) + ) + length = int( + jinja2.Template(gen["length"]).render(**pr, **self.templates) + ) + out[key] = [url, offset, length] + elif ("offset" in gen) ^ ("length" in gen): + raise ValueError( + "Both 'offset' and 'length' are required for a " + "reference generator entry if either is provided." + ) + else: + out[key] = [url] + return out + + def _dircache_from_items(self): + self.dircache = {"": []} + it = self.references.items() + for path, part in it: + if isinstance(part, (bytes, str)) or hasattr(part, "to_bytes"): + size = len(part) + elif len(part) == 1: + size = None + else: + _, _, size = part + par = path.rsplit("/", 1)[0] if "/" in path else "" + par0 = par + subdirs = [par0] + while par0 and par0 not in self.dircache: + # collect parent directories + par0 = self._parent(par0) + subdirs.append(par0) + + subdirs.reverse() + for parent, child in zip(subdirs, subdirs[1:]): + # register newly discovered directories + assert child not in self.dircache + assert parent in self.dircache + self.dircache[parent].append( + {"name": child, "type": "directory", "size": 0} + ) + self.dircache[child] = [] + + self.dircache[par].append({"name": path, "type": "file", "size": size}) + + def _open(self, path, mode="rb", block_size=None, cache_options=None, **kwargs): + part_or_url, start0, end0 = self._cat_common(path) + # This logic is kept outside `ReferenceFile` to avoid unnecessary redirection. + # That does mean `_cat_common` gets called twice if it eventually reaches `ReferenceFile`. + if isinstance(part_or_url, bytes): + return io.BytesIO(part_or_url[start0:end0]) + + protocol, _ = split_protocol(part_or_url) + if start0 is None and end0 is None: + return self.fss[protocol]._open( + part_or_url, + mode, + block_size=block_size, + cache_options=cache_options, + **kwargs, + ) + + return ReferenceFile( + self, + path, + mode, + block_size=block_size, + cache_options=cache_options, + **kwargs, + ) + + def ls(self, path, detail=True, **kwargs): + logger.debug("list %s", path) + path = self._strip_protocol(path) + if isinstance(self.references, LazyReferenceMapper): + try: + return self.references.ls(path, detail) + except KeyError: + pass + raise FileNotFoundError(f"'{path}' is not a known key") + if not self.dircache: + self._dircache_from_items() + out = self._ls_from_cache(path) + if out is None: + raise FileNotFoundError(path) + if detail: + return out + return [o["name"] for o in out] + + def exists(self, path, **kwargs): # overwrite auto-sync version + return self.isdir(path) or self.isfile(path) + + def isdir(self, path): # overwrite auto-sync version + if self.dircache: + return path in self.dircache + elif isinstance(self.references, LazyReferenceMapper): + return path in self.references.listdir() + else: + # this may be faster than building dircache for single calls, but + # by looping will be slow for many calls; could cache it? + return any(_.startswith(f"{path}/") for _ in self.references) + + def isfile(self, path): # overwrite auto-sync version + return path in self.references + + async def _ls(self, path, detail=True, **kwargs): # calls fast sync code + return self.ls(path, detail, **kwargs) + + def find(self, path, maxdepth=None, withdirs=False, detail=False, **kwargs): + if withdirs: + return super().find( + path, maxdepth=maxdepth, withdirs=withdirs, detail=detail, **kwargs + ) + if path: + path = self._strip_protocol(path) + r = sorted(k for k in self.references if k.startswith(path)) + else: + r = sorted(self.references) + if detail: + if not self.dircache: + self._dircache_from_items() + return {k: self._ls_from_cache(k)[0] for k in r} + else: + return r + + def info(self, path, **kwargs): + out = self.references.get(path) + if out is not None: + if isinstance(out, (str, bytes)): + # decode base64 here + return {"name": path, "type": "file", "size": len(out)} + elif len(out) > 1: + return {"name": path, "type": "file", "size": out[2]} + else: + out0 = [{"name": path, "type": "file", "size": None}] + else: + out = self.ls(path, True) + out0 = [o for o in out if o["name"] == path] + if not out0: + return {"name": path, "type": "directory", "size": 0} + if out0[0]["size"] is None: + # if this is a whole remote file, update size using remote FS + prot, _ = split_protocol(self.references[path][0]) + out0[0]["size"] = self.fss[prot].size(self.references[path][0]) + return out0[0] + + async def _info(self, path, **kwargs): # calls fast sync code + return self.info(path) + + async def _rm_file(self, path, **kwargs): + self.references.pop( + path, None + ) # ignores FileNotFound, just as well for directories + self.dircache.clear() # this is a bit heavy handed + + async def _pipe_file(self, path, data, mode="overwrite", **kwargs): + if mode == "create" and self.exists(path): + raise FileExistsError + # can be str or bytes + self.references[path] = data + self.dircache.clear() # this is a bit heavy handed + + async def _put_file(self, lpath, rpath, mode="overwrite", **kwargs): + # puts binary + if mode == "create" and self.exists(rpath): + raise FileExistsError + with open(lpath, "rb") as f: + self.references[rpath] = f.read() + self.dircache.clear() # this is a bit heavy handed + + def save_json(self, url, **storage_options): + """Write modified references into new location""" + out = {} + for k, v in self.references.items(): + if isinstance(v, bytes): + try: + out[k] = v.decode("ascii") + except UnicodeDecodeError: + out[k] = (b"base64:" + base64.b64encode(v)).decode() + else: + out[k] = v + with fsspec.open(url, "wb", **storage_options) as f: + f.write(json.dumps({"version": 1, "refs": out}).encode()) + + +class ReferenceFile(AbstractBufferedFile): + def __init__( + self, + fs, + path, + mode="rb", + block_size="default", + autocommit=True, + cache_type="readahead", + cache_options=None, + size=None, + **kwargs, + ): + super().__init__( + fs, + path, + mode=mode, + block_size=block_size, + autocommit=autocommit, + size=size, + cache_type=cache_type, + cache_options=cache_options, + **kwargs, + ) + part_or_url, self.start, self.end = self.fs._cat_common(self.path) + protocol, _ = split_protocol(part_or_url) + self.src_fs = self.fs.fss[protocol] + self.src_path = part_or_url + self._f = None + + @property + def f(self): + if self._f is None or self._f.closed: + self._f = self.src_fs._open( + self.src_path, + mode=self.mode, + block_size=self.blocksize, + autocommit=self.autocommit, + cache_type="none", + **self.kwargs, + ) + return self._f + + def close(self): + if self._f is not None: + self._f.close() + return super().close() + + def _fetch_range(self, start, end): + start = start + self.start + end = min(end + self.start, self.end) + self.f.seek(start) + return self.f.read(end - start) diff --git a/fsspec/implementations/sftp.py b/fsspec/implementations/sftp.py new file mode 100644 index 0000000000000000000000000000000000000000..77f7b370cd246f9a9bfd34141afc3edd728d13c3 --- /dev/null +++ b/fsspec/implementations/sftp.py @@ -0,0 +1,180 @@ +import datetime +import logging +import os +import types +import uuid +from stat import S_ISDIR, S_ISLNK + +import paramiko + +from .. import AbstractFileSystem +from ..utils import infer_storage_options + +logger = logging.getLogger("fsspec.sftp") + + +class SFTPFileSystem(AbstractFileSystem): + """Files over SFTP/SSH + + Peer-to-peer filesystem over SSH using paramiko. + + Note: if using this with the ``open`` or ``open_files``, with full URLs, + there is no way to tell if a path is relative, so all paths are assumed + to be absolute. + """ + + protocol = "sftp", "ssh" + + def __init__(self, host, **ssh_kwargs): + """ + + Parameters + ---------- + host: str + Hostname or IP as a string + temppath: str + Location on the server to put files, when within a transaction + ssh_kwargs: dict + Parameters passed on to connection. See details in + https://docs.paramiko.org/en/3.3/api/client.html#paramiko.client.SSHClient.connect + May include port, username, password... + """ + if self._cached: + return + super().__init__(**ssh_kwargs) + self.temppath = ssh_kwargs.pop("temppath", "/tmp") # remote temp directory + self.host = host + self.ssh_kwargs = ssh_kwargs + self._connect() + + def _connect(self): + logger.debug("Connecting to SFTP server %s", self.host) + self.client = paramiko.SSHClient() + self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + self.client.connect(self.host, **self.ssh_kwargs) + self.ftp = self.client.open_sftp() + + @classmethod + def _strip_protocol(cls, path): + return infer_storage_options(path)["path"] + + @staticmethod + def _get_kwargs_from_urls(urlpath): + out = infer_storage_options(urlpath) + out.pop("path", None) + out.pop("protocol", None) + return out + + def mkdir(self, path, create_parents=True, mode=511): + logger.debug("Creating folder %s", path) + if self.exists(path): + raise FileExistsError(f"File exists: {path}") + + if create_parents: + self.makedirs(path) + else: + self.ftp.mkdir(path, mode) + + def makedirs(self, path, exist_ok=False, mode=511): + if self.exists(path) and not exist_ok: + raise FileExistsError(f"File exists: {path}") + + parts = path.split("/") + new_path = "/" if path[:1] == "/" else "" + + for part in parts: + if part: + new_path = f"{new_path}/{part}" if new_path else part + if not self.exists(new_path): + self.ftp.mkdir(new_path, mode) + + def rmdir(self, path): + logger.debug("Removing folder %s", path) + self.ftp.rmdir(path) + + def info(self, path): + stat = self._decode_stat(self.ftp.stat(path)) + stat["name"] = path + return stat + + @staticmethod + def _decode_stat(stat, parent_path=None): + if S_ISDIR(stat.st_mode): + t = "directory" + elif S_ISLNK(stat.st_mode): + t = "link" + else: + t = "file" + out = { + "name": "", + "size": stat.st_size, + "type": t, + "uid": stat.st_uid, + "gid": stat.st_gid, + "time": datetime.datetime.fromtimestamp( + stat.st_atime, tz=datetime.timezone.utc + ), + "mtime": datetime.datetime.fromtimestamp( + stat.st_mtime, tz=datetime.timezone.utc + ), + } + if parent_path: + out["name"] = "/".join([parent_path.rstrip("/"), stat.filename]) + return out + + def ls(self, path, detail=False): + logger.debug("Listing folder %s", path) + stats = [self._decode_stat(stat, path) for stat in self.ftp.listdir_iter(path)] + if detail: + return stats + else: + paths = [stat["name"] for stat in stats] + return sorted(paths) + + def put(self, lpath, rpath, callback=None, **kwargs): + logger.debug("Put file %s into %s", lpath, rpath) + self.ftp.put(lpath, rpath) + + def get_file(self, rpath, lpath, **kwargs): + if self.isdir(rpath): + os.makedirs(lpath, exist_ok=True) + else: + self.ftp.get(self._strip_protocol(rpath), lpath) + + def _open(self, path, mode="rb", block_size=None, **kwargs): + """ + block_size: int or None + If 0, no buffering, if 1, line buffering, if >1, buffer that many + bytes, if None use default from paramiko. + """ + logger.debug("Opening file %s", path) + if kwargs.get("autocommit", True) is False: + # writes to temporary file, move on commit + path2 = "/".join([self.temppath, str(uuid.uuid4())]) + f = self.ftp.open(path2, mode, bufsize=block_size if block_size else -1) + f.temppath = path2 + f.targetpath = path + f.fs = self + f.commit = types.MethodType(commit_a_file, f) + f.discard = types.MethodType(discard_a_file, f) + else: + f = self.ftp.open(path, mode, bufsize=block_size if block_size else -1) + return f + + def _rm(self, path): + if self.isdir(path): + self.ftp.rmdir(path) + else: + self.ftp.remove(path) + + def mv(self, old, new): + logger.debug("Renaming %s into %s", old, new) + self.ftp.posix_rename(old, new) + + +def commit_a_file(self): + self.fs.mv(self.temppath, self.targetpath) + + +def discard_a_file(self): + self.fs._rm(self.temppath) diff --git a/fsspec/implementations/smb.py b/fsspec/implementations/smb.py new file mode 100644 index 0000000000000000000000000000000000000000..db6b3f5c3702de90cf121ccca49f3ca2b580df9f --- /dev/null +++ b/fsspec/implementations/smb.py @@ -0,0 +1,416 @@ +""" +This module contains SMBFileSystem class responsible for handling access to +Windows Samba network shares by using package smbprotocol +""" + +import datetime +import re +import uuid +from stat import S_ISDIR, S_ISLNK + +import smbclient +import smbprotocol.exceptions + +from .. import AbstractFileSystem +from ..utils import infer_storage_options + +# ! pylint: disable=bad-continuation + + +class SMBFileSystem(AbstractFileSystem): + """Allow reading and writing to Windows and Samba network shares. + + When using `fsspec.open()` for getting a file-like object the URI + should be specified as this format: + ``smb://workgroup;user:password@server:port/share/folder/file.csv``. + + Example:: + + >>> import fsspec + >>> with fsspec.open( + ... 'smb://myuser:mypassword@myserver.com/' 'share/folder/file.csv' + ... ) as smbfile: + ... df = pd.read_csv(smbfile, sep='|', header=None) + + Note that you need to pass in a valid hostname or IP address for the host + component of the URL. Do not use the Windows/NetBIOS machine name for the + host component. + + The first component of the path in the URL points to the name of the shared + folder. Subsequent path components will point to the directory/folder/file. + + The URL components ``workgroup`` , ``user``, ``password`` and ``port`` may be + optional. + + .. note:: + + For working this source require `smbprotocol`_ to be installed, e.g.:: + + $ pip install smbprotocol + # or + # pip install smbprotocol[kerberos] + + .. _smbprotocol: https://github.com/jborean93/smbprotocol#requirements + + Note: if using this with the ``open`` or ``open_files``, with full URLs, + there is no way to tell if a path is relative, so all paths are assumed + to be absolute. + """ + + protocol = "smb" + + # pylint: disable=too-many-arguments + def __init__( + self, + host, + port=None, + username=None, + password=None, + timeout=60, + encrypt=None, + share_access=None, + register_session_retries=4, + register_session_retry_wait=1, + register_session_retry_factor=10, + auto_mkdir=False, + **kwargs, + ): + """ + You can use _get_kwargs_from_urls to get some kwargs from + a reasonable SMB url. + + Authentication will be anonymous or integrated if username/password are not + given. + + Parameters + ---------- + host: str + The remote server name/ip to connect to + port: int or None + Port to connect with. Usually 445, sometimes 139. + username: str or None + Username to connect with. Required if Kerberos auth is not being used. + password: str or None + User's password on the server, if using username + timeout: int + Connection timeout in seconds + encrypt: bool + Whether to force encryption or not, once this has been set to True + the session cannot be changed back to False. + share_access: str or None + Specifies the default access applied to file open operations + performed with this file system object. + This affects whether other processes can concurrently open a handle + to the same file. + + - None (the default): exclusively locks the file until closed. + - 'r': Allow other handles to be opened with read access. + - 'w': Allow other handles to be opened with write access. + - 'd': Allow other handles to be opened with delete access. + register_session_retries: int + Number of retries to register a session with the server. Retries are not performed + for authentication errors, as they are considered as invalid credentials and not network + issues. If set to negative value, no register attempts will be performed. + register_session_retry_wait: int + Time in seconds to wait between each retry. Number must be non-negative. + register_session_retry_factor: int + Base factor for the wait time between each retry. The wait time + is calculated using exponential function. For factor=1 all wait times + will be equal to `register_session_retry_wait`. For any number of retries, + the last wait time will be equal to `register_session_retry_wait` and for retries>1 + the first wait time will be equal to `register_session_retry_wait / factor`. + Number must be equal to or greater than 1. Optimal factor is 10. + auto_mkdir: bool + Whether, when opening a file, the directory containing it should + be created (if it doesn't already exist). This is assumed by pyarrow + and zarr-python code. + """ + super().__init__(**kwargs) + self.host = host + self.port = port + self.username = username + self.password = password + self.timeout = timeout + self.encrypt = encrypt + self.temppath = kwargs.pop("temppath", "") + self.share_access = share_access + self.register_session_retries = register_session_retries + if register_session_retry_wait < 0: + raise ValueError( + "register_session_retry_wait must be a non-negative integer" + ) + self.register_session_retry_wait = register_session_retry_wait + if register_session_retry_factor < 1: + raise ValueError( + "register_session_retry_factor must be a positive " + "integer equal to or greater than 1" + ) + self.register_session_retry_factor = register_session_retry_factor + self.auto_mkdir = auto_mkdir + self._connect() + + @property + def _port(self): + return 445 if self.port is None else self.port + + def _connect(self): + import time + + if self.register_session_retries <= -1: + return + + retried_errors = [] + + wait_time = self.register_session_retry_wait + n_waits = ( + self.register_session_retries - 1 + ) # -1 = No wait time after the last retry + factor = self.register_session_retry_factor + + # Generate wait times for each retry attempt. + # Wait times are calculated using exponential function. For factor=1 all wait times + # will be equal to `wait`. For any number of retries the last wait time will be + # equal to `wait` and for retries>2 the first wait time will be equal to `wait / factor`. + wait_times = iter( + factor ** (n / n_waits - 1) * wait_time for n in range(0, n_waits + 1) + ) + + for attempt in range(self.register_session_retries + 1): + try: + smbclient.register_session( + self.host, + username=self.username, + password=self.password, + port=self._port, + encrypt=self.encrypt, + connection_timeout=self.timeout, + ) + return + except ( + smbprotocol.exceptions.SMBAuthenticationError, + smbprotocol.exceptions.LogonFailure, + ): + # These exceptions should not be repeated, as they clearly indicate + # that the credentials are invalid and not a network issue. + raise + except ValueError as exc: + if re.findall(r"\[Errno -\d+]", str(exc)): + # This exception is raised by the smbprotocol.transport:Tcp.connect + # and originates from socket.gaierror (OSError). These exceptions might + # be raised due to network instability. We will retry to connect. + retried_errors.append(exc) + else: + # All another ValueError exceptions should be raised, as they are not + # related to network issues. + raise + except Exception as exc: + # Save the exception and retry to connect. This except might be dropped + # in the future, once all exceptions suited for retry are identified. + retried_errors.append(exc) + + if attempt < self.register_session_retries: + time.sleep(next(wait_times)) + + # Raise last exception to inform user about the connection issues. + # Note: Should we use ExceptionGroup to raise all exceptions? + raise retried_errors[-1] + + @classmethod + def _strip_protocol(cls, path): + return infer_storage_options(path)["path"] + + @staticmethod + def _get_kwargs_from_urls(path): + # smb://workgroup;user:password@host:port/share/folder/file.csv + out = infer_storage_options(path) + out.pop("path", None) + out.pop("protocol", None) + return out + + def mkdir(self, path, create_parents=True, **kwargs): + wpath = _as_unc_path(self.host, path) + if create_parents: + smbclient.makedirs(wpath, exist_ok=False, port=self._port, **kwargs) + else: + smbclient.mkdir(wpath, port=self._port, **kwargs) + + def makedirs(self, path, exist_ok=False): + if _share_has_path(path): + wpath = _as_unc_path(self.host, path) + smbclient.makedirs(wpath, exist_ok=exist_ok, port=self._port) + + def rmdir(self, path): + if _share_has_path(path): + wpath = _as_unc_path(self.host, path) + smbclient.rmdir(wpath, port=self._port) + + def info(self, path, **kwargs): + wpath = _as_unc_path(self.host, path) + stats = smbclient.stat(wpath, port=self._port, **kwargs) + if S_ISDIR(stats.st_mode): + stype = "directory" + elif S_ISLNK(stats.st_mode): + stype = "link" + else: + stype = "file" + res = { + "name": path + "/" if stype == "directory" else path, + "size": stats.st_size, + "type": stype, + "uid": stats.st_uid, + "gid": stats.st_gid, + "time": stats.st_atime, + "mtime": stats.st_mtime, + } + return res + + def created(self, path): + """Return the created timestamp of a file as a datetime.datetime""" + wpath = _as_unc_path(self.host, path) + stats = smbclient.stat(wpath, port=self._port) + return datetime.datetime.fromtimestamp(stats.st_ctime, tz=datetime.timezone.utc) + + def modified(self, path): + """Return the modified timestamp of a file as a datetime.datetime""" + wpath = _as_unc_path(self.host, path) + stats = smbclient.stat(wpath, port=self._port) + return datetime.datetime.fromtimestamp(stats.st_mtime, tz=datetime.timezone.utc) + + def ls(self, path, detail=True, **kwargs): + unc = _as_unc_path(self.host, path) + listed = smbclient.listdir(unc, port=self._port, **kwargs) + dirs = ["/".join([path.rstrip("/"), p]) for p in listed] + if detail: + dirs = [self.info(d) for d in dirs] + return dirs + + # pylint: disable=too-many-arguments + def _open( + self, + path, + mode="rb", + block_size=-1, + autocommit=True, + cache_options=None, + **kwargs, + ): + """ + block_size: int or None + If 0, no buffering, 1, line buffering, >1, buffer that many bytes + + Notes + ----- + By specifying 'share_access' in 'kwargs' it is possible to override the + default shared access setting applied in the constructor of this object. + """ + if self.auto_mkdir and "w" in mode: + self.makedirs(self._parent(path), exist_ok=True) + bls = block_size if block_size is not None and block_size >= 0 else -1 + wpath = _as_unc_path(self.host, path) + share_access = kwargs.pop("share_access", self.share_access) + if "w" in mode and autocommit is False: + temp = _as_temp_path(self.host, path, self.temppath) + return SMBFileOpener( + wpath, temp, mode, port=self._port, block_size=bls, **kwargs + ) + return smbclient.open_file( + wpath, + mode, + buffering=bls, + share_access=share_access, + port=self._port, + **kwargs, + ) + + def copy(self, path1, path2, **kwargs): + """Copy within two locations in the same filesystem""" + wpath1 = _as_unc_path(self.host, path1) + wpath2 = _as_unc_path(self.host, path2) + if self.auto_mkdir: + self.makedirs(self._parent(path2), exist_ok=True) + smbclient.copyfile(wpath1, wpath2, port=self._port, **kwargs) + + def _rm(self, path): + if _share_has_path(path): + wpath = _as_unc_path(self.host, path) + stats = smbclient.stat(wpath, port=self._port) + if S_ISDIR(stats.st_mode): + smbclient.rmdir(wpath, port=self._port) + else: + smbclient.remove(wpath, port=self._port) + + def mv(self, path1, path2, recursive=None, maxdepth=None, **kwargs): + wpath1 = _as_unc_path(self.host, path1) + wpath2 = _as_unc_path(self.host, path2) + smbclient.rename(wpath1, wpath2, port=self._port, **kwargs) + + +def _as_unc_path(host, path): + rpath = path.replace("/", "\\") + unc = f"\\\\{host}{rpath}" + return unc + + +def _as_temp_path(host, path, temppath): + share = path.split("/")[1] + temp_file = f"/{share}{temppath}/{uuid.uuid4()}" + unc = _as_unc_path(host, temp_file) + return unc + + +def _share_has_path(path): + parts = path.count("/") + if path.endswith("/"): + return parts > 2 + return parts > 1 + + +class SMBFileOpener: + """writes to remote temporary file, move on commit""" + + def __init__(self, path, temp, mode, port=445, block_size=-1, **kwargs): + self.path = path + self.temp = temp + self.mode = mode + self.block_size = block_size + self.kwargs = kwargs + self.smbfile = None + self._incontext = False + self.port = port + self._open() + + def _open(self): + if self.smbfile is None or self.smbfile.closed: + self.smbfile = smbclient.open_file( + self.temp, + self.mode, + port=self.port, + buffering=self.block_size, + **self.kwargs, + ) + + def commit(self): + """Move temp file to definitive on success.""" + # TODO: use transaction support in SMB protocol + smbclient.replace(self.temp, self.path, port=self.port) + + def discard(self): + """Remove the temp file on failure.""" + smbclient.remove(self.temp, port=self.port) + + def __fspath__(self): + return self.path + + def __iter__(self): + return self.smbfile.__iter__() + + def __getattr__(self, item): + return getattr(self.smbfile, item) + + def __enter__(self): + self._incontext = True + return self.smbfile.__enter__() + + def __exit__(self, exc_type, exc_value, traceback): + self._incontext = False + self.smbfile.__exit__(exc_type, exc_value, traceback) diff --git a/fsspec/implementations/tar.py b/fsspec/implementations/tar.py new file mode 100644 index 0000000000000000000000000000000000000000..412e5ba4d2cdea7db090dc96412e697909a38d78 --- /dev/null +++ b/fsspec/implementations/tar.py @@ -0,0 +1,124 @@ +import logging +import tarfile + +import fsspec +from fsspec.archive import AbstractArchiveFileSystem +from fsspec.compression import compr +from fsspec.utils import infer_compression + +typemap = {b"0": "file", b"5": "directory"} + +logger = logging.getLogger("tar") + + +class TarFileSystem(AbstractArchiveFileSystem): + """Compressed Tar archives as a file-system (read-only) + + Supports the following formats: + tar.gz, tar.bz2, tar.xz + """ + + root_marker = "" + protocol = "tar" + cachable = False + + def __init__( + self, + fo="", + index_store=None, + target_options=None, + target_protocol=None, + compression=None, + **kwargs, + ): + super().__init__(**kwargs) + target_options = target_options or {} + + if isinstance(fo, str): + self.of = fsspec.open(fo, protocol=target_protocol, **target_options) + fo = self.of.open() # keep the reference + + # Try to infer compression. + if compression is None: + name = None + + # Try different ways to get hold of the filename. `fo` might either + # be a `fsspec.LocalFileOpener`, an `io.BufferedReader` or an + # `fsspec.AbstractFileSystem` instance. + try: + # Amended io.BufferedReader or similar. + # This uses a "protocol extension" where original filenames are + # propagated to archive-like filesystems in order to let them + # infer the right compression appropriately. + if hasattr(fo, "original"): + name = fo.original + + # fsspec.LocalFileOpener + elif hasattr(fo, "path"): + name = fo.path + + # io.BufferedReader + elif hasattr(fo, "name"): + name = fo.name + + # fsspec.AbstractFileSystem + elif hasattr(fo, "info"): + name = fo.info()["name"] + + except Exception as ex: + logger.warning( + f"Unable to determine file name, not inferring compression: {ex}" + ) + + if name is not None: + compression = infer_compression(name) + logger.info(f"Inferred compression {compression} from file name {name}") + + if compression is not None: + # TODO: tarfile already implements compression with modes like "'r:gz'", + # but then would seek to offset in the file work? + fo = compr[compression](fo) + + self._fo_ref = fo + self.fo = fo # the whole instance is a context + self.tar = tarfile.TarFile(fileobj=self.fo) + self.dir_cache = None + + self.index_store = index_store + self.index = None + self._index() + + def _index(self): + # TODO: load and set saved index, if exists + out = {} + for ti in self.tar: + info = ti.get_info() + info["type"] = typemap.get(info["type"], "file") + name = ti.get_info()["name"].rstrip("/") + out[name] = (info, ti.offset_data) + + self.index = out + # TODO: save index to self.index_store here, if set + + def _get_dirs(self): + if self.dir_cache is not None: + return + + # This enables ls to get directories as children as well as files + self.dir_cache = { + dirname: {"name": dirname, "size": 0, "type": "directory"} + for dirname in self._all_dirnames(self.tar.getnames()) + } + for member in self.tar.getmembers(): + info = member.get_info() + info["name"] = info["name"].rstrip("/") + info["type"] = typemap.get(info["type"], "file") + self.dir_cache[info["name"]] = info + + def _open(self, path, mode="rb", **kwargs): + if mode != "rb": + raise ValueError("Read-only filesystem implementation") + details, offset = self.index[path] + if details["type"] != "file": + raise ValueError("Can only handle regular files") + return self.tar.extractfile(path) diff --git a/fsspec/implementations/webhdfs.py b/fsspec/implementations/webhdfs.py new file mode 100644 index 0000000000000000000000000000000000000000..c6e0d8446f875ec9578cccf055aeb47cd1c2e996 --- /dev/null +++ b/fsspec/implementations/webhdfs.py @@ -0,0 +1,485 @@ +# https://hadoop.apache.org/docs/r1.0.4/webhdfs.html + +import logging +import os +import secrets +import shutil +import tempfile +import uuid +from contextlib import suppress +from urllib.parse import quote + +import requests + +from ..spec import AbstractBufferedFile, AbstractFileSystem +from ..utils import infer_storage_options, tokenize + +logger = logging.getLogger("webhdfs") + + +class WebHDFS(AbstractFileSystem): + """ + Interface to HDFS over HTTP using the WebHDFS API. Supports also HttpFS gateways. + + Four auth mechanisms are supported: + + insecure: no auth is done, and the user is assumed to be whoever they + say they are (parameter ``user``), or a predefined value such as + "dr.who" if not given + spnego: when kerberos authentication is enabled, auth is negotiated by + requests_kerberos https://github.com/requests/requests-kerberos . + This establishes a session based on existing kinit login and/or + specified principal/password; parameters are passed with ``kerb_kwargs`` + token: uses an existing Hadoop delegation token from another secured + service. Indeed, this client can also generate such tokens when + not insecure. Note that tokens expire, but can be renewed (by a + previously specified user) and may allow for proxying. + basic-auth: used when both parameter ``user`` and parameter ``password`` + are provided. + + """ + + tempdir = str(tempfile.gettempdir()) + protocol = "webhdfs", "webHDFS" + + def __init__( + self, + host, + port=50070, + kerberos=False, + token=None, + user=None, + password=None, + proxy_to=None, + kerb_kwargs=None, + data_proxy=None, + use_https=False, + session_cert=None, + session_verify=True, + **kwargs, + ): + """ + Parameters + ---------- + host: str + Name-node address + port: int + Port for webHDFS + kerberos: bool + Whether to authenticate with kerberos for this connection + token: str or None + If given, use this token on every call to authenticate. A user + and user-proxy may be encoded in the token and should not be also + given + user: str or None + If given, assert the user name to connect with + password: str or None + If given, assert the password to use for basic auth. If password + is provided, user must be provided also + proxy_to: str or None + If given, the user has the authority to proxy, and this value is + the user in who's name actions are taken + kerb_kwargs: dict + Any extra arguments for HTTPKerberosAuth, see + ``_ + data_proxy: dict, callable or None + If given, map data-node addresses. This can be necessary if the + HDFS cluster is behind a proxy, running on Docker or otherwise has + a mismatch between the host-names given by the name-node and the + address by which to refer to them from the client. If a dict, + maps host names ``host->data_proxy[host]``; if a callable, full + URLs are passed, and function must conform to + ``url->data_proxy(url)``. + use_https: bool + Whether to connect to the Name-node using HTTPS instead of HTTP + session_cert: str or Tuple[str, str] or None + Path to a certificate file, or tuple of (cert, key) files to use + for the requests.Session + session_verify: str, bool or None + Path to a certificate file to use for verifying the requests.Session. + kwargs + """ + if self._cached: + return + super().__init__(**kwargs) + self.url = f"{'https' if use_https else 'http'}://{host}:{port}/webhdfs/v1" + self.kerb = kerberos + self.kerb_kwargs = kerb_kwargs or {} + self.pars = {} + self.proxy = data_proxy or {} + if token is not None: + if user is not None or proxy_to is not None: + raise ValueError( + "If passing a delegation token, must not set " + "user or proxy_to, as these are encoded in the" + " token" + ) + self.pars["delegation"] = token + self.user = user + self.password = password + + if password is not None: + if user is None: + raise ValueError( + "If passing a password, the user must also be" + "set in order to set up the basic-auth" + ) + else: + if user is not None: + self.pars["user.name"] = user + + if proxy_to is not None: + self.pars["doas"] = proxy_to + if kerberos and user is not None: + raise ValueError( + "If using Kerberos auth, do not specify the " + "user, this is handled by kinit." + ) + + self.session_cert = session_cert + self.session_verify = session_verify + + self._connect() + + self._fsid = f"webhdfs_{tokenize(host, port)}" + + @property + def fsid(self): + return self._fsid + + def _connect(self): + self.session = requests.Session() + + if self.session_cert: + self.session.cert = self.session_cert + + self.session.verify = self.session_verify + + if self.kerb: + from requests_kerberos import HTTPKerberosAuth + + self.session.auth = HTTPKerberosAuth(**self.kerb_kwargs) + + if self.user is not None and self.password is not None: + from requests.auth import HTTPBasicAuth + + self.session.auth = HTTPBasicAuth(self.user, self.password) + + def _call(self, op, method="get", path=None, data=None, redirect=True, **kwargs): + path = self._strip_protocol(path) if path is not None else "" + url = self._apply_proxy(self.url + quote(path, safe="/=")) + args = kwargs.copy() + args.update(self.pars) + args["op"] = op.upper() + logger.debug("sending %s with %s", url, method) + out = self.session.request( + method=method.upper(), + url=url, + params=args, + data=data, + allow_redirects=redirect, + ) + if out.status_code in [400, 401, 403, 404, 500]: + try: + err = out.json() + msg = err["RemoteException"]["message"] + exp = err["RemoteException"]["exception"] + except (ValueError, KeyError): + pass + else: + if exp in ["IllegalArgumentException", "UnsupportedOperationException"]: + raise ValueError(msg) + elif exp in ["SecurityException", "AccessControlException"]: + raise PermissionError(msg) + elif exp in ["FileNotFoundException"]: + raise FileNotFoundError(msg) + else: + raise RuntimeError(msg) + out.raise_for_status() + return out + + def _open( + self, + path, + mode="rb", + block_size=None, + autocommit=True, + replication=None, + permissions=None, + **kwargs, + ): + """ + + Parameters + ---------- + path: str + File location + mode: str + 'rb', 'wb', etc. + block_size: int + Client buffer size for read-ahead or write buffer + autocommit: bool + If False, writes to temporary file that only gets put in final + location upon commit + replication: int + Number of copies of file on the cluster, write mode only + permissions: str or int + posix permissions, write mode only + kwargs + + Returns + ------- + WebHDFile instance + """ + block_size = block_size or self.blocksize + return WebHDFile( + self, + path, + mode=mode, + block_size=block_size, + tempdir=self.tempdir, + autocommit=autocommit, + replication=replication, + permissions=permissions, + ) + + @staticmethod + def _process_info(info): + info["type"] = info["type"].lower() + info["size"] = info["length"] + return info + + @classmethod + def _strip_protocol(cls, path): + return infer_storage_options(path)["path"] + + @staticmethod + def _get_kwargs_from_urls(urlpath): + out = infer_storage_options(urlpath) + out.pop("path", None) + out.pop("protocol", None) + if "username" in out: + out["user"] = out.pop("username") + return out + + def info(self, path): + out = self._call("GETFILESTATUS", path=path) + info = out.json()["FileStatus"] + info["name"] = path + return self._process_info(info) + + def ls(self, path, detail=False): + out = self._call("LISTSTATUS", path=path) + infos = out.json()["FileStatuses"]["FileStatus"] + for info in infos: + self._process_info(info) + info["name"] = path.rstrip("/") + "/" + info["pathSuffix"] + if detail: + return sorted(infos, key=lambda i: i["name"]) + else: + return sorted(info["name"] for info in infos) + + def content_summary(self, path): + """Total numbers of files, directories and bytes under path""" + out = self._call("GETCONTENTSUMMARY", path=path) + return out.json()["ContentSummary"] + + def ukey(self, path): + """Checksum info of file, giving method and result""" + out = self._call("GETFILECHECKSUM", path=path, redirect=False) + if "Location" in out.headers: + location = self._apply_proxy(out.headers["Location"]) + out2 = self.session.get(location) + out2.raise_for_status() + return out2.json()["FileChecksum"] + else: + out.raise_for_status() + return out.json()["FileChecksum"] + + def home_directory(self): + """Get user's home directory""" + out = self._call("GETHOMEDIRECTORY") + return out.json()["Path"] + + def get_delegation_token(self, renewer=None): + """Retrieve token which can give the same authority to other uses + + Parameters + ---------- + renewer: str or None + User who may use this token; if None, will be current user + """ + if renewer: + out = self._call("GETDELEGATIONTOKEN", renewer=renewer) + else: + out = self._call("GETDELEGATIONTOKEN") + t = out.json()["Token"] + if t is None: + raise ValueError("No token available for this user/security context") + return t["urlString"] + + def renew_delegation_token(self, token): + """Make token live longer. Returns new expiry time""" + out = self._call("RENEWDELEGATIONTOKEN", method="put", token=token) + return out.json()["long"] + + def cancel_delegation_token(self, token): + """Stop the token from being useful""" + self._call("CANCELDELEGATIONTOKEN", method="put", token=token) + + def chmod(self, path, mod): + """Set the permission at path + + Parameters + ---------- + path: str + location to set (file or directory) + mod: str or int + posix epresentation or permission, give as oct string, e.g, '777' + or 0o777 + """ + self._call("SETPERMISSION", method="put", path=path, permission=mod) + + def chown(self, path, owner=None, group=None): + """Change owning user and/or group""" + kwargs = {} + if owner is not None: + kwargs["owner"] = owner + if group is not None: + kwargs["group"] = group + self._call("SETOWNER", method="put", path=path, **kwargs) + + def set_replication(self, path, replication): + """ + Set file replication factor + + Parameters + ---------- + path: str + File location (not for directories) + replication: int + Number of copies of file on the cluster. Should be smaller than + number of data nodes; normally 3 on most systems. + """ + self._call("SETREPLICATION", path=path, method="put", replication=replication) + + def mkdir(self, path, **kwargs): + self._call("MKDIRS", method="put", path=path) + + def makedirs(self, path, exist_ok=False): + if exist_ok is False and self.exists(path): + raise FileExistsError(path) + self.mkdir(path) + + def mv(self, path1, path2, **kwargs): + self._call("RENAME", method="put", path=path1, destination=path2) + + def rm(self, path, recursive=False, **kwargs): + self._call( + "DELETE", + method="delete", + path=path, + recursive="true" if recursive else "false", + ) + + def rm_file(self, path, **kwargs): + self.rm(path) + + def cp_file(self, lpath, rpath, **kwargs): + with self.open(lpath) as lstream: + tmp_fname = "/".join([self._parent(rpath), f".tmp.{secrets.token_hex(16)}"]) + # Perform an atomic copy (stream to a temporary file and + # move it to the actual destination). + try: + with self.open(tmp_fname, "wb") as rstream: + shutil.copyfileobj(lstream, rstream) + self.mv(tmp_fname, rpath) + except BaseException: + with suppress(FileNotFoundError): + self.rm(tmp_fname) + raise + + def _apply_proxy(self, location): + if self.proxy and callable(self.proxy): + location = self.proxy(location) + elif self.proxy: + # as a dict + for k, v in self.proxy.items(): + location = location.replace(k, v, 1) + return location + + +class WebHDFile(AbstractBufferedFile): + """A file living in HDFS over webHDFS""" + + def __init__(self, fs, path, **kwargs): + super().__init__(fs, path, **kwargs) + kwargs = kwargs.copy() + if kwargs.get("permissions", None) is None: + kwargs.pop("permissions", None) + if kwargs.get("replication", None) is None: + kwargs.pop("replication", None) + self.permissions = kwargs.pop("permissions", 511) + tempdir = kwargs.pop("tempdir") + if kwargs.pop("autocommit", False) is False: + self.target = self.path + self.path = os.path.join(tempdir, str(uuid.uuid4())) + + def _upload_chunk(self, final=False): + """Write one part of a multi-block file upload + + Parameters + ========== + final: bool + This is the last block, so should complete file, if + self.autocommit is True. + """ + out = self.fs.session.post( + self.location, + data=self.buffer.getvalue(), + headers={"content-type": "application/octet-stream"}, + ) + out.raise_for_status() + return True + + def _initiate_upload(self): + """Create remote file/upload""" + kwargs = self.kwargs.copy() + if "a" in self.mode: + op, method = "APPEND", "POST" + else: + op, method = "CREATE", "PUT" + kwargs["overwrite"] = "true" + out = self.fs._call(op, method, self.path, redirect=False, **kwargs) + location = self.fs._apply_proxy(out.headers["Location"]) + if "w" in self.mode: + # create empty file to append to + out2 = self.fs.session.put( + location, headers={"content-type": "application/octet-stream"} + ) + out2.raise_for_status() + # after creating empty file, change location to append to + out2 = self.fs._call("APPEND", "POST", self.path, redirect=False, **kwargs) + self.location = self.fs._apply_proxy(out2.headers["Location"]) + + def _fetch_range(self, start, end): + start = max(start, 0) + end = min(self.size, end) + if start >= end or start >= self.size: + return b"" + out = self.fs._call( + "OPEN", path=self.path, offset=start, length=end - start, redirect=False + ) + out.raise_for_status() + if "Location" in out.headers: + location = out.headers["Location"] + out2 = self.fs.session.get(self.fs._apply_proxy(location)) + return out2.content + else: + return out.content + + def commit(self): + self.fs.mv(self.path, self.target) + + def discard(self): + self.fs.rm(self.path) diff --git a/fsspec/implementations/zip.py b/fsspec/implementations/zip.py new file mode 100644 index 0000000000000000000000000000000000000000..6db3ae27806106a19a366886ab4b183f85c1cb1a --- /dev/null +++ b/fsspec/implementations/zip.py @@ -0,0 +1,177 @@ +import os +import zipfile + +import fsspec +from fsspec.archive import AbstractArchiveFileSystem + + +class ZipFileSystem(AbstractArchiveFileSystem): + """Read/Write contents of ZIP archive as a file-system + + Keeps file object open while instance lives. + + This class is pickleable, but not necessarily thread-safe + """ + + root_marker = "" + protocol = "zip" + cachable = False + + def __init__( + self, + fo="", + mode="r", + target_protocol=None, + target_options=None, + compression=zipfile.ZIP_STORED, + allowZip64=True, + compresslevel=None, + **kwargs, + ): + """ + Parameters + ---------- + fo: str or file-like + Contains ZIP, and must exist. If a str, will fetch file using + :meth:`~fsspec.open_files`, which must return one file exactly. + mode: str + Accept: "r", "w", "a" + target_protocol: str (optional) + If ``fo`` is a string, this value can be used to override the + FS protocol inferred from a URL + target_options: dict (optional) + Kwargs passed when instantiating the target FS, if ``fo`` is + a string. + compression, allowZip64, compresslevel: passed to ZipFile + Only relevant when creating a ZIP + """ + super().__init__(self, **kwargs) + if mode not in set("rwa"): + raise ValueError(f"mode '{mode}' no understood") + self.mode = mode + if isinstance(fo, (str, os.PathLike)): + if mode == "a": + m = "r+b" + else: + m = mode + "b" + fo = fsspec.open( + fo, mode=m, protocol=target_protocol, **(target_options or {}) + ) + self.force_zip_64 = allowZip64 + self.of = fo + self.fo = fo.__enter__() # the whole instance is a context + self.zip = zipfile.ZipFile( + self.fo, + mode=mode, + compression=compression, + allowZip64=allowZip64, + compresslevel=compresslevel, + ) + self.dir_cache = None + + @classmethod + def _strip_protocol(cls, path): + # zip file paths are always relative to the archive root + return super()._strip_protocol(path).lstrip("/") + + def __del__(self): + if hasattr(self, "zip"): + self.close() + del self.zip + + def close(self): + """Commits any write changes to the file. Done on ``del`` too.""" + self.zip.close() + + def _get_dirs(self): + if self.dir_cache is None or self.mode in set("wa"): + # when writing, dir_cache is always in the ZipFile's attributes, + # not read from the file. + files = self.zip.infolist() + self.dir_cache = { + dirname.rstrip("/"): { + "name": dirname.rstrip("/"), + "size": 0, + "type": "directory", + } + for dirname in self._all_dirnames(self.zip.namelist()) + } + for z in files: + f = {s: getattr(z, s, None) for s in zipfile.ZipInfo.__slots__} + f.update( + { + "name": z.filename.rstrip("/"), + "size": z.file_size, + "type": ("directory" if z.is_dir() else "file"), + } + ) + self.dir_cache[f["name"]] = f + + def pipe_file(self, path, value, **kwargs): + # override upstream, because we know the exact file size in this case + self.zip.writestr(path, value, **kwargs) + + def _open( + self, + path, + mode="rb", + block_size=None, + autocommit=True, + cache_options=None, + **kwargs, + ): + path = self._strip_protocol(path) + if "r" in mode and self.mode in set("wa"): + if self.exists(path): + raise OSError("ZipFS can only be open for reading or writing, not both") + raise FileNotFoundError(path) + if "r" in self.mode and "w" in mode: + raise OSError("ZipFS can only be open for reading or writing, not both") + out = self.zip.open(path, mode.strip("b"), force_zip64=self.force_zip_64) + if "r" in mode: + info = self.info(path) + out.size = info["size"] + out.name = info["name"] + return out + + def find(self, path, maxdepth=None, withdirs=False, detail=False, **kwargs): + if maxdepth is not None and maxdepth < 1: + raise ValueError("maxdepth must be at least 1") + + # Remove the leading slash, as the zip file paths are always + # given without a leading slash + path = path.lstrip("/") + path_parts = list(filter(lambda s: bool(s), path.split("/"))) + + def _matching_starts(file_path): + file_parts = filter(lambda s: bool(s), file_path.split("/")) + return all(a == b for a, b in zip(path_parts, file_parts)) + + self._get_dirs() + + result = {} + # To match posix find, if an exact file name is given, we should + # return only that file + if path in self.dir_cache and self.dir_cache[path]["type"] == "file": + result[path] = self.dir_cache[path] + return result if detail else [path] + + for file_path, file_info in self.dir_cache.items(): + if not (path == "" or _matching_starts(file_path)): + continue + + if file_info["type"] == "directory": + if withdirs: + if file_path not in result: + result[file_path.strip("/")] = file_info + continue + + if file_path not in result: + result[file_path] = file_info if detail else None + + if maxdepth: + path_depth = path.count("/") + result = { + k: v for k, v in result.items() if k.count("/") - path_depth < maxdepth + } + return result if detail else sorted(result) diff --git a/fsspec/tests/abstract/__init__.py b/fsspec/tests/abstract/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8ed2ad802ecaf021106c25c03112f29e75c7b2f8 --- /dev/null +++ b/fsspec/tests/abstract/__init__.py @@ -0,0 +1,289 @@ +import os +from hashlib import md5 + +import pytest + +from fsspec.implementations.local import LocalFileSystem +from fsspec.tests.abstract.copy import AbstractCopyTests # noqa: F401 +from fsspec.tests.abstract.get import AbstractGetTests # noqa: F401 +from fsspec.tests.abstract.open import AbstractOpenTests # noqa: F401 +from fsspec.tests.abstract.pipe import AbstractPipeTests # noqa: F401 +from fsspec.tests.abstract.put import AbstractPutTests # noqa: F401 + + +class BaseAbstractFixtures: + """ + Abstract base class containing fixtures that are used by but never need to + be overridden in derived filesystem-specific classes to run the abstract + tests on such filesystems. + """ + + @pytest.fixture + def fs_bulk_operations_scenario_0(self, fs, fs_join, fs_path): + """ + Scenario on remote filesystem that is used for many cp/get/put tests. + + Cleans up at the end of each test it which it is used. + """ + source = self._bulk_operations_scenario_0(fs, fs_join, fs_path) + yield source + fs.rm(source, recursive=True) + + @pytest.fixture + def fs_glob_edge_cases_files(self, fs, fs_join, fs_path): + """ + Scenario on remote filesystem that is used for glob edge cases cp/get/put tests. + + Cleans up at the end of each test it which it is used. + """ + source = self._glob_edge_cases_files(fs, fs_join, fs_path) + yield source + fs.rm(source, recursive=True) + + @pytest.fixture + def fs_dir_and_file_with_same_name_prefix(self, fs, fs_join, fs_path): + """ + Scenario on remote filesystem that is used to check cp/get/put on directory + and file with the same name prefixes. + + Cleans up at the end of each test it which it is used. + """ + source = self._dir_and_file_with_same_name_prefix(fs, fs_join, fs_path) + yield source + fs.rm(source, recursive=True) + + @pytest.fixture + def fs_10_files_with_hashed_names(self, fs, fs_join, fs_path): + """ + Scenario on remote filesystem that is used to check cp/get/put files order + when source and destination are lists. + + Cleans up at the end of each test it which it is used. + """ + source = self._10_files_with_hashed_names(fs, fs_join, fs_path) + yield source + fs.rm(source, recursive=True) + + @pytest.fixture + def fs_target(self, fs, fs_join, fs_path): + """ + Return name of remote directory that does not yet exist to copy into. + + Cleans up at the end of each test it which it is used. + """ + target = fs_join(fs_path, "target") + yield target + if fs.exists(target): + fs.rm(target, recursive=True) + + @pytest.fixture + def local_bulk_operations_scenario_0(self, local_fs, local_join, local_path): + """ + Scenario on local filesystem that is used for many cp/get/put tests. + + Cleans up at the end of each test it which it is used. + """ + source = self._bulk_operations_scenario_0(local_fs, local_join, local_path) + yield source + local_fs.rm(source, recursive=True) + + @pytest.fixture + def local_glob_edge_cases_files(self, local_fs, local_join, local_path): + """ + Scenario on local filesystem that is used for glob edge cases cp/get/put tests. + + Cleans up at the end of each test it which it is used. + """ + source = self._glob_edge_cases_files(local_fs, local_join, local_path) + yield source + local_fs.rm(source, recursive=True) + + @pytest.fixture + def local_dir_and_file_with_same_name_prefix( + self, local_fs, local_join, local_path + ): + """ + Scenario on local filesystem that is used to check cp/get/put on directory + and file with the same name prefixes. + + Cleans up at the end of each test it which it is used. + """ + source = self._dir_and_file_with_same_name_prefix( + local_fs, local_join, local_path + ) + yield source + local_fs.rm(source, recursive=True) + + @pytest.fixture + def local_10_files_with_hashed_names(self, local_fs, local_join, local_path): + """ + Scenario on local filesystem that is used to check cp/get/put files order + when source and destination are lists. + + Cleans up at the end of each test it which it is used. + """ + source = self._10_files_with_hashed_names(local_fs, local_join, local_path) + yield source + local_fs.rm(source, recursive=True) + + @pytest.fixture + def local_target(self, local_fs, local_join, local_path): + """ + Return name of local directory that does not yet exist to copy into. + + Cleans up at the end of each test it which it is used. + """ + target = local_join(local_path, "target") + yield target + if local_fs.exists(target): + local_fs.rm(target, recursive=True) + + def _glob_edge_cases_files(self, some_fs, some_join, some_path): + """ + Scenario that is used for glob edge cases cp/get/put tests. + Creates the following directory and file structure: + + 📁 source + ├── 📄 file1 + ├── 📄 file2 + ├── 📁 subdir0 + │ ├── 📄 subfile1 + │ ├── 📄 subfile2 + │ └── 📁 nesteddir + │ └── 📄 nestedfile + └── 📁 subdir1 + ├── 📄 subfile1 + ├── 📄 subfile2 + └── 📁 nesteddir + └── 📄 nestedfile + """ + source = some_join(some_path, "source") + some_fs.touch(some_join(source, "file1")) + some_fs.touch(some_join(source, "file2")) + + for subdir_idx in range(2): + subdir = some_join(source, f"subdir{subdir_idx}") + nesteddir = some_join(subdir, "nesteddir") + some_fs.makedirs(nesteddir) + some_fs.touch(some_join(subdir, "subfile1")) + some_fs.touch(some_join(subdir, "subfile2")) + some_fs.touch(some_join(nesteddir, "nestedfile")) + + return source + + def _bulk_operations_scenario_0(self, some_fs, some_join, some_path): + """ + Scenario that is used for many cp/get/put tests. Creates the following + directory and file structure: + + 📁 source + ├── 📄 file1 + ├── 📄 file2 + └── 📁 subdir + ├── 📄 subfile1 + ├── 📄 subfile2 + └── 📁 nesteddir + └── 📄 nestedfile + """ + source = some_join(some_path, "source") + subdir = some_join(source, "subdir") + nesteddir = some_join(subdir, "nesteddir") + some_fs.makedirs(nesteddir) + some_fs.touch(some_join(source, "file1")) + some_fs.touch(some_join(source, "file2")) + some_fs.touch(some_join(subdir, "subfile1")) + some_fs.touch(some_join(subdir, "subfile2")) + some_fs.touch(some_join(nesteddir, "nestedfile")) + return source + + def _dir_and_file_with_same_name_prefix(self, some_fs, some_join, some_path): + """ + Scenario that is used to check cp/get/put on directory and file with + the same name prefixes. Creates the following directory and file structure: + + 📁 source + ├── 📄 subdir.txt + └── 📁 subdir + └── 📄 subfile.txt + """ + source = some_join(some_path, "source") + subdir = some_join(source, "subdir") + file = some_join(source, "subdir.txt") + subfile = some_join(subdir, "subfile.txt") + some_fs.makedirs(subdir) + some_fs.touch(file) + some_fs.touch(subfile) + return source + + def _10_files_with_hashed_names(self, some_fs, some_join, some_path): + """ + Scenario that is used to check cp/get/put files order when source and + destination are lists. Creates the following directory and file structure: + + 📁 source + └── 📄 {hashed([0-9])}.txt + """ + source = some_join(some_path, "source") + for i in range(10): + hashed_i = md5(str(i).encode("utf-8")).hexdigest() + path = some_join(source, f"{hashed_i}.txt") + some_fs.pipe(path=path, value=f"{i}".encode()) + return source + + +class AbstractFixtures(BaseAbstractFixtures): + """ + Abstract base class containing fixtures that may be overridden in derived + filesystem-specific classes to run the abstract tests on such filesystems. + + For any particular filesystem some of these fixtures must be overridden, + such as ``fs`` and ``fs_path``, and others may be overridden if the + default functions here are not appropriate, such as ``fs_join``. + """ + + @pytest.fixture + def fs(self): + raise NotImplementedError("This function must be overridden in derived classes") + + @pytest.fixture + def fs_join(self): + """ + Return a function that joins its arguments together into a path. + + Most fsspec implementations join paths in a platform-dependent way, + but some will override this to always use a forward slash. + """ + return os.path.join + + @pytest.fixture + def fs_path(self): + raise NotImplementedError("This function must be overridden in derived classes") + + @pytest.fixture(scope="class") + def local_fs(self): + # Maybe need an option for auto_mkdir=False? This is only relevant + # for certain implementations. + return LocalFileSystem(auto_mkdir=True) + + @pytest.fixture + def local_join(self): + """ + Return a function that joins its arguments together into a path, on + the local filesystem. + """ + return os.path.join + + @pytest.fixture + def local_path(self, tmpdir): + return tmpdir + + @pytest.fixture + def supports_empty_directories(self): + """ + Return whether this implementation supports empty directories. + """ + return True + + @pytest.fixture + def fs_sanitize_path(self): + return lambda x: x diff --git a/fsspec/tests/abstract/common.py b/fsspec/tests/abstract/common.py new file mode 100644 index 0000000000000000000000000000000000000000..22e7c4140404ab2a8928689721419cf05c2760b9 --- /dev/null +++ b/fsspec/tests/abstract/common.py @@ -0,0 +1,175 @@ +GLOB_EDGE_CASES_TESTS = { + "argnames": ("path", "recursive", "maxdepth", "expected"), + "argvalues": [ + ("fil?1", False, None, ["file1"]), + ("fil?1", True, None, ["file1"]), + ("file[1-2]", False, None, ["file1", "file2"]), + ("file[1-2]", True, None, ["file1", "file2"]), + ("*", False, None, ["file1", "file2"]), + ( + "*", + True, + None, + [ + "file1", + "file2", + "subdir0/subfile1", + "subdir0/subfile2", + "subdir0/nesteddir/nestedfile", + "subdir1/subfile1", + "subdir1/subfile2", + "subdir1/nesteddir/nestedfile", + ], + ), + ("*", True, 1, ["file1", "file2"]), + ( + "*", + True, + 2, + [ + "file1", + "file2", + "subdir0/subfile1", + "subdir0/subfile2", + "subdir1/subfile1", + "subdir1/subfile2", + ], + ), + ("*1", False, None, ["file1"]), + ( + "*1", + True, + None, + [ + "file1", + "subdir1/subfile1", + "subdir1/subfile2", + "subdir1/nesteddir/nestedfile", + ], + ), + ("*1", True, 2, ["file1", "subdir1/subfile1", "subdir1/subfile2"]), + ( + "**", + False, + None, + [ + "file1", + "file2", + "subdir0/subfile1", + "subdir0/subfile2", + "subdir0/nesteddir/nestedfile", + "subdir1/subfile1", + "subdir1/subfile2", + "subdir1/nesteddir/nestedfile", + ], + ), + ( + "**", + True, + None, + [ + "file1", + "file2", + "subdir0/subfile1", + "subdir0/subfile2", + "subdir0/nesteddir/nestedfile", + "subdir1/subfile1", + "subdir1/subfile2", + "subdir1/nesteddir/nestedfile", + ], + ), + ("**", True, 1, ["file1", "file2"]), + ( + "**", + True, + 2, + [ + "file1", + "file2", + "subdir0/subfile1", + "subdir0/subfile2", + "subdir0/nesteddir/nestedfile", + "subdir1/subfile1", + "subdir1/subfile2", + "subdir1/nesteddir/nestedfile", + ], + ), + ( + "**", + False, + 2, + [ + "file1", + "file2", + "subdir0/subfile1", + "subdir0/subfile2", + "subdir1/subfile1", + "subdir1/subfile2", + ], + ), + ("**/*1", False, None, ["file1", "subdir0/subfile1", "subdir1/subfile1"]), + ( + "**/*1", + True, + None, + [ + "file1", + "subdir0/subfile1", + "subdir1/subfile1", + "subdir1/subfile2", + "subdir1/nesteddir/nestedfile", + ], + ), + ("**/*1", True, 1, ["file1"]), + ( + "**/*1", + True, + 2, + ["file1", "subdir0/subfile1", "subdir1/subfile1", "subdir1/subfile2"], + ), + ("**/*1", False, 2, ["file1", "subdir0/subfile1", "subdir1/subfile1"]), + ("**/subdir0", False, None, []), + ("**/subdir0", True, None, ["subfile1", "subfile2", "nesteddir/nestedfile"]), + ("**/subdir0/nested*", False, 2, []), + ("**/subdir0/nested*", True, 2, ["nestedfile"]), + ("subdir[1-2]", False, None, []), + ("subdir[1-2]", True, None, ["subfile1", "subfile2", "nesteddir/nestedfile"]), + ("subdir[1-2]", True, 2, ["subfile1", "subfile2"]), + ("subdir[0-1]", False, None, []), + ( + "subdir[0-1]", + True, + None, + [ + "subdir0/subfile1", + "subdir0/subfile2", + "subdir0/nesteddir/nestedfile", + "subdir1/subfile1", + "subdir1/subfile2", + "subdir1/nesteddir/nestedfile", + ], + ), + ( + "subdir[0-1]/*fil[e]*", + False, + None, + [ + "subdir0/subfile1", + "subdir0/subfile2", + "subdir1/subfile1", + "subdir1/subfile2", + ], + ), + ( + "subdir[0-1]/*fil[e]*", + True, + None, + [ + "subdir0/subfile1", + "subdir0/subfile2", + "subdir1/subfile1", + "subdir1/subfile2", + ], + ), + ], +} diff --git a/fsspec/tests/abstract/copy.py b/fsspec/tests/abstract/copy.py new file mode 100644 index 0000000000000000000000000000000000000000..e39e57e5f7d52bfda8ab5e2398b04cc2303630a0 --- /dev/null +++ b/fsspec/tests/abstract/copy.py @@ -0,0 +1,557 @@ +from hashlib import md5 +from itertools import product + +import pytest + +from fsspec.tests.abstract.common import GLOB_EDGE_CASES_TESTS + + +class AbstractCopyTests: + def test_copy_file_to_existing_directory( + self, + fs, + fs_join, + fs_bulk_operations_scenario_0, + fs_target, + supports_empty_directories, + ): + # Copy scenario 1a + source = fs_bulk_operations_scenario_0 + + target = fs_target + fs.mkdir(target) + if not supports_empty_directories: + # Force target directory to exist by adding a dummy file + fs.touch(fs_join(target, "dummy")) + assert fs.isdir(target) + + target_file2 = fs_join(target, "file2") + target_subfile1 = fs_join(target, "subfile1") + + # Copy from source directory + fs.cp(fs_join(source, "file2"), target) + assert fs.isfile(target_file2) + + # Copy from sub directory + fs.cp(fs_join(source, "subdir", "subfile1"), target) + assert fs.isfile(target_subfile1) + + # Remove copied files + fs.rm([target_file2, target_subfile1]) + assert not fs.exists(target_file2) + assert not fs.exists(target_subfile1) + + # Repeat with trailing slash on target + fs.cp(fs_join(source, "file2"), target + "/") + assert fs.isdir(target) + assert fs.isfile(target_file2) + + fs.cp(fs_join(source, "subdir", "subfile1"), target + "/") + assert fs.isfile(target_subfile1) + + def test_copy_file_to_new_directory( + self, fs, fs_join, fs_bulk_operations_scenario_0, fs_target + ): + # Copy scenario 1b + source = fs_bulk_operations_scenario_0 + + target = fs_target + fs.mkdir(target) + + fs.cp( + fs_join(source, "subdir", "subfile1"), fs_join(target, "newdir/") + ) # Note trailing slash + assert fs.isdir(target) + assert fs.isdir(fs_join(target, "newdir")) + assert fs.isfile(fs_join(target, "newdir", "subfile1")) + + def test_copy_file_to_file_in_existing_directory( + self, + fs, + fs_join, + fs_bulk_operations_scenario_0, + fs_target, + supports_empty_directories, + ): + # Copy scenario 1c + source = fs_bulk_operations_scenario_0 + + target = fs_target + fs.mkdir(target) + if not supports_empty_directories: + # Force target directory to exist by adding a dummy file + fs.touch(fs_join(target, "dummy")) + assert fs.isdir(target) + + fs.cp(fs_join(source, "subdir", "subfile1"), fs_join(target, "newfile")) + assert fs.isfile(fs_join(target, "newfile")) + + def test_copy_file_to_file_in_new_directory( + self, fs, fs_join, fs_bulk_operations_scenario_0, fs_target + ): + # Copy scenario 1d + source = fs_bulk_operations_scenario_0 + + target = fs_target + fs.mkdir(target) + + fs.cp( + fs_join(source, "subdir", "subfile1"), fs_join(target, "newdir", "newfile") + ) + assert fs.isdir(fs_join(target, "newdir")) + assert fs.isfile(fs_join(target, "newdir", "newfile")) + + def test_copy_directory_to_existing_directory( + self, + fs, + fs_join, + fs_bulk_operations_scenario_0, + fs_target, + supports_empty_directories, + ): + # Copy scenario 1e + source = fs_bulk_operations_scenario_0 + + target = fs_target + fs.mkdir(target) + if not supports_empty_directories: + # Force target directory to exist by adding a dummy file + dummy = fs_join(target, "dummy") + fs.touch(dummy) + assert fs.isdir(target) + + for source_slash, target_slash in zip([False, True], [False, True]): + s = fs_join(source, "subdir") + if source_slash: + s += "/" + t = target + "/" if target_slash else target + + # Without recursive does nothing + fs.cp(s, t) + assert fs.ls(target, detail=False) == ( + [] if supports_empty_directories else [dummy] + ) + + # With recursive + fs.cp(s, t, recursive=True) + if source_slash: + assert fs.isfile(fs_join(target, "subfile1")) + assert fs.isfile(fs_join(target, "subfile2")) + assert fs.isdir(fs_join(target, "nesteddir")) + assert fs.isfile(fs_join(target, "nesteddir", "nestedfile")) + assert not fs.exists(fs_join(target, "subdir")) + + fs.rm( + [ + fs_join(target, "subfile1"), + fs_join(target, "subfile2"), + fs_join(target, "nesteddir"), + ], + recursive=True, + ) + else: + assert fs.isdir(fs_join(target, "subdir")) + assert fs.isfile(fs_join(target, "subdir", "subfile1")) + assert fs.isfile(fs_join(target, "subdir", "subfile2")) + assert fs.isdir(fs_join(target, "subdir", "nesteddir")) + assert fs.isfile(fs_join(target, "subdir", "nesteddir", "nestedfile")) + + fs.rm(fs_join(target, "subdir"), recursive=True) + assert fs.ls(target, detail=False) == ( + [] if supports_empty_directories else [dummy] + ) + + # Limit recursive by maxdepth + fs.cp(s, t, recursive=True, maxdepth=1) + if source_slash: + assert fs.isfile(fs_join(target, "subfile1")) + assert fs.isfile(fs_join(target, "subfile2")) + assert not fs.exists(fs_join(target, "nesteddir")) + assert not fs.exists(fs_join(target, "subdir")) + + fs.rm( + [ + fs_join(target, "subfile1"), + fs_join(target, "subfile2"), + ], + recursive=True, + ) + else: + assert fs.isdir(fs_join(target, "subdir")) + assert fs.isfile(fs_join(target, "subdir", "subfile1")) + assert fs.isfile(fs_join(target, "subdir", "subfile2")) + assert not fs.exists(fs_join(target, "subdir", "nesteddir")) + + fs.rm(fs_join(target, "subdir"), recursive=True) + assert fs.ls(target, detail=False) == ( + [] if supports_empty_directories else [dummy] + ) + + def test_copy_directory_to_new_directory( + self, + fs, + fs_join, + fs_bulk_operations_scenario_0, + fs_target, + supports_empty_directories, + ): + # Copy scenario 1f + source = fs_bulk_operations_scenario_0 + + target = fs_target + fs.mkdir(target) + + for source_slash, target_slash in zip([False, True], [False, True]): + s = fs_join(source, "subdir") + if source_slash: + s += "/" + t = fs_join(target, "newdir") + if target_slash: + t += "/" + + # Without recursive does nothing + fs.cp(s, t) + if supports_empty_directories: + assert fs.ls(target) == [] + else: + with pytest.raises(FileNotFoundError): + fs.ls(target) + + # With recursive + fs.cp(s, t, recursive=True) + assert fs.isdir(fs_join(target, "newdir")) + assert fs.isfile(fs_join(target, "newdir", "subfile1")) + assert fs.isfile(fs_join(target, "newdir", "subfile2")) + assert fs.isdir(fs_join(target, "newdir", "nesteddir")) + assert fs.isfile(fs_join(target, "newdir", "nesteddir", "nestedfile")) + assert not fs.exists(fs_join(target, "subdir")) + + fs.rm(fs_join(target, "newdir"), recursive=True) + assert not fs.exists(fs_join(target, "newdir")) + + # Limit recursive by maxdepth + fs.cp(s, t, recursive=True, maxdepth=1) + assert fs.isdir(fs_join(target, "newdir")) + assert fs.isfile(fs_join(target, "newdir", "subfile1")) + assert fs.isfile(fs_join(target, "newdir", "subfile2")) + assert not fs.exists(fs_join(target, "newdir", "nesteddir")) + assert not fs.exists(fs_join(target, "subdir")) + + fs.rm(fs_join(target, "newdir"), recursive=True) + assert not fs.exists(fs_join(target, "newdir")) + + def test_copy_glob_to_existing_directory( + self, + fs, + fs_join, + fs_bulk_operations_scenario_0, + fs_target, + supports_empty_directories, + ): + # Copy scenario 1g + source = fs_bulk_operations_scenario_0 + + target = fs_target + fs.mkdir(target) + if not supports_empty_directories: + # Force target directory to exist by adding a dummy file + dummy = fs_join(target, "dummy") + fs.touch(dummy) + assert fs.isdir(target) + + for target_slash in [False, True]: + t = target + "/" if target_slash else target + + # Without recursive + fs.cp(fs_join(source, "subdir", "*"), t) + assert fs.isfile(fs_join(target, "subfile1")) + assert fs.isfile(fs_join(target, "subfile2")) + assert not fs.isdir(fs_join(target, "nesteddir")) + assert not fs.exists(fs_join(target, "nesteddir", "nestedfile")) + assert not fs.exists(fs_join(target, "subdir")) + + fs.rm( + [ + fs_join(target, "subfile1"), + fs_join(target, "subfile2"), + ], + recursive=True, + ) + assert fs.ls(target, detail=False) == ( + [] if supports_empty_directories else [dummy] + ) + + # With recursive + for glob, recursive in zip(["*", "**"], [True, False]): + fs.cp(fs_join(source, "subdir", glob), t, recursive=recursive) + assert fs.isfile(fs_join(target, "subfile1")) + assert fs.isfile(fs_join(target, "subfile2")) + assert fs.isdir(fs_join(target, "nesteddir")) + assert fs.isfile(fs_join(target, "nesteddir", "nestedfile")) + assert not fs.exists(fs_join(target, "subdir")) + + fs.rm( + [ + fs_join(target, "subfile1"), + fs_join(target, "subfile2"), + fs_join(target, "nesteddir"), + ], + recursive=True, + ) + assert fs.ls(target, detail=False) == ( + [] if supports_empty_directories else [dummy] + ) + + # Limit recursive by maxdepth + fs.cp( + fs_join(source, "subdir", glob), t, recursive=recursive, maxdepth=1 + ) + assert fs.isfile(fs_join(target, "subfile1")) + assert fs.isfile(fs_join(target, "subfile2")) + assert not fs.exists(fs_join(target, "nesteddir")) + assert not fs.exists(fs_join(target, "subdir")) + + fs.rm( + [ + fs_join(target, "subfile1"), + fs_join(target, "subfile2"), + ], + recursive=True, + ) + assert fs.ls(target, detail=False) == ( + [] if supports_empty_directories else [dummy] + ) + + def test_copy_glob_to_new_directory( + self, fs, fs_join, fs_bulk_operations_scenario_0, fs_target + ): + # Copy scenario 1h + source = fs_bulk_operations_scenario_0 + + target = fs_target + fs.mkdir(target) + + for target_slash in [False, True]: + t = fs_join(target, "newdir") + if target_slash: + t += "/" + + # Without recursive + fs.cp(fs_join(source, "subdir", "*"), t) + assert fs.isdir(fs_join(target, "newdir")) + assert fs.isfile(fs_join(target, "newdir", "subfile1")) + assert fs.isfile(fs_join(target, "newdir", "subfile2")) + assert not fs.exists(fs_join(target, "newdir", "nesteddir")) + assert not fs.exists(fs_join(target, "newdir", "nesteddir", "nestedfile")) + assert not fs.exists(fs_join(target, "subdir")) + assert not fs.exists(fs_join(target, "newdir", "subdir")) + + fs.rm(fs_join(target, "newdir"), recursive=True) + assert not fs.exists(fs_join(target, "newdir")) + + # With recursive + for glob, recursive in zip(["*", "**"], [True, False]): + fs.cp(fs_join(source, "subdir", glob), t, recursive=recursive) + assert fs.isdir(fs_join(target, "newdir")) + assert fs.isfile(fs_join(target, "newdir", "subfile1")) + assert fs.isfile(fs_join(target, "newdir", "subfile2")) + assert fs.isdir(fs_join(target, "newdir", "nesteddir")) + assert fs.isfile(fs_join(target, "newdir", "nesteddir", "nestedfile")) + assert not fs.exists(fs_join(target, "subdir")) + assert not fs.exists(fs_join(target, "newdir", "subdir")) + + fs.rm(fs_join(target, "newdir"), recursive=True) + assert not fs.exists(fs_join(target, "newdir")) + + # Limit recursive by maxdepth + fs.cp( + fs_join(source, "subdir", glob), t, recursive=recursive, maxdepth=1 + ) + assert fs.isdir(fs_join(target, "newdir")) + assert fs.isfile(fs_join(target, "newdir", "subfile1")) + assert fs.isfile(fs_join(target, "newdir", "subfile2")) + assert not fs.exists(fs_join(target, "newdir", "nesteddir")) + assert not fs.exists(fs_join(target, "subdir")) + assert not fs.exists(fs_join(target, "newdir", "subdir")) + + fs.rm(fs_join(target, "newdir"), recursive=True) + assert not fs.exists(fs_join(target, "newdir")) + + @pytest.mark.parametrize( + GLOB_EDGE_CASES_TESTS["argnames"], + GLOB_EDGE_CASES_TESTS["argvalues"], + ) + def test_copy_glob_edge_cases( + self, + path, + recursive, + maxdepth, + expected, + fs, + fs_join, + fs_glob_edge_cases_files, + fs_target, + fs_sanitize_path, + ): + # Copy scenario 1g + source = fs_glob_edge_cases_files + + target = fs_target + + for new_dir, target_slash in product([True, False], [True, False]): + fs.mkdir(target) + + t = fs_join(target, "newdir") if new_dir else target + t = t + "/" if target_slash else t + + fs.copy(fs_join(source, path), t, recursive=recursive, maxdepth=maxdepth) + + output = fs.find(target) + if new_dir: + prefixed_expected = [ + fs_sanitize_path(fs_join(target, "newdir", p)) for p in expected + ] + else: + prefixed_expected = [ + fs_sanitize_path(fs_join(target, p)) for p in expected + ] + assert sorted(output) == sorted(prefixed_expected) + + try: + fs.rm(target, recursive=True) + except FileNotFoundError: + pass + + def test_copy_list_of_files_to_existing_directory( + self, + fs, + fs_join, + fs_bulk_operations_scenario_0, + fs_target, + supports_empty_directories, + ): + # Copy scenario 2a + source = fs_bulk_operations_scenario_0 + + target = fs_target + fs.mkdir(target) + if not supports_empty_directories: + # Force target directory to exist by adding a dummy file + dummy = fs_join(target, "dummy") + fs.touch(dummy) + assert fs.isdir(target) + + source_files = [ + fs_join(source, "file1"), + fs_join(source, "file2"), + fs_join(source, "subdir", "subfile1"), + ] + + for target_slash in [False, True]: + t = target + "/" if target_slash else target + + fs.cp(source_files, t) + assert fs.isfile(fs_join(target, "file1")) + assert fs.isfile(fs_join(target, "file2")) + assert fs.isfile(fs_join(target, "subfile1")) + + fs.rm( + [ + fs_join(target, "file1"), + fs_join(target, "file2"), + fs_join(target, "subfile1"), + ], + recursive=True, + ) + assert fs.ls(target, detail=False) == ( + [] if supports_empty_directories else [dummy] + ) + + def test_copy_list_of_files_to_new_directory( + self, fs, fs_join, fs_bulk_operations_scenario_0, fs_target + ): + # Copy scenario 2b + source = fs_bulk_operations_scenario_0 + + target = fs_target + fs.mkdir(target) + + source_files = [ + fs_join(source, "file1"), + fs_join(source, "file2"), + fs_join(source, "subdir", "subfile1"), + ] + + fs.cp(source_files, fs_join(target, "newdir") + "/") # Note trailing slash + assert fs.isdir(fs_join(target, "newdir")) + assert fs.isfile(fs_join(target, "newdir", "file1")) + assert fs.isfile(fs_join(target, "newdir", "file2")) + assert fs.isfile(fs_join(target, "newdir", "subfile1")) + + def test_copy_two_files_new_directory( + self, fs, fs_join, fs_bulk_operations_scenario_0, fs_target + ): + # This is a duplicate of test_copy_list_of_files_to_new_directory and + # can eventually be removed. + source = fs_bulk_operations_scenario_0 + + target = fs_target + assert not fs.exists(target) + fs.cp([fs_join(source, "file1"), fs_join(source, "file2")], target) + + assert fs.isdir(target) + assert fs.isfile(fs_join(target, "file1")) + assert fs.isfile(fs_join(target, "file2")) + + def test_copy_directory_without_files_with_same_name_prefix( + self, + fs, + fs_join, + fs_target, + fs_dir_and_file_with_same_name_prefix, + supports_empty_directories, + ): + # Create the test dirs + source = fs_dir_and_file_with_same_name_prefix + target = fs_target + + # Test without glob + fs.cp(fs_join(source, "subdir"), target, recursive=True) + + assert fs.isfile(fs_join(target, "subfile.txt")) + assert not fs.isfile(fs_join(target, "subdir.txt")) + + fs.rm([fs_join(target, "subfile.txt")]) + if supports_empty_directories: + assert fs.ls(target) == [] + else: + assert not fs.exists(target) + + # Test with glob + fs.cp(fs_join(source, "subdir*"), target, recursive=True) + + assert fs.isdir(fs_join(target, "subdir")) + assert fs.isfile(fs_join(target, "subdir", "subfile.txt")) + assert fs.isfile(fs_join(target, "subdir.txt")) + + def test_copy_with_source_and_destination_as_list( + self, fs, fs_target, fs_join, fs_10_files_with_hashed_names + ): + # Create the test dir + source = fs_10_files_with_hashed_names + target = fs_target + + # Create list of files for source and destination + source_files = [] + destination_files = [] + for i in range(10): + hashed_i = md5(str(i).encode("utf-8")).hexdigest() + source_files.append(fs_join(source, f"{hashed_i}.txt")) + destination_files.append(fs_join(target, f"{hashed_i}.txt")) + + # Copy and assert order was kept + fs.copy(path1=source_files, path2=destination_files) + + for i in range(10): + file_content = fs.cat(destination_files[i]).decode("utf-8") + assert file_content == str(i) diff --git a/fsspec/tests/abstract/get.py b/fsspec/tests/abstract/get.py new file mode 100644 index 0000000000000000000000000000000000000000..851ab81ee581e74cac41c64c83ef0af75826d6b0 --- /dev/null +++ b/fsspec/tests/abstract/get.py @@ -0,0 +1,587 @@ +from hashlib import md5 +from itertools import product + +import pytest + +from fsspec.implementations.local import make_path_posix +from fsspec.tests.abstract.common import GLOB_EDGE_CASES_TESTS + + +class AbstractGetTests: + def test_get_file_to_existing_directory( + self, + fs, + fs_join, + fs_bulk_operations_scenario_0, + local_fs, + local_join, + local_target, + ): + # Copy scenario 1a + source = fs_bulk_operations_scenario_0 + + target = local_target + local_fs.mkdir(target) + assert local_fs.isdir(target) + + target_file2 = local_join(target, "file2") + target_subfile1 = local_join(target, "subfile1") + + # Copy from source directory + fs.get(fs_join(source, "file2"), target) + assert local_fs.isfile(target_file2) + + # Copy from sub directory + fs.get(fs_join(source, "subdir", "subfile1"), target) + assert local_fs.isfile(target_subfile1) + + # Remove copied files + local_fs.rm([target_file2, target_subfile1]) + assert not local_fs.exists(target_file2) + assert not local_fs.exists(target_subfile1) + + # Repeat with trailing slash on target + fs.get(fs_join(source, "file2"), target + "/") + assert local_fs.isdir(target) + assert local_fs.isfile(target_file2) + + fs.get(fs_join(source, "subdir", "subfile1"), target + "/") + assert local_fs.isfile(target_subfile1) + + def test_get_file_to_new_directory( + self, + fs, + fs_join, + fs_bulk_operations_scenario_0, + local_fs, + local_join, + local_target, + ): + # Copy scenario 1b + source = fs_bulk_operations_scenario_0 + + target = local_target + local_fs.mkdir(target) + + fs.get( + fs_join(source, "subdir", "subfile1"), local_join(target, "newdir/") + ) # Note trailing slash + + assert local_fs.isdir(target) + assert local_fs.isdir(local_join(target, "newdir")) + assert local_fs.isfile(local_join(target, "newdir", "subfile1")) + + def test_get_file_to_file_in_existing_directory( + self, + fs, + fs_join, + fs_bulk_operations_scenario_0, + local_fs, + local_join, + local_target, + ): + # Copy scenario 1c + source = fs_bulk_operations_scenario_0 + + target = local_target + local_fs.mkdir(target) + + fs.get(fs_join(source, "subdir", "subfile1"), local_join(target, "newfile")) + assert local_fs.isfile(local_join(target, "newfile")) + + def test_get_file_to_file_in_new_directory( + self, + fs, + fs_join, + fs_bulk_operations_scenario_0, + local_fs, + local_join, + local_target, + ): + # Copy scenario 1d + source = fs_bulk_operations_scenario_0 + + target = local_target + local_fs.mkdir(target) + + fs.get( + fs_join(source, "subdir", "subfile1"), + local_join(target, "newdir", "newfile"), + ) + assert local_fs.isdir(local_join(target, "newdir")) + assert local_fs.isfile(local_join(target, "newdir", "newfile")) + + def test_get_directory_to_existing_directory( + self, + fs, + fs_join, + fs_bulk_operations_scenario_0, + local_fs, + local_join, + local_target, + ): + # Copy scenario 1e + source = fs_bulk_operations_scenario_0 + + target = local_target + local_fs.mkdir(target) + assert local_fs.isdir(target) + + for source_slash, target_slash in zip([False, True], [False, True]): + s = fs_join(source, "subdir") + if source_slash: + s += "/" + t = target + "/" if target_slash else target + + # Without recursive does nothing + fs.get(s, t) + assert local_fs.ls(target) == [] + + # With recursive + fs.get(s, t, recursive=True) + if source_slash: + assert local_fs.isfile(local_join(target, "subfile1")) + assert local_fs.isfile(local_join(target, "subfile2")) + assert local_fs.isdir(local_join(target, "nesteddir")) + assert local_fs.isfile(local_join(target, "nesteddir", "nestedfile")) + assert not local_fs.exists(local_join(target, "subdir")) + + local_fs.rm( + [ + local_join(target, "subfile1"), + local_join(target, "subfile2"), + local_join(target, "nesteddir"), + ], + recursive=True, + ) + else: + assert local_fs.isdir(local_join(target, "subdir")) + assert local_fs.isfile(local_join(target, "subdir", "subfile1")) + assert local_fs.isfile(local_join(target, "subdir", "subfile2")) + assert local_fs.isdir(local_join(target, "subdir", "nesteddir")) + assert local_fs.isfile( + local_join(target, "subdir", "nesteddir", "nestedfile") + ) + + local_fs.rm(local_join(target, "subdir"), recursive=True) + assert local_fs.ls(target) == [] + + # Limit recursive by maxdepth + fs.get(s, t, recursive=True, maxdepth=1) + if source_slash: + assert local_fs.isfile(local_join(target, "subfile1")) + assert local_fs.isfile(local_join(target, "subfile2")) + assert not local_fs.exists(local_join(target, "nesteddir")) + assert not local_fs.exists(local_join(target, "subdir")) + + local_fs.rm( + [ + local_join(target, "subfile1"), + local_join(target, "subfile2"), + ], + recursive=True, + ) + else: + assert local_fs.isdir(local_join(target, "subdir")) + assert local_fs.isfile(local_join(target, "subdir", "subfile1")) + assert local_fs.isfile(local_join(target, "subdir", "subfile2")) + assert not local_fs.exists(local_join(target, "subdir", "nesteddir")) + + local_fs.rm(local_join(target, "subdir"), recursive=True) + assert local_fs.ls(target) == [] + + def test_get_directory_to_new_directory( + self, + fs, + fs_join, + fs_bulk_operations_scenario_0, + local_fs, + local_join, + local_target, + ): + # Copy scenario 1f + source = fs_bulk_operations_scenario_0 + + target = local_target + local_fs.mkdir(target) + + for source_slash, target_slash in zip([False, True], [False, True]): + s = fs_join(source, "subdir") + if source_slash: + s += "/" + t = local_join(target, "newdir") + if target_slash: + t += "/" + + # Without recursive does nothing + fs.get(s, t) + assert local_fs.ls(target) == [] + + # With recursive + fs.get(s, t, recursive=True) + assert local_fs.isdir(local_join(target, "newdir")) + assert local_fs.isfile(local_join(target, "newdir", "subfile1")) + assert local_fs.isfile(local_join(target, "newdir", "subfile2")) + assert local_fs.isdir(local_join(target, "newdir", "nesteddir")) + assert local_fs.isfile( + local_join(target, "newdir", "nesteddir", "nestedfile") + ) + assert not local_fs.exists(local_join(target, "subdir")) + + local_fs.rm(local_join(target, "newdir"), recursive=True) + assert local_fs.ls(target) == [] + + # Limit recursive by maxdepth + fs.get(s, t, recursive=True, maxdepth=1) + assert local_fs.isdir(local_join(target, "newdir")) + assert local_fs.isfile(local_join(target, "newdir", "subfile1")) + assert local_fs.isfile(local_join(target, "newdir", "subfile2")) + assert not local_fs.exists(local_join(target, "newdir", "nesteddir")) + assert not local_fs.exists(local_join(target, "subdir")) + + local_fs.rm(local_join(target, "newdir"), recursive=True) + assert not local_fs.exists(local_join(target, "newdir")) + + def test_get_glob_to_existing_directory( + self, + fs, + fs_join, + fs_bulk_operations_scenario_0, + local_fs, + local_join, + local_target, + ): + # Copy scenario 1g + source = fs_bulk_operations_scenario_0 + + target = local_target + local_fs.mkdir(target) + + for target_slash in [False, True]: + t = target + "/" if target_slash else target + + # Without recursive + fs.get(fs_join(source, "subdir", "*"), t) + assert local_fs.isfile(local_join(target, "subfile1")) + assert local_fs.isfile(local_join(target, "subfile2")) + assert not local_fs.isdir(local_join(target, "nesteddir")) + assert not local_fs.exists(local_join(target, "nesteddir", "nestedfile")) + assert not local_fs.exists(local_join(target, "subdir")) + + local_fs.rm( + [ + local_join(target, "subfile1"), + local_join(target, "subfile2"), + ], + recursive=True, + ) + assert local_fs.ls(target) == [] + + # With recursive + for glob, recursive in zip(["*", "**"], [True, False]): + fs.get(fs_join(source, "subdir", glob), t, recursive=recursive) + assert local_fs.isfile(local_join(target, "subfile1")) + assert local_fs.isfile(local_join(target, "subfile2")) + assert local_fs.isdir(local_join(target, "nesteddir")) + assert local_fs.isfile(local_join(target, "nesteddir", "nestedfile")) + assert not local_fs.exists(local_join(target, "subdir")) + + local_fs.rm( + [ + local_join(target, "subfile1"), + local_join(target, "subfile2"), + local_join(target, "nesteddir"), + ], + recursive=True, + ) + assert local_fs.ls(target) == [] + + # Limit recursive by maxdepth + fs.get( + fs_join(source, "subdir", glob), t, recursive=recursive, maxdepth=1 + ) + assert local_fs.isfile(local_join(target, "subfile1")) + assert local_fs.isfile(local_join(target, "subfile2")) + assert not local_fs.exists(local_join(target, "nesteddir")) + assert not local_fs.exists(local_join(target, "subdir")) + + local_fs.rm( + [ + local_join(target, "subfile1"), + local_join(target, "subfile2"), + ], + recursive=True, + ) + assert local_fs.ls(target) == [] + + def test_get_glob_to_new_directory( + self, + fs, + fs_join, + fs_bulk_operations_scenario_0, + local_fs, + local_join, + local_target, + ): + # Copy scenario 1h + source = fs_bulk_operations_scenario_0 + + target = local_target + local_fs.mkdir(target) + + for target_slash in [False, True]: + t = fs_join(target, "newdir") + if target_slash: + t += "/" + + # Without recursive + fs.get(fs_join(source, "subdir", "*"), t) + assert local_fs.isdir(local_join(target, "newdir")) + assert local_fs.isfile(local_join(target, "newdir", "subfile1")) + assert local_fs.isfile(local_join(target, "newdir", "subfile2")) + assert not local_fs.exists(local_join(target, "newdir", "nesteddir")) + assert not local_fs.exists( + local_join(target, "newdir", "nesteddir", "nestedfile") + ) + assert not local_fs.exists(local_join(target, "subdir")) + assert not local_fs.exists(local_join(target, "newdir", "subdir")) + + local_fs.rm(local_join(target, "newdir"), recursive=True) + assert local_fs.ls(target) == [] + + # With recursive + for glob, recursive in zip(["*", "**"], [True, False]): + fs.get(fs_join(source, "subdir", glob), t, recursive=recursive) + assert local_fs.isdir(local_join(target, "newdir")) + assert local_fs.isfile(local_join(target, "newdir", "subfile1")) + assert local_fs.isfile(local_join(target, "newdir", "subfile2")) + assert local_fs.isdir(local_join(target, "newdir", "nesteddir")) + assert local_fs.isfile( + local_join(target, "newdir", "nesteddir", "nestedfile") + ) + assert not local_fs.exists(local_join(target, "subdir")) + assert not local_fs.exists(local_join(target, "newdir", "subdir")) + + local_fs.rm(local_join(target, "newdir"), recursive=True) + assert not local_fs.exists(local_join(target, "newdir")) + + # Limit recursive by maxdepth + fs.get( + fs_join(source, "subdir", glob), t, recursive=recursive, maxdepth=1 + ) + assert local_fs.isdir(local_join(target, "newdir")) + assert local_fs.isfile(local_join(target, "newdir", "subfile1")) + assert local_fs.isfile(local_join(target, "newdir", "subfile2")) + assert not local_fs.exists(local_join(target, "newdir", "nesteddir")) + assert not local_fs.exists(local_join(target, "subdir")) + assert not local_fs.exists(local_join(target, "newdir", "subdir")) + + local_fs.rm(local_fs.ls(target, detail=False), recursive=True) + assert not local_fs.exists(local_join(target, "newdir")) + + @pytest.mark.parametrize( + GLOB_EDGE_CASES_TESTS["argnames"], + GLOB_EDGE_CASES_TESTS["argvalues"], + ) + def test_get_glob_edge_cases( + self, + path, + recursive, + maxdepth, + expected, + fs, + fs_join, + fs_glob_edge_cases_files, + local_fs, + local_join, + local_target, + ): + # Copy scenario 1g + source = fs_glob_edge_cases_files + + target = local_target + + for new_dir, target_slash in product([True, False], [True, False]): + local_fs.mkdir(target) + + t = local_join(target, "newdir") if new_dir else target + t = t + "/" if target_slash else t + + fs.get(fs_join(source, path), t, recursive=recursive, maxdepth=maxdepth) + + output = local_fs.find(target) + if new_dir: + prefixed_expected = [ + make_path_posix(local_join(target, "newdir", p)) for p in expected + ] + else: + prefixed_expected = [ + make_path_posix(local_join(target, p)) for p in expected + ] + assert sorted(output) == sorted(prefixed_expected) + + try: + local_fs.rm(target, recursive=True) + except FileNotFoundError: + pass + + def test_get_list_of_files_to_existing_directory( + self, + fs, + fs_join, + fs_bulk_operations_scenario_0, + local_fs, + local_join, + local_target, + ): + # Copy scenario 2a + source = fs_bulk_operations_scenario_0 + + target = local_target + local_fs.mkdir(target) + + source_files = [ + fs_join(source, "file1"), + fs_join(source, "file2"), + fs_join(source, "subdir", "subfile1"), + ] + + for target_slash in [False, True]: + t = target + "/" if target_slash else target + + fs.get(source_files, t) + assert local_fs.isfile(local_join(target, "file1")) + assert local_fs.isfile(local_join(target, "file2")) + assert local_fs.isfile(local_join(target, "subfile1")) + + local_fs.rm( + [ + local_join(target, "file1"), + local_join(target, "file2"), + local_join(target, "subfile1"), + ], + recursive=True, + ) + assert local_fs.ls(target) == [] + + def test_get_list_of_files_to_new_directory( + self, + fs, + fs_join, + fs_bulk_operations_scenario_0, + local_fs, + local_join, + local_target, + ): + # Copy scenario 2b + source = fs_bulk_operations_scenario_0 + + target = local_target + local_fs.mkdir(target) + + source_files = [ + fs_join(source, "file1"), + fs_join(source, "file2"), + fs_join(source, "subdir", "subfile1"), + ] + + fs.get(source_files, local_join(target, "newdir") + "/") # Note trailing slash + assert local_fs.isdir(local_join(target, "newdir")) + assert local_fs.isfile(local_join(target, "newdir", "file1")) + assert local_fs.isfile(local_join(target, "newdir", "file2")) + assert local_fs.isfile(local_join(target, "newdir", "subfile1")) + + def test_get_directory_recursive( + self, fs, fs_join, fs_path, local_fs, local_join, local_target + ): + # https://github.com/fsspec/filesystem_spec/issues/1062 + # Recursive cp/get/put of source directory into non-existent target directory. + src = fs_join(fs_path, "src") + src_file = fs_join(src, "file") + fs.mkdir(src) + fs.touch(src_file) + + target = local_target + + # get without slash + assert not local_fs.exists(target) + for loop in range(2): + fs.get(src, target, recursive=True) + assert local_fs.isdir(target) + + if loop == 0: + assert local_fs.isfile(local_join(target, "file")) + assert not local_fs.exists(local_join(target, "src")) + else: + assert local_fs.isfile(local_join(target, "file")) + assert local_fs.isdir(local_join(target, "src")) + assert local_fs.isfile(local_join(target, "src", "file")) + + local_fs.rm(target, recursive=True) + + # get with slash + assert not local_fs.exists(target) + for loop in range(2): + fs.get(src + "/", target, recursive=True) + assert local_fs.isdir(target) + assert local_fs.isfile(local_join(target, "file")) + assert not local_fs.exists(local_join(target, "src")) + + def test_get_directory_without_files_with_same_name_prefix( + self, + fs, + fs_join, + local_fs, + local_join, + local_target, + fs_dir_and_file_with_same_name_prefix, + ): + # Create the test dirs + source = fs_dir_and_file_with_same_name_prefix + target = local_target + + # Test without glob + fs.get(fs_join(source, "subdir"), target, recursive=True) + + assert local_fs.isfile(local_join(target, "subfile.txt")) + assert not local_fs.isfile(local_join(target, "subdir.txt")) + + local_fs.rm([local_join(target, "subfile.txt")]) + assert local_fs.ls(target) == [] + + # Test with glob + fs.get(fs_join(source, "subdir*"), target, recursive=True) + + assert local_fs.isdir(local_join(target, "subdir")) + assert local_fs.isfile(local_join(target, "subdir", "subfile.txt")) + assert local_fs.isfile(local_join(target, "subdir.txt")) + + def test_get_with_source_and_destination_as_list( + self, + fs, + fs_join, + local_fs, + local_join, + local_target, + fs_10_files_with_hashed_names, + ): + # Create the test dir + source = fs_10_files_with_hashed_names + target = local_target + + # Create list of files for source and destination + source_files = [] + destination_files = [] + for i in range(10): + hashed_i = md5(str(i).encode("utf-8")).hexdigest() + source_files.append(fs_join(source, f"{hashed_i}.txt")) + destination_files.append( + make_path_posix(local_join(target, f"{hashed_i}.txt")) + ) + + # Copy and assert order was kept + fs.get(rpath=source_files, lpath=destination_files) + + for i in range(10): + file_content = local_fs.cat(destination_files[i]).decode("utf-8") + assert file_content == str(i) diff --git a/fsspec/tests/abstract/mv.py b/fsspec/tests/abstract/mv.py new file mode 100644 index 0000000000000000000000000000000000000000..39f6caa3de815e024fa84de2acecc986c823ed29 --- /dev/null +++ b/fsspec/tests/abstract/mv.py @@ -0,0 +1,57 @@ +import os + +import pytest + +import fsspec + + +def test_move_raises_error_with_tmpdir(tmpdir): + # Create a file in the temporary directory + source = tmpdir.join("source_file.txt") + source.write("content") + + # Define a destination that simulates a protected or invalid path + destination = tmpdir.join("non_existent_directory/destination_file.txt") + + # Instantiate the filesystem (assuming the local file system interface) + fs = fsspec.filesystem("file") + + # Use the actual file paths as string + with pytest.raises(FileNotFoundError): + fs.mv(str(source), str(destination)) + + +@pytest.mark.parametrize("recursive", (True, False)) +def test_move_raises_error_with_tmpdir_permission(recursive, tmpdir): + # Create a file in the temporary directory + source = tmpdir.join("source_file.txt") + source.write("content") + + # Create a protected directory (non-writable) + protected_dir = tmpdir.mkdir("protected_directory") + protected_path = str(protected_dir) + + # Set the directory to read-only + if os.name == "nt": + os.system(f'icacls "{protected_path}" /deny Everyone:(W)') + else: + os.chmod(protected_path, 0o555) # Sets the directory to read-only + + # Define a destination inside the protected directory + destination = protected_dir.join("destination_file.txt") + + # Instantiate the filesystem (assuming the local file system interface) + fs = fsspec.filesystem("file") + + # Try to move the file to the read-only directory, expecting a permission error + with pytest.raises(PermissionError): + fs.mv(str(source), str(destination), recursive=recursive) + + # Assert the file was not created in the destination + assert not os.path.exists(destination) + + # Cleanup: Restore permissions so the directory can be cleaned up + if os.name == "nt": + os.system(f'icacls "{protected_path}" /remove:d Everyone') + else: + os.chmod(protected_path, 0o755) # Restore write permission for cleanup diff --git a/fsspec/tests/abstract/open.py b/fsspec/tests/abstract/open.py new file mode 100644 index 0000000000000000000000000000000000000000..bb75ea852276fb8d834345883813b8e27a0ae24c --- /dev/null +++ b/fsspec/tests/abstract/open.py @@ -0,0 +1,11 @@ +import pytest + + +class AbstractOpenTests: + def test_open_exclusive(self, fs, fs_target): + with fs.open(fs_target, "wb") as f: + f.write(b"data") + with fs.open(fs_target, "rb") as f: + assert f.read() == b"data" + with pytest.raises(FileExistsError): + fs.open(fs_target, "xb") diff --git a/fsspec/tests/abstract/pipe.py b/fsspec/tests/abstract/pipe.py new file mode 100644 index 0000000000000000000000000000000000000000..8ecca96e9d23ff268a253c48269d5cca451ea270 --- /dev/null +++ b/fsspec/tests/abstract/pipe.py @@ -0,0 +1,11 @@ +import pytest + + +class AbstractPipeTests: + def test_pipe_exclusive(self, fs, fs_target): + fs.pipe_file(fs_target, b"data") + assert fs.cat_file(fs_target) == b"data" + with pytest.raises(FileExistsError): + fs.pipe_file(fs_target, b"data", mode="create") + fs.pipe_file(fs_target, b"new data", mode="overwrite") + assert fs.cat_file(fs_target) == b"new data" diff --git a/fsspec/tests/abstract/put.py b/fsspec/tests/abstract/put.py new file mode 100644 index 0000000000000000000000000000000000000000..9fc349977f0384d9fc86126498be5c6ad99a21d3 --- /dev/null +++ b/fsspec/tests/abstract/put.py @@ -0,0 +1,591 @@ +from hashlib import md5 +from itertools import product + +import pytest + +from fsspec.tests.abstract.common import GLOB_EDGE_CASES_TESTS + + +class AbstractPutTests: + def test_put_file_to_existing_directory( + self, + fs, + fs_join, + fs_target, + local_join, + local_bulk_operations_scenario_0, + supports_empty_directories, + ): + # Copy scenario 1a + source = local_bulk_operations_scenario_0 + + target = fs_target + fs.mkdir(target) + if not supports_empty_directories: + # Force target directory to exist by adding a dummy file + fs.touch(fs_join(target, "dummy")) + assert fs.isdir(target) + + target_file2 = fs_join(target, "file2") + target_subfile1 = fs_join(target, "subfile1") + + # Copy from source directory + fs.put(local_join(source, "file2"), target) + assert fs.isfile(target_file2) + + # Copy from sub directory + fs.put(local_join(source, "subdir", "subfile1"), target) + assert fs.isfile(target_subfile1) + + # Remove copied files + fs.rm([target_file2, target_subfile1]) + assert not fs.exists(target_file2) + assert not fs.exists(target_subfile1) + + # Repeat with trailing slash on target + fs.put(local_join(source, "file2"), target + "/") + assert fs.isdir(target) + assert fs.isfile(target_file2) + + fs.put(local_join(source, "subdir", "subfile1"), target + "/") + assert fs.isfile(target_subfile1) + + def test_put_file_to_new_directory( + self, fs, fs_join, fs_target, local_join, local_bulk_operations_scenario_0 + ): + # Copy scenario 1b + source = local_bulk_operations_scenario_0 + + target = fs_target + fs.mkdir(target) + + fs.put( + local_join(source, "subdir", "subfile1"), fs_join(target, "newdir/") + ) # Note trailing slash + assert fs.isdir(target) + assert fs.isdir(fs_join(target, "newdir")) + assert fs.isfile(fs_join(target, "newdir", "subfile1")) + + def test_put_file_to_file_in_existing_directory( + self, + fs, + fs_join, + fs_target, + local_join, + supports_empty_directories, + local_bulk_operations_scenario_0, + ): + # Copy scenario 1c + source = local_bulk_operations_scenario_0 + + target = fs_target + fs.mkdir(target) + if not supports_empty_directories: + # Force target directory to exist by adding a dummy file + fs.touch(fs_join(target, "dummy")) + assert fs.isdir(target) + + fs.put(local_join(source, "subdir", "subfile1"), fs_join(target, "newfile")) + assert fs.isfile(fs_join(target, "newfile")) + + def test_put_file_to_file_in_new_directory( + self, fs, fs_join, fs_target, local_join, local_bulk_operations_scenario_0 + ): + # Copy scenario 1d + source = local_bulk_operations_scenario_0 + + target = fs_target + fs.mkdir(target) + + fs.put( + local_join(source, "subdir", "subfile1"), + fs_join(target, "newdir", "newfile"), + ) + assert fs.isdir(fs_join(target, "newdir")) + assert fs.isfile(fs_join(target, "newdir", "newfile")) + + def test_put_directory_to_existing_directory( + self, + fs, + fs_join, + fs_target, + local_bulk_operations_scenario_0, + supports_empty_directories, + ): + # Copy scenario 1e + source = local_bulk_operations_scenario_0 + + target = fs_target + fs.mkdir(target) + if not supports_empty_directories: + # Force target directory to exist by adding a dummy file + dummy = fs_join(target, "dummy") + fs.touch(dummy) + assert fs.isdir(target) + + for source_slash, target_slash in zip([False, True], [False, True]): + s = fs_join(source, "subdir") + if source_slash: + s += "/" + t = target + "/" if target_slash else target + + # Without recursive does nothing + fs.put(s, t) + assert fs.ls(target, detail=False) == ( + [] if supports_empty_directories else [dummy] + ) + + # With recursive + fs.put(s, t, recursive=True) + if source_slash: + assert fs.isfile(fs_join(target, "subfile1")) + assert fs.isfile(fs_join(target, "subfile2")) + assert fs.isdir(fs_join(target, "nesteddir")) + assert fs.isfile(fs_join(target, "nesteddir", "nestedfile")) + assert not fs.exists(fs_join(target, "subdir")) + + fs.rm( + [ + fs_join(target, "subfile1"), + fs_join(target, "subfile2"), + fs_join(target, "nesteddir"), + ], + recursive=True, + ) + else: + assert fs.isdir(fs_join(target, "subdir")) + assert fs.isfile(fs_join(target, "subdir", "subfile1")) + assert fs.isfile(fs_join(target, "subdir", "subfile2")) + assert fs.isdir(fs_join(target, "subdir", "nesteddir")) + assert fs.isfile(fs_join(target, "subdir", "nesteddir", "nestedfile")) + + fs.rm(fs_join(target, "subdir"), recursive=True) + assert fs.ls(target, detail=False) == ( + [] if supports_empty_directories else [dummy] + ) + + # Limit recursive by maxdepth + fs.put(s, t, recursive=True, maxdepth=1) + if source_slash: + assert fs.isfile(fs_join(target, "subfile1")) + assert fs.isfile(fs_join(target, "subfile2")) + assert not fs.exists(fs_join(target, "nesteddir")) + assert not fs.exists(fs_join(target, "subdir")) + + fs.rm( + [ + fs_join(target, "subfile1"), + fs_join(target, "subfile2"), + ], + recursive=True, + ) + else: + assert fs.isdir(fs_join(target, "subdir")) + assert fs.isfile(fs_join(target, "subdir", "subfile1")) + assert fs.isfile(fs_join(target, "subdir", "subfile2")) + assert not fs.exists(fs_join(target, "subdir", "nesteddir")) + + fs.rm(fs_join(target, "subdir"), recursive=True) + assert fs.ls(target, detail=False) == ( + [] if supports_empty_directories else [dummy] + ) + + def test_put_directory_to_new_directory( + self, + fs, + fs_join, + fs_target, + local_bulk_operations_scenario_0, + supports_empty_directories, + ): + # Copy scenario 1f + source = local_bulk_operations_scenario_0 + + target = fs_target + fs.mkdir(target) + + for source_slash, target_slash in zip([False, True], [False, True]): + s = fs_join(source, "subdir") + if source_slash: + s += "/" + t = fs_join(target, "newdir") + if target_slash: + t += "/" + + # Without recursive does nothing + fs.put(s, t) + if supports_empty_directories: + assert fs.ls(target) == [] + else: + with pytest.raises(FileNotFoundError): + fs.ls(target) + + # With recursive + fs.put(s, t, recursive=True) + assert fs.isdir(fs_join(target, "newdir")) + assert fs.isfile(fs_join(target, "newdir", "subfile1")) + assert fs.isfile(fs_join(target, "newdir", "subfile2")) + assert fs.isdir(fs_join(target, "newdir", "nesteddir")) + assert fs.isfile(fs_join(target, "newdir", "nesteddir", "nestedfile")) + assert not fs.exists(fs_join(target, "subdir")) + + fs.rm(fs_join(target, "newdir"), recursive=True) + assert not fs.exists(fs_join(target, "newdir")) + + # Limit recursive by maxdepth + fs.put(s, t, recursive=True, maxdepth=1) + assert fs.isdir(fs_join(target, "newdir")) + assert fs.isfile(fs_join(target, "newdir", "subfile1")) + assert fs.isfile(fs_join(target, "newdir", "subfile2")) + assert not fs.exists(fs_join(target, "newdir", "nesteddir")) + assert not fs.exists(fs_join(target, "subdir")) + + fs.rm(fs_join(target, "newdir"), recursive=True) + assert not fs.exists(fs_join(target, "newdir")) + + def test_put_glob_to_existing_directory( + self, + fs, + fs_join, + fs_target, + local_join, + supports_empty_directories, + local_bulk_operations_scenario_0, + ): + # Copy scenario 1g + source = local_bulk_operations_scenario_0 + + target = fs_target + fs.mkdir(target) + if not supports_empty_directories: + # Force target directory to exist by adding a dummy file + dummy = fs_join(target, "dummy") + fs.touch(dummy) + assert fs.isdir(target) + + for target_slash in [False, True]: + t = target + "/" if target_slash else target + + # Without recursive + fs.put(local_join(source, "subdir", "*"), t) + assert fs.isfile(fs_join(target, "subfile1")) + assert fs.isfile(fs_join(target, "subfile2")) + assert not fs.isdir(fs_join(target, "nesteddir")) + assert not fs.exists(fs_join(target, "nesteddir", "nestedfile")) + assert not fs.exists(fs_join(target, "subdir")) + + fs.rm( + [ + fs_join(target, "subfile1"), + fs_join(target, "subfile2"), + ], + recursive=True, + ) + assert fs.ls(target, detail=False) == ( + [] if supports_empty_directories else [dummy] + ) + + # With recursive + for glob, recursive in zip(["*", "**"], [True, False]): + fs.put(local_join(source, "subdir", glob), t, recursive=recursive) + assert fs.isfile(fs_join(target, "subfile1")) + assert fs.isfile(fs_join(target, "subfile2")) + assert fs.isdir(fs_join(target, "nesteddir")) + assert fs.isfile(fs_join(target, "nesteddir", "nestedfile")) + assert not fs.exists(fs_join(target, "subdir")) + + fs.rm( + [ + fs_join(target, "subfile1"), + fs_join(target, "subfile2"), + fs_join(target, "nesteddir"), + ], + recursive=True, + ) + assert fs.ls(target, detail=False) == ( + [] if supports_empty_directories else [dummy] + ) + + # Limit recursive by maxdepth + fs.put( + local_join(source, "subdir", glob), + t, + recursive=recursive, + maxdepth=1, + ) + assert fs.isfile(fs_join(target, "subfile1")) + assert fs.isfile(fs_join(target, "subfile2")) + assert not fs.exists(fs_join(target, "nesteddir")) + assert not fs.exists(fs_join(target, "subdir")) + + fs.rm( + [ + fs_join(target, "subfile1"), + fs_join(target, "subfile2"), + ], + recursive=True, + ) + assert fs.ls(target, detail=False) == ( + [] if supports_empty_directories else [dummy] + ) + + def test_put_glob_to_new_directory( + self, fs, fs_join, fs_target, local_join, local_bulk_operations_scenario_0 + ): + # Copy scenario 1h + source = local_bulk_operations_scenario_0 + + target = fs_target + fs.mkdir(target) + + for target_slash in [False, True]: + t = fs_join(target, "newdir") + if target_slash: + t += "/" + + # Without recursive + fs.put(local_join(source, "subdir", "*"), t) + assert fs.isdir(fs_join(target, "newdir")) + assert fs.isfile(fs_join(target, "newdir", "subfile1")) + assert fs.isfile(fs_join(target, "newdir", "subfile2")) + assert not fs.exists(fs_join(target, "newdir", "nesteddir")) + assert not fs.exists(fs_join(target, "newdir", "nesteddir", "nestedfile")) + assert not fs.exists(fs_join(target, "subdir")) + assert not fs.exists(fs_join(target, "newdir", "subdir")) + + fs.rm(fs_join(target, "newdir"), recursive=True) + assert not fs.exists(fs_join(target, "newdir")) + + # With recursive + for glob, recursive in zip(["*", "**"], [True, False]): + fs.put(local_join(source, "subdir", glob), t, recursive=recursive) + assert fs.isdir(fs_join(target, "newdir")) + assert fs.isfile(fs_join(target, "newdir", "subfile1")) + assert fs.isfile(fs_join(target, "newdir", "subfile2")) + assert fs.isdir(fs_join(target, "newdir", "nesteddir")) + assert fs.isfile(fs_join(target, "newdir", "nesteddir", "nestedfile")) + assert not fs.exists(fs_join(target, "subdir")) + assert not fs.exists(fs_join(target, "newdir", "subdir")) + + fs.rm(fs_join(target, "newdir"), recursive=True) + assert not fs.exists(fs_join(target, "newdir")) + + # Limit recursive by maxdepth + fs.put( + local_join(source, "subdir", glob), + t, + recursive=recursive, + maxdepth=1, + ) + assert fs.isdir(fs_join(target, "newdir")) + assert fs.isfile(fs_join(target, "newdir", "subfile1")) + assert fs.isfile(fs_join(target, "newdir", "subfile2")) + assert not fs.exists(fs_join(target, "newdir", "nesteddir")) + assert not fs.exists(fs_join(target, "subdir")) + assert not fs.exists(fs_join(target, "newdir", "subdir")) + + fs.rm(fs_join(target, "newdir"), recursive=True) + assert not fs.exists(fs_join(target, "newdir")) + + @pytest.mark.parametrize( + GLOB_EDGE_CASES_TESTS["argnames"], + GLOB_EDGE_CASES_TESTS["argvalues"], + ) + def test_put_glob_edge_cases( + self, + path, + recursive, + maxdepth, + expected, + fs, + fs_join, + fs_target, + local_glob_edge_cases_files, + local_join, + fs_sanitize_path, + ): + # Copy scenario 1g + source = local_glob_edge_cases_files + + target = fs_target + + for new_dir, target_slash in product([True, False], [True, False]): + fs.mkdir(target) + + t = fs_join(target, "newdir") if new_dir else target + t = t + "/" if target_slash else t + + fs.put(local_join(source, path), t, recursive=recursive, maxdepth=maxdepth) + + output = fs.find(target) + if new_dir: + prefixed_expected = [ + fs_sanitize_path(fs_join(target, "newdir", p)) for p in expected + ] + else: + prefixed_expected = [ + fs_sanitize_path(fs_join(target, p)) for p in expected + ] + assert sorted(output) == sorted(prefixed_expected) + + try: + fs.rm(target, recursive=True) + except FileNotFoundError: + pass + + def test_put_list_of_files_to_existing_directory( + self, + fs, + fs_join, + fs_target, + local_join, + local_bulk_operations_scenario_0, + supports_empty_directories, + ): + # Copy scenario 2a + source = local_bulk_operations_scenario_0 + + target = fs_target + fs.mkdir(target) + if not supports_empty_directories: + # Force target directory to exist by adding a dummy file + dummy = fs_join(target, "dummy") + fs.touch(dummy) + assert fs.isdir(target) + + source_files = [ + local_join(source, "file1"), + local_join(source, "file2"), + local_join(source, "subdir", "subfile1"), + ] + + for target_slash in [False, True]: + t = target + "/" if target_slash else target + + fs.put(source_files, t) + assert fs.isfile(fs_join(target, "file1")) + assert fs.isfile(fs_join(target, "file2")) + assert fs.isfile(fs_join(target, "subfile1")) + + fs.rm( + [ + fs_join(target, "file1"), + fs_join(target, "file2"), + fs_join(target, "subfile1"), + ], + recursive=True, + ) + assert fs.ls(target, detail=False) == ( + [] if supports_empty_directories else [dummy] + ) + + def test_put_list_of_files_to_new_directory( + self, fs, fs_join, fs_target, local_join, local_bulk_operations_scenario_0 + ): + # Copy scenario 2b + source = local_bulk_operations_scenario_0 + + target = fs_target + fs.mkdir(target) + + source_files = [ + local_join(source, "file1"), + local_join(source, "file2"), + local_join(source, "subdir", "subfile1"), + ] + + fs.put(source_files, fs_join(target, "newdir") + "/") # Note trailing slash + assert fs.isdir(fs_join(target, "newdir")) + assert fs.isfile(fs_join(target, "newdir", "file1")) + assert fs.isfile(fs_join(target, "newdir", "file2")) + assert fs.isfile(fs_join(target, "newdir", "subfile1")) + + def test_put_directory_recursive( + self, fs, fs_join, fs_target, local_fs, local_join, local_path + ): + # https://github.com/fsspec/filesystem_spec/issues/1062 + # Recursive cp/get/put of source directory into non-existent target directory. + src = local_join(local_path, "src") + src_file = local_join(src, "file") + local_fs.mkdir(src) + local_fs.touch(src_file) + + target = fs_target + + # put without slash + assert not fs.exists(target) + for loop in range(2): + fs.put(src, target, recursive=True) + assert fs.isdir(target) + + if loop == 0: + assert fs.isfile(fs_join(target, "file")) + assert not fs.exists(fs_join(target, "src")) + else: + assert fs.isfile(fs_join(target, "file")) + assert fs.isdir(fs_join(target, "src")) + assert fs.isfile(fs_join(target, "src", "file")) + + fs.rm(target, recursive=True) + + # put with slash + assert not fs.exists(target) + for loop in range(2): + fs.put(src + "/", target, recursive=True) + assert fs.isdir(target) + assert fs.isfile(fs_join(target, "file")) + assert not fs.exists(fs_join(target, "src")) + + def test_put_directory_without_files_with_same_name_prefix( + self, + fs, + fs_join, + fs_target, + local_join, + local_dir_and_file_with_same_name_prefix, + supports_empty_directories, + ): + # Create the test dirs + source = local_dir_and_file_with_same_name_prefix + target = fs_target + + # Test without glob + fs.put(local_join(source, "subdir"), fs_target, recursive=True) + + assert fs.isfile(fs_join(fs_target, "subfile.txt")) + assert not fs.isfile(fs_join(fs_target, "subdir.txt")) + + fs.rm([fs_join(target, "subfile.txt")]) + if supports_empty_directories: + assert fs.ls(target) == [] + else: + assert not fs.exists(target) + + # Test with glob + fs.put(local_join(source, "subdir*"), fs_target, recursive=True) + + assert fs.isdir(fs_join(fs_target, "subdir")) + assert fs.isfile(fs_join(fs_target, "subdir", "subfile.txt")) + assert fs.isfile(fs_join(fs_target, "subdir.txt")) + + def test_copy_with_source_and_destination_as_list( + self, fs, fs_target, fs_join, local_join, local_10_files_with_hashed_names + ): + # Create the test dir + source = local_10_files_with_hashed_names + target = fs_target + + # Create list of files for source and destination + source_files = [] + destination_files = [] + for i in range(10): + hashed_i = md5(str(i).encode("utf-8")).hexdigest() + source_files.append(local_join(source, f"{hashed_i}.txt")) + destination_files.append(fs_join(target, f"{hashed_i}.txt")) + + # Copy and assert order was kept + fs.put(lpath=source_files, rpath=destination_files) + + for i in range(10): + file_content = fs.cat(destination_files[i]).decode("utf-8") + assert file_content == str(i) diff --git a/huggingface_hub/_hot_reload/__init__.py b/huggingface_hub/_hot_reload/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ef08599ad40fcb6d55fac1c397597f68b8b52084 --- /dev/null +++ b/huggingface_hub/_hot_reload/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/huggingface_hub/_hot_reload/client.py b/huggingface_hub/_hot_reload/client.py new file mode 100644 index 0000000000000000000000000000000000000000..d65f0731d1ee907e183e3d66dd9ea02a2a28402f --- /dev/null +++ b/huggingface_hub/_hot_reload/client.py @@ -0,0 +1,103 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import json +from collections import deque +from typing import Iterator, Literal, Optional, TypedDict, Union + +import httpx + +from ..utils._headers import build_hf_headers +from ..utils._http import hf_raise_for_status +from .sse_client import SSEClient +from .types import ApiGetReloadEventSourceData, ApiGetReloadRequest + + +HOT_RELOADING_PORT = 7887 + + +class MultiReplicaStreamEvent(TypedDict): + kind: Literal["event"] + event: ApiGetReloadEventSourceData + + +class MultiReplicaStreamReplicaHash(TypedDict): + kind: Literal["replicaHash"] + hash: str + + +class MultiReplicaStreamFullMatch(TypedDict): + kind: Literal["fullMatch"] + + +class ReloadClient: + def __init__( + self, + *, + host: str, + subdomain: str, + replica_hash: str, + token: Optional[str], + ): + base_host = host.replace(subdomain, f"{subdomain}--{HOT_RELOADING_PORT}") + self.replica_hash = replica_hash + self.client = httpx.Client( + base_url=f"{base_host}/--replicas/+{replica_hash}", + headers=build_hf_headers(token=token), + ) + + def get_reload(self, reload_id: str) -> Iterator[ApiGetReloadEventSourceData]: + req = ApiGetReloadRequest(reloadId=reload_id) + with self.client.stream("POST", "/get-reload", json=req) as res: + hf_raise_for_status(res) + for event in SSEClient(res.iter_bytes()).events(): + if event.event == "message": + yield json.loads(event.data) + + +def multi_replica_reload_events( + commit_sha: str, + host: str, + subdomain: str, + replica_hashes: list[str], + token: Optional[str], +) -> Iterator[Union[MultiReplicaStreamEvent, MultiReplicaStreamReplicaHash, MultiReplicaStreamFullMatch]]: + clients = [ + ReloadClient( + host=host, + subdomain=subdomain, + replica_hash=hash, + token=token, + ) + for hash in replica_hashes + ] + + first_client_events: dict[int, ApiGetReloadEventSourceData] = {} + for client_index, client in enumerate(clients): + if len(clients) > 1: + yield {"kind": "replicaHash", "hash": client.replica_hash} + full_match = True + replay: deque[ApiGetReloadEventSourceData] = deque() + for event_index, event in enumerate(client.get_reload(commit_sha)): + if client_index == 0: + first_client_events[event_index] = event + elif full_match := full_match and first_client_events.get(event_index) == event: + replay.append(event) + continue + while replay: + yield {"kind": "event", "event": replay.popleft()} + yield {"kind": "event", "event": event} + if client_index > 0 and full_match: + yield {"kind": "fullMatch"} diff --git a/huggingface_hub/_hot_reload/sse_client.py b/huggingface_hub/_hot_reload/sse_client.py new file mode 100644 index 0000000000000000000000000000000000000000..0798a04028c9e70252284ee8f87197cbd7d9b1d0 --- /dev/null +++ b/huggingface_hub/_hot_reload/sse_client.py @@ -0,0 +1,144 @@ +""" +Vendored file: Server Side Events (SSE) client for Python. + +Source: +- Author: Maxime Petazzoni +- Repository: https://github.com/mpetazzoni/sseclient +- File: https://github.com/mpetazzoni/sseclient/blob/main/sseclient/__init__.py + +License: +- Apache-2.0 (from upstream project) + +Provides a generator of SSE received through an existing HTTP response. +""" + +import logging + +__author__ = 'Maxime Petazzoni ' +__email__ = 'maxime.petazzoni@bulix.org' +__all__ = ['SSEClient'] + +_FIELD_SEPARATOR = ':' + + +class SSEClient(object): + """Implementation of a SSE client. + + See http://www.w3.org/TR/2009/WD-eventsource-20091029/ for the + specification. + """ + + def __init__(self, event_source, char_enc='utf-8'): + """Initialize the SSE client over an existing, ready to consume + event source. + + The event source is expected to be a binary stream and have a close() + method. That would usually be something that implements + io.BinaryIOBase, like an httplib or urllib3 HTTPResponse object. + """ + self._logger = logging.getLogger(self.__class__.__module__) + self._logger.debug('Initialized SSE client from event source %s', + event_source) + self._event_source = event_source + self._char_enc = char_enc + + def _read(self): + """Read the incoming event source stream and yield event chunks. + + Unfortunately it is possible for some servers to decide to break an + event into multiple HTTP chunks in the response. It is thus necessary + to correctly stitch together consecutive response chunks and find the + SSE delimiter (empty new line) to yield full, correct event chunks.""" + data = b'' + for chunk in self._event_source: + for line in chunk.splitlines(True): + data += line + if data.endswith((b'\r\r', b'\n\n', b'\r\n\r\n')): + yield data + data = b'' + if data: + yield data + + def events(self): + for chunk in self._read(): + event = Event() + # Split before decoding so splitlines() only uses \r and \n + for line in chunk.splitlines(): + # Decode the line. + line = line.decode(self._char_enc) + + # Lines starting with a separator are comments and are to be + # ignored. + if not line.strip() or line.startswith(_FIELD_SEPARATOR): + continue + + data = line.split(_FIELD_SEPARATOR, 1) + field = data[0] + + # Ignore unknown fields. + if field not in event.__dict__: + self._logger.debug('Saw invalid field %s while parsing ' + 'Server Side Event', field) + continue + + if len(data) > 1: + # From the spec: + # "If value starts with a single U+0020 SPACE character, + # remove it from value." + if data[1].startswith(' '): + value = data[1][1:] + else: + value = data[1] + else: + # If no value is present after the separator, + # assume an empty value. + value = '' + + # The data field may come over multiple lines and their values + # are concatenated with each other. + if field == 'data': + event.__dict__[field] += value + '\n' + else: + event.__dict__[field] = value + + # Events with no data are not dispatched. + if not event.data: + continue + + # If the data field ends with a newline, remove it. + if event.data.endswith('\n'): + event.data = event.data[0:-1] + + # Empty event names default to 'message' + event.event = event.event or 'message' + + # Dispatch the event + self._logger.debug('Dispatching %s...', event) + yield event + + def close(self): + """Manually close the event source stream.""" + self._event_source.close() + + +class Event(object): + """Representation of an event from the event stream.""" + + def __init__(self, id=None, event='message', data='', retry=None): + self.id = id + self.event = event + self.data = data + self.retry = retry + + def __str__(self): + s = '{0} event'.format(self.event) + if self.id: + s += ' #{0}'.format(self.id) + if self.data: + s += ', {0} byte{1}'.format(len(self.data), + 's' if len(self.data) else '') + else: + s += ', no data' + if self.retry: + s += ', retry in {0}ms'.format(self.retry) + return s diff --git a/huggingface_hub/_hot_reload/types.py b/huggingface_hub/_hot_reload/types.py new file mode 100644 index 0000000000000000000000000000000000000000..c2a017de48607f619aab69b4cdf9a7b296611581 --- /dev/null +++ b/huggingface_hub/_hot_reload/types.py @@ -0,0 +1,115 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Literal, TypedDict, Union + +from typing_extensions import NotRequired + + +class ReloadRegion(TypedDict): + startLine: int + startCol: int + endLine: int + endCol: int + + +class ReloadOperationObject(TypedDict): + kind: Literal["add", "update", "delete"] + region: ReloadRegion + objectType: str + objectName: str + + +class ReloadOperationRun(TypedDict): + kind: Literal["run"] + region: ReloadRegion + codeLines: str + stdout: NotRequired[str] + stderr: NotRequired[str] + + +class ReloadOperationException(TypedDict): + kind: Literal["exception"] + region: ReloadRegion + traceback: str + + +class ReloadOperationError(TypedDict): + kind: Literal["error"] + traceback: str + + +class ReloadOperationUI(TypedDict): + kind: Literal["ui"] + updated: bool + + +class ApiCreateReloadRequest(TypedDict): + filepath: str + contents: str + reloadId: NotRequired[str] + + +class ApiCreateReloadResponseSuccess(TypedDict): + status: Literal["created"] + reloadId: str + + +class ApiCreateReloadResponseError(TypedDict): + status: Literal["alreadyReloading", "fileNotFound"] + + +class ApiCreateReloadResponse(TypedDict): + res: Union[ApiCreateReloadResponseError, ApiCreateReloadResponseSuccess] + + +class ApiGetReloadRequest(TypedDict): + reloadId: str + + +class ApiGetReloadEventSourceData(TypedDict): + data: Union[ + ReloadOperationError, + ReloadOperationException, + ReloadOperationObject, + ReloadOperationRun, + ReloadOperationUI, + ] + + +class ApiGetStatusRequest(TypedDict): + revision: str + + +class ApiGetStatusResponse(TypedDict): + reloading: bool + uncommited: list[str] + + +class ApiFetchContentsRequest(TypedDict): + filepath: str + + +class ApiFetchContentsResponseError(TypedDict): + status: Literal["fileNotFound"] + + +class ApiFetchContentsResponseSuccess(TypedDict): + status: Literal["ok"] + contents: str + + +class ApiFetchContentsResponse(TypedDict): + res: Union[ApiFetchContentsResponseError, ApiFetchContentsResponseSuccess] diff --git a/huggingface_hub/cli/__init__.py b/huggingface_hub/cli/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8568c82be1c638c0ccd34d460fd8b0f73dcbec4e --- /dev/null +++ b/huggingface_hub/cli/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/huggingface_hub/cli/_cli_utils.py b/huggingface_hub/cli/_cli_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..710beab4c77e5cf3b5425f85e9ff602d49c6c65b --- /dev/null +++ b/huggingface_hub/cli/_cli_utils.py @@ -0,0 +1,607 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains CLI utilities (styling, helpers).""" + +import dataclasses +import datetime +import importlib.metadata +import json +import os +import re +import time +from enum import Enum +from pathlib import Path +from typing import TYPE_CHECKING, Annotated, Any, Callable, Literal, Optional, Sequence, Union, cast + +import click +import typer + +from huggingface_hub import __version__, constants +from huggingface_hub.utils import ANSI, get_session, hf_raise_for_status, installation_method, logging, tabulate + + +logger = logging.get_logger() + +# Arbitrary maximum length of a cell in a table output +_MAX_CELL_LENGTH = 35 + +if TYPE_CHECKING: + from huggingface_hub.hf_api import HfApi + + +def get_hf_api(token: Optional[str] = None) -> "HfApi": + # Import here to avoid circular import + from huggingface_hub.hf_api import HfApi + + return HfApi(token=token, library_name="huggingface-cli", library_version=__version__) + + +#### TYPER UTILS + +CLI_REFERENCE_URL = "https://huggingface.co/docs/huggingface_hub/en/guides/cli" + + +def generate_epilog(examples: list[str], docs_anchor: Optional[str] = None) -> str: + """Generate an epilog with examples and a Learn More section. + + Args: + examples: List of example commands (without the `$ ` prefix). + docs_anchor: Optional anchor for the docs URL (e.g., "#hf-download"). + + Returns: + Formatted epilog string. + """ + docs_url = f"{CLI_REFERENCE_URL}{docs_anchor}" if docs_anchor else CLI_REFERENCE_URL + examples_str = "\n".join(f" $ {ex}" for ex in examples) + return f"""\ +Examples +{examples_str} + +Learn more + Use `hf --help` for more information about a command. + Read the documentation at {docs_url} +""" + + +TOPIC_T = Union[Literal["main", "help"], str] +FallbackHandlerT = Callable[[list[str], set[str]], Optional[int]] + + +def _format_epilog_no_indent(epilog: Optional[str], ctx: click.Context, formatter: click.HelpFormatter) -> None: + """Write the epilog without indentation.""" + if epilog: + formatter.write_paragraph() + for line in epilog.split("\n"): + formatter.write_text(line) + + +_ALIAS_SPLIT = re.compile(r"\s*\|\s*") + + +class HFCliTyperGroup(typer.core.TyperGroup): + """ + Typer Group that: + - lists commands alphabetically within sections. + - separates commands by topic (main, help, etc.). + - formats epilog without extra indentation. + - supports aliases via pipe-separated names (e.g. ``name="list | ls"``). + - rewrites ``--json`` to ``--format json`` for commands that accept ``--format``. + """ + + def resolve_command(self, ctx: click.Context, args: list[str]) -> tuple: + # Rewrite hidden --json shorthand to --format json, but only for commands that accept --format. + # This avoids rewriting `--json` for commands that pass args through to external binaries + # (e.g. `hf extensions exec`) or that simply don't support `--format`. + if "--json" in args: + cmd_name = args[0] if args and not args[0].startswith("-") else None + cmd = self.get_command(ctx, cmd_name) if cmd_name else None + has_format_option = cmd is not None and any( + isinstance(param, click.Option) and "--format" in param.opts for param in cmd.params + ) + if has_format_option: + if any(arg == "--format" or arg.startswith("--format=") for arg in args): + raise click.UsageError("'--json' and '--format' are mutually exclusive.") + idx = args.index("--json") + args[idx : idx + 1] = ["--format", "json"] + return super().resolve_command(ctx, args) + + def get_command(self, ctx: click.Context, cmd_name: str) -> Optional[click.Command]: + # Try exact match first + cmd = super().get_command(ctx, cmd_name) + if cmd is not None: + return cmd + # Fall back to alias lookup: check if cmd_name matches any alias + # taken from https://github.com/fastapi/typer/issues/132#issuecomment-2417492805 + for registered_name, registered_cmd in self.commands.items(): + aliases = _ALIAS_SPLIT.split(registered_name) + if cmd_name in aliases: + return registered_cmd + return None + + def _alias_map(self) -> dict[str, list[str]]: + """Build a mapping from primary command name to its aliases (if any).""" + result: dict[str, list[str]] = {} + for registered_name in self.commands: + parts = _ALIAS_SPLIT.split(registered_name) + primary = parts[0] + result[primary] = parts[1:] + return result + + def format_commands(self, ctx: click.Context, formatter: click.HelpFormatter) -> None: + topics: dict[str, list] = {} + alias_map = self._alias_map() + + for name in self.list_commands(ctx): + cmd = self.get_command(ctx, name) + if cmd is None or cmd.hidden: + continue + help_text = cmd.get_short_help_str(limit=formatter.width) + aliases = alias_map.get(name, []) + if aliases: + help_text = f"{help_text} [alias: {', '.join(aliases)}]" + topic = getattr(cmd, "topic", "main") + topics.setdefault(topic, []).append((name, help_text)) + + with formatter.section("Main commands"): + formatter.write_dl(topics["main"]) + for topic in sorted(topics.keys()): + if topic == "main": + continue + with formatter.section(f"{topic.capitalize()} commands"): + formatter.write_dl(topics[topic]) + + def format_epilog(self, ctx: click.Context, formatter: click.HelpFormatter) -> None: + # Collect only the first example from each command (to keep group help concise) + # Full examples are shown in individual subcommand help (e.g. `hf buckets sync --help`) + all_examples: list[str] = [] + for name in self.list_commands(ctx): + cmd = self.get_command(ctx, name) + if cmd is None or cmd.hidden: + continue + cmd_examples = getattr(cmd, "examples", []) + if cmd_examples: + all_examples.append(cmd_examples[0]) + + if all_examples: + epilog = generate_epilog(all_examples) + _format_epilog_no_indent(epilog, ctx, formatter) + elif self.epilog: + _format_epilog_no_indent(self.epilog, ctx, formatter) + + def list_commands(self, ctx: click.Context) -> list[str]: # type: ignore[name-defined] + # For aliased commands ("list | ls"), use the primary name (first entry). + primary_names: list[str] = [] + for name in self.commands: + primary = _ALIAS_SPLIT.split(name)[0] + primary_names.append(primary) + return sorted(primary_names) + + +def fallback_typer_group_factory( + fallback_handler: FallbackHandlerT, + extra_commands_provider: Optional[Callable[[], list[tuple[str, str]]]] = None, +) -> type[HFCliTyperGroup]: + """Return a Typer group class that runs a fallback handler before command resolution.""" + + class FallbackTyperGroup(HFCliTyperGroup): + def resolve_command(self, ctx: click.Context, args: list[str]) -> tuple: + fallback_exit_code = fallback_handler(args, set(self.commands.keys())) + if fallback_exit_code is not None: + raise SystemExit(fallback_exit_code) + return super().resolve_command(ctx, args) + + def format_commands(self, ctx: click.Context, formatter: click.HelpFormatter) -> None: + super().format_commands(ctx, formatter) + if extra_commands_provider is not None: + entries = extra_commands_provider() + if entries: + with formatter.section("Extension commands"): + formatter.write_dl(entries) + + return FallbackTyperGroup + + +def HFCliCommand(topic: TOPIC_T, examples: Optional[list[str]] = None) -> type[typer.core.TyperCommand]: + def format_epilog(self: click.Command, ctx: click.Context, formatter: click.HelpFormatter) -> None: + _format_epilog_no_indent(self.epilog, ctx, formatter) + + return type( + f"TyperCommand{topic.capitalize()}", + (typer.core.TyperCommand,), + {"topic": topic, "examples": examples or [], "format_epilog": format_epilog}, + ) + + +class HFCliApp(typer.Typer): + """Custom Typer app for Hugging Face CLI.""" + + def command( # type: ignore[override] + self, + name: Optional[str] = None, + *, + topic: TOPIC_T = "main", + examples: Optional[list[str]] = None, + context_settings: Optional[dict[str, Any]] = None, + help: Optional[str] = None, + epilog: Optional[str] = None, + short_help: Optional[str] = None, + options_metavar: str = "[OPTIONS]", + add_help_option: bool = True, + no_args_is_help: bool = False, + hidden: bool = False, + deprecated: bool = False, + rich_help_panel: Optional[str] = None, + ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + # Generate epilog from examples if not explicitly provided + if epilog is None and examples: + epilog = generate_epilog(examples) + + def _inner(func: Callable[..., Any]) -> Callable[..., Any]: + return super(HFCliApp, self).command( + name, + cls=HFCliCommand(topic, examples), + context_settings=context_settings, + help=help, + epilog=epilog, + short_help=short_help, + options_metavar=options_metavar, + add_help_option=add_help_option, + no_args_is_help=no_args_is_help, + hidden=hidden, + deprecated=deprecated, + rich_help_panel=rich_help_panel, + )(func) + + return _inner + + +def typer_factory( + help: str, epilog: Optional[str] = None, cls: Optional[type[typer.core.TyperGroup]] = None +) -> "HFCliApp": + """Create a Typer app with consistent settings. + + Args: + help: Help text for the app. + epilog: Optional epilog text (use `generate_epilog` to create one). + cls: Optional Click group class to use (defaults to `HFCliTyperGroup`). + + Returns: + A configured Typer app. + """ + if cls is None: + cls = HFCliTyperGroup + return HFCliApp( + help=help, + epilog=epilog, + add_completion=True, + no_args_is_help=True, + cls=cls, + # Disable rich completely for consistent experience + rich_markup_mode=None, + rich_help_panel=None, + pretty_exceptions_enable=False, + # Increase max content width for better readability + context_settings={ + "max_content_width": 120, + "help_option_names": ["-h", "--help"], + }, + ) + + +class RepoType(str, Enum): + model = "model" + dataset = "dataset" + space = "space" + + +RepoIdArg = Annotated[ + str, + typer.Argument( + help="The ID of the repo (e.g. `username/repo-name`).", + ), +] + + +RepoTypeOpt = Annotated[ + RepoType, + typer.Option( + "--type", + "--repo-type", + help="The type of repository (model, dataset, or space).", + ), +] + +TokenOpt = Annotated[ + Optional[str], + typer.Option( + help="A User Access Token generated from https://huggingface.co/settings/tokens.", + ), +] + +PrivateOpt = Annotated[ + Optional[bool], + typer.Option( + help="Whether to create a private repo if repo doesn't exist on the Hub. Ignored if the repo already exists.", + ), +] + +RevisionOpt = Annotated[ + Optional[str], + typer.Option( + help="Git revision id which can be a branch name, a tag, or a commit hash.", + ), +] + + +LimitOpt = Annotated[ + int, + typer.Option(help="Limit the number of results."), +] + +AuthorOpt = Annotated[ + Optional[str], + typer.Option(help="Filter by author or organization."), +] + +FilterOpt = Annotated[ + Optional[list[str]], + typer.Option(help="Filter by tags (e.g. 'text-classification'). Can be used multiple times."), +] + +SearchOpt = Annotated[ + Optional[str], + typer.Option(help="Search query."), +] + + +class OutputFormat(str, Enum): + """Output format for CLI list commands.""" + + table = "table" + json = "json" + + +FormatOpt = Annotated[ + OutputFormat, + typer.Option( + help="Output format (table or json).", + ), +] + +QuietOpt = Annotated[ + bool, + typer.Option( + "-q", + "--quiet", + help="Print only IDs (one per line).", + ), +] + + +def _to_header(name: str) -> str: + """Convert a camelCase or PascalCase string to SCREAMING_SNAKE_CASE to be used as table header.""" + s = re.sub(r"([a-z])([A-Z])", r"\1_\2", name) + return s.upper() + + +def _format_value(value: Any) -> str: + """Convert a value to string for terminal display.""" + if not value: + return "" + if isinstance(value, bool): + return "✔" if value else "" + if isinstance(value, datetime.datetime): + return value.strftime("%Y-%m-%d") + if isinstance(value, str) and re.match(r"^\d{4}-\d{2}-\d{2}T", value): + return value[:10] + if isinstance(value, list): + return ", ".join(_format_value(v) for v in value) + elif isinstance(value, dict): + if "name" in value: # Likely to be a user or org => print name + return str(value["name"]) + # TODO: extend if needed + return json.dumps(value) + return str(value) + + +def _format_cell(value: Any, max_len: int = _MAX_CELL_LENGTH) -> str: + """Format a value + truncate it for table display.""" + cell = _format_value(value) + if len(cell) > max_len: + cell = cell[: max_len - 3] + "..." + return cell + + +def print_as_table( + items: Sequence[dict[str, Any]], + headers: list[str], + row_fn: Callable[[dict[str, Any]], list[str]], + alignments: Optional[dict[str, str]] = None, +) -> None: + """Print items as a formatted table. + + Args: + items: Sequence of dictionaries representing the items to display. + headers: List of column headers. + row_fn: Function that takes an item dict and returns a list of string values for each column. + alignments: Optional mapping of header name to "left" or "right". Defaults to "left". + """ + if not items: + print("No results found.") + return + rows = cast(list[list[Union[str, int]]], [row_fn(item) for item in items]) + screaming_headers = [_to_header(h) for h in headers] + # Remap alignments keys to screaming case to match tabulate headers + screaming_alignments = {_to_header(k): v for k, v in (alignments or {}).items()} + print(tabulate(rows, headers=screaming_headers, alignments=screaming_alignments)) + + +def print_list_output( + items: Sequence[dict[str, Any]], + format: OutputFormat, + quiet: bool, + id_key: str = "id", + headers: Optional[list[str]] = None, + row_fn: Optional[Callable[[dict[str, Any]], list[str]]] = None, + alignments: Optional[dict[str, str]] = None, +) -> None: + """Print list command output in the specified format. + + Args: + items: Sequence of dictionaries representing the items to display. + format: Output format (table or json). + quiet: If True, print only IDs (one per line). + id_key: Key to use for extracting IDs in quiet mode. + headers: Optional list of column names for headers. If not provided, auto-detected from keys. + row_fn: Optional function to extract row values. If not provided, uses _format_cell on each column. + alignments: Optional mapping of header name to "left" or "right". Defaults to "left". + """ + if quiet: + for item in items: + print(item[id_key]) + return + + if format == OutputFormat.json: + print(json.dumps(list(items), indent=2, default=str)) + return + + if headers is None: + all_columns = list(items[0].keys()) if items else [id_key] + headers = [col for col in all_columns if any(_format_cell(item.get(col)) for item in items)] + + if row_fn is None: + + def row_fn(item: dict[str, Any]) -> list[str]: + return [_format_cell(item.get(col)) for col in headers] # type: ignore[union-attr] + + print_as_table(items, headers=headers, row_fn=row_fn, alignments=alignments) + + +def _serialize_value(v: object) -> object: + """Recursively serialize a value to be JSON-compatible.""" + if isinstance(v, datetime.datetime): + return v.isoformat() + elif isinstance(v, dict): + return {key: _serialize_value(val) for key, val in v.items() if val is not None} + elif isinstance(v, list): + return [_serialize_value(item) for item in v] + return v + + +def api_object_to_dict(info: Any) -> dict[str, Any]: + """Convert repo info dataclasses to json-serializable dicts.""" + return {k: _serialize_value(v) for k, v in dataclasses.asdict(info).items() if v is not None} + + +def make_expand_properties_parser(valid_properties: list[str]): + """Create a callback to parse and validate comma-separated expand properties.""" + + def _parse_expand_properties(value: Optional[str]) -> Optional[list[str]]: + if value is None: + return None + properties = [p.strip() for p in value.split(",")] + for prop in properties: + if prop not in valid_properties: + raise typer.BadParameter( + f"Invalid expand property: '{prop}'. Valid values are: {', '.join(valid_properties)}" + ) + return properties + + return _parse_expand_properties + + +### PyPI VERSION CHECKER + + +def check_cli_update(library: Literal["huggingface_hub", "transformers"]) -> None: + """ + Check whether a newer version of a library is available on PyPI. + + If a newer version is found, notify the user and suggest updating. + If current version is a pre-release (e.g. `1.0.0.rc1`), or a dev version (e.g. `1.0.0.dev1`), no check is performed. + + This function is called at the entry point of the CLI. It only performs the check once every 24 hours, and any error + during the check is caught and logged, to avoid breaking the CLI. + + Args: + library: The library to check for updates. Currently supports "huggingface_hub" and "transformers". + """ + try: + _check_cli_update(library) + except Exception: + # We don't want the CLI to fail on version checks, no matter the reason. + logger.debug("Error while checking for CLI update.", exc_info=True) + + +def _check_cli_update(library: Literal["huggingface_hub", "transformers"]) -> None: + current_version = importlib.metadata.version(library) + + # Skip if current version is a pre-release or dev version + if any(tag in current_version for tag in ["rc", "dev"]): + return + + # Skip if already checked in the last 24 hours + if os.path.exists(constants.CHECK_FOR_UPDATE_DONE_PATH): + mtime = os.path.getmtime(constants.CHECK_FOR_UPDATE_DONE_PATH) + if (time.time() - mtime) < 24 * 3600: + return + + # Touch the file to mark that we did the check now + Path(constants.CHECK_FOR_UPDATE_DONE_PATH).parent.mkdir(parents=True, exist_ok=True) + Path(constants.CHECK_FOR_UPDATE_DONE_PATH).touch() + + # Check latest version from PyPI + response = get_session().get(f"https://pypi.org/pypi/{library}/json", timeout=2) + hf_raise_for_status(response) + data = response.json() + latest_version = data["info"]["version"] + + # If latest version is different from current, notify user + if current_version != latest_version: + if library == "huggingface_hub": + update_command = _get_huggingface_hub_update_command() + else: + update_command = _get_transformers_update_command() + + click.echo( + ANSI.yellow( + f"A new version of {library} ({latest_version}) is available! " + f"You are using version {current_version}.\n" + f"To update, run: {ANSI.bold(update_command)}\n", + ) + ) + + +def _get_huggingface_hub_update_command() -> str: + """Return the command to update huggingface_hub.""" + method = installation_method() + if method == "brew": + return "brew upgrade huggingface-cli" + elif method == "hf_installer" and os.name == "nt": + return 'powershell -NoProfile -Command "iwr -useb https://hf.co/cli/install.ps1 | iex"' + elif method == "hf_installer": + return "curl -LsSf https://hf.co/cli/install.sh | bash -" + else: # unknown => likely pip + return "pip install -U huggingface_hub" + + +def _get_transformers_update_command() -> str: + """Return the command to update transformers.""" + method = installation_method() + if method == "hf_installer" and os.name == "nt": + return 'powershell -NoProfile -Command "iwr -useb https://hf.co/cli/install.ps1 | iex" -WithTransformers' + elif method == "hf_installer": + return "curl -LsSf https://hf.co/cli/install.sh | bash -s -- --with-transformers" + else: # brew/unknown => likely pip + return "pip install -U transformers" diff --git a/huggingface_hub/cli/_errors.py b/huggingface_hub/cli/_errors.py new file mode 100644 index 0000000000000000000000000000000000000000..440d4f465200c890082046e0df0950b50975e23c --- /dev/null +++ b/huggingface_hub/cli/_errors.py @@ -0,0 +1,113 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""CLI error handling utilities.""" + +import traceback +from typing import Callable, Optional + +from huggingface_hub.errors import ( + BucketNotFoundError, + CLIError, + CLIExtensionInstallError, + GatedRepoError, + HfHubHTTPError, + LocalTokenNotFoundError, + RemoteEntryNotFoundError, + RepositoryNotFoundError, + RevisionNotFoundError, +) + + +def _format_repo_not_found(error: RepositoryNotFoundError) -> str: + label = error.repo_type.capitalize() if error.repo_type else "Repository" + if error.repo_id: + msg = f"{label} '{error.repo_id}' not found." + else: + msg = f"{label} not found." + msg += " If the repo is private, make sure you are authenticated." + return msg + + +def _format_gated_repo(error: GatedRepoError) -> str: + label = error.repo_type if error.repo_type else "repository" + if error.repo_id: + return f"Access denied. {label.capitalize()} '{error.repo_id}' requires approval." + return f"Access denied. This {label} requires approval." + + +def _format_bucket_not_found(error: BucketNotFoundError) -> str: + if error.bucket_id: + return f"Bucket '{error.bucket_id}' not found. If the bucket is private, make sure you are authenticated." + return "Bucket not found. Check the bucket id (namespace/name). If the bucket is private, make sure you are authenticated." + + +def _format_entry_not_found(error: RemoteEntryNotFoundError) -> str: + label = error.repo_type if error.repo_type else "repository" + url = str(error.response.url) if error.response else None + if error.repo_id: + msg = f"File not found in {label} '{error.repo_id}'." + else: + msg = f"File not found in {label}." + if url: + msg += f"\nURL: {url}" + return msg + + +def _format_revision_not_found(error: RevisionNotFoundError) -> str: + label = error.repo_type if error.repo_type else "repository" + if error.repo_id: + return f"Revision not found in {label} '{error.repo_id}'." + return f"Revision not found in {label}. Check the revision parameter." + + +def _format_cli_error(error: CLIError) -> str: + """No traceback, just the error message.""" + return str(error) + + +def _format_cli_extension_install_error(error: CLIExtensionInstallError) -> str: + """Format a CLI extension installation error. + + The error is likely to be a tricky subprocess error to investigate. In this specific case we want to format the + traceback of the root cause while keeping the "nicely formatted" error message of the CLIExtensionInstallError + as a 1-line message. + """ + cause_tb = ( + "".join(traceback.format_exception(type(error.__cause__), error.__cause__, error.__cause__.__traceback__)) + if error.__cause__ is not None + else "" + ) + return f"{cause_tb}\n{error}" + + +CLI_ERROR_MAPPINGS: dict[type[Exception], Callable[..., str]] = { + # GatedRepoError must come before RepositoryNotFoundError (it's a subclass). + GatedRepoError: _format_gated_repo, + BucketNotFoundError: _format_bucket_not_found, + RepositoryNotFoundError: _format_repo_not_found, + RevisionNotFoundError: _format_revision_not_found, + LocalTokenNotFoundError: lambda _: "Not logged in. Run 'hf auth login' first.", + RemoteEntryNotFoundError: _format_entry_not_found, + HfHubHTTPError: lambda error: str(error), + ValueError: lambda error: f"Invalid value. {error}", + CLIExtensionInstallError: _format_cli_extension_install_error, + CLIError: _format_cli_error, +} + + +def format_known_exception(error: Exception) -> Optional[str]: + for exc_type, formatter in CLI_ERROR_MAPPINGS.items(): + if isinstance(error, exc_type): + return formatter(error) + return None diff --git a/huggingface_hub/cli/auth.py b/huggingface_hub/cli/auth.py new file mode 100644 index 0000000000000000000000000000000000000000..3c69988835e6ffe5baef7e008e8010b773854eaf --- /dev/null +++ b/huggingface_hub/cli/auth.py @@ -0,0 +1,164 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains commands to authenticate to the Hugging Face Hub and interact with your repositories. + +Usage: + # login and save token locally. + hf auth login --token=hf_*** --add-to-git-credential + + # switch between tokens + hf auth switch + + # list all tokens + hf auth list + + # logout from all tokens + hf auth logout + + # check which account you are logged in as + hf auth whoami +""" + +from typing import Annotated, Optional + +import typer + +from huggingface_hub.constants import ENDPOINT +from huggingface_hub.hf_api import whoami + +from .._login import auth_list, auth_switch, login, logout +from ..utils import ANSI, get_stored_tokens, get_token, logging +from ._cli_utils import TokenOpt, typer_factory + + +logger = logging.get_logger(__name__) + + +auth_cli = typer_factory(help="Manage authentication (login, logout, etc.).") + + +@auth_cli.command( + "login", + examples=[ + "hf auth login", + "hf auth login --token $HF_TOKEN", + "hf auth login --token $HF_TOKEN --add-to-git-credential", + "hf auth login --force", + ], +) +def auth_login( + token: TokenOpt = None, + add_to_git_credential: Annotated[ + bool, + typer.Option( + help="Save to git credential helper. Useful only if you plan to run git commands directly.", + ), + ] = False, + force: Annotated[ + bool, + typer.Option( + help="Force re-login even if already logged in.", + ), + ] = False, +) -> None: + """Login using a token from huggingface.co/settings/tokens.""" + login(token=token, add_to_git_credential=add_to_git_credential, skip_if_logged_in=not force) + + +@auth_cli.command( + "logout", + examples=["hf auth logout", "hf auth logout --token-name my-token"], +) +def auth_logout( + token_name: Annotated[ + Optional[str], + typer.Option(help="Name of token to logout"), + ] = None, +) -> None: + """Logout from a specific token.""" + logout(token_name=token_name) + + +def _select_token_name() -> Optional[str]: + token_names = list(get_stored_tokens().keys()) + + if not token_names: + logger.error("No stored tokens found. Please login first.") + return None + + print("Available stored tokens:") + for i, token_name in enumerate(token_names, 1): + print(f"{i}. {token_name}") + while True: + try: + choice = input("Enter the number of the token to switch to (or 'q' to quit): ") + if choice.lower() == "q": + return None + index = int(choice) - 1 + if 0 <= index < len(token_names): + return token_names[index] + else: + print("Invalid selection. Please try again.") + except ValueError: + print("Invalid input. Please enter a number or 'q' to quit.") + + +@auth_cli.command( + "switch", + examples=["hf auth switch", "hf auth switch --token-name my-token"], +) +def auth_switch_cmd( + token_name: Annotated[ + Optional[str], + typer.Option( + help="Name of the token to switch to", + ), + ] = None, + add_to_git_credential: Annotated[ + bool, + typer.Option( + help="Save to git credential helper. Useful only if you plan to run git commands directly.", + ), + ] = False, +) -> None: + """Switch between access tokens.""" + if token_name is None: + token_name = _select_token_name() + if token_name is None: + print("No token name provided. Aborting.") + raise typer.Exit() + auth_switch(token_name, add_to_git_credential=add_to_git_credential) + + +@auth_cli.command("list | ls", examples=["hf auth list"]) +def auth_list_cmd() -> None: + """List all stored access tokens.""" + auth_list() + + +@auth_cli.command("whoami", examples=["hf auth whoami"]) +def auth_whoami() -> None: + """Find out which huggingface.co account you are logged in as.""" + token = get_token() + if token is None: + print("Not logged in") + raise typer.Exit() + info = whoami(token) + print(ANSI.bold("user: "), info["name"]) + orgs = [org["name"] for org in info["orgs"]] + if orgs: + print(ANSI.bold("orgs: "), ",".join(orgs)) + + if ENDPOINT != "https://huggingface.co": + print(f"Authenticated through private endpoint: {ENDPOINT}") diff --git a/huggingface_hub/cli/buckets.py b/huggingface_hub/cli/buckets.py new file mode 100644 index 0000000000000000000000000000000000000000..e058d1a28f42318f80d4bd93b478784446cc8d21 --- /dev/null +++ b/huggingface_hub/cli/buckets.py @@ -0,0 +1,1056 @@ +# coding=utf-8 +# Copyright 2025-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains commands to interact with buckets via the CLI.""" + +import json +import os +import sys +import tempfile +from datetime import datetime +from typing import Annotated, Optional, Union + +import typer + +from huggingface_hub import logging +from huggingface_hub._buckets import ( + BUCKET_PREFIX, + BucketFile, + BucketFolder, + FilterMatcher, + _is_bucket_path, + _parse_bucket_path, + _split_bucket_id_and_prefix, +) +from huggingface_hub.utils import StatusLine, are_progress_bars_disabled, disable_progress_bars, enable_progress_bars + +from ._cli_utils import ( + FormatOpt, + OutputFormat, + QuietOpt, + TokenOpt, + api_object_to_dict, + get_hf_api, + print_list_output, + typer_factory, +) + + +logger = logging.get_logger(__name__) + + +buckets_cli = typer_factory(help="Commands to interact with buckets.") + + +def _parse_bucket_argument(argument: str) -> tuple[str, str]: + """Parse a bucket argument accepting both 'namespace/name(/prefix)' and 'hf://buckets/namespace/name(/prefix)'. + + Returns: + tuple: (bucket_id, prefix) where bucket_id is "namespace/bucket_name" and prefix may be empty string. + """ + if argument.startswith(BUCKET_PREFIX): + return _parse_bucket_path(argument) + try: + return _split_bucket_id_and_prefix(argument) + except ValueError: + raise ValueError( + f"Invalid bucket argument: {argument}. Must be in format namespace/bucket_name" + f" or {BUCKET_PREFIX}namespace/bucket_name" + ) + + +def _format_size(size: Union[int, float], human_readable: bool = False) -> str: + """Format a size in bytes.""" + if not human_readable: + return str(size) + + for unit in ["B", "KB", "MB", "GB", "TB"]: + if size < 1000: + if unit == "B": + return f"{size} {unit}" + return f"{size:.1f} {unit}" + size /= 1000 + return f"{size:.1f} PB" + + +def _format_mtime(mtime: Optional[datetime], human_readable: bool = False) -> str: + """Format mtime datetime to a readable date string.""" + if mtime is None: + return "" + if human_readable: + return mtime.strftime("%b %d %H:%M") + return mtime.strftime("%Y-%m-%d %H:%M:%S") + + +def _build_tree( + items: list[Union[BucketFile, BucketFolder]], + human_readable: bool = False, + quiet: bool = False, +) -> list[str]: + """Build a tree representation of files and directories. + + Produces ASCII tree with size and date columns before the tree connector. + When quiet=True, only the tree structure is shown (no size/date). + + Args: + items: List of BucketFile/BucketFolder items + human_readable: Whether to show human-readable sizes and short dates + quiet: If True, show only the tree structure without sizes/dates + + Returns: + List of formatted tree lines + """ + # Build a nested structure + tree: dict = {} + + for item in items: + parts = item.path.split("/") + current = tree + for part in parts[:-1]: + if part not in current: + current[part] = {"__children__": {}} + current = current[part]["__children__"] + + final_part = parts[-1] + if isinstance(item, BucketFolder): + if final_part not in current: + current[final_part] = {"__children__": {}} + else: + current[final_part] = {"__item__": item} + + # Compute prefix width for alignment (size + date columns) + prefix_width = 0 + max_size_width = 0 + max_date_width = 0 + if not quiet: + for item in items: + if isinstance(item, BucketFile): + size_str = _format_size(item.size, human_readable) + max_size_width = max(max_size_width, len(size_str)) + date_str = _format_mtime(item.mtime, human_readable) + max_date_width = max(max_date_width, len(date_str)) + if max_size_width > 0: + prefix_width = max_size_width + 2 + max_date_width + + # Render tree + lines: list[str] = [] + _render_tree( + tree, + lines, + "", + prefix_width=prefix_width, + max_size_width=max_size_width, + human_readable=human_readable, + ) + return lines + + +def _render_tree( + node: dict, + lines: list[str], + indent: str, + prefix_width: int = 0, + max_size_width: int = 0, + human_readable: bool = False, +) -> None: + """Recursively render a tree structure with size+date prefix.""" + items = sorted(node.items()) + for i, (name, value) in enumerate(items): + is_last = i == len(items) - 1 + connector = "└── " if is_last else "├── " + + is_dir = "__children__" in value + children = value.get("__children__", {}) + + if prefix_width > 0: + if is_dir: + prefix = " " * prefix_width + else: + item = value.get("__item__") + if item is not None: + size_str = _format_size(item.size, human_readable) + date_str = _format_mtime(item.mtime, human_readable) + prefix = f"{size_str:>{max_size_width}} {date_str}" + else: + prefix = " " * prefix_width + lines.append(f"{prefix} {indent}{connector}{name}{'/' if is_dir else ''}") + else: + lines.append(f"{indent}{connector}{name}{'/' if is_dir else ''}") + + if children: + child_indent = indent + (" " if is_last else "│ ") + _render_tree( + children, + lines, + child_indent, + prefix_width=prefix_width, + max_size_width=max_size_width, + human_readable=human_readable, + ) + + +@buckets_cli.command( + name="create", + examples=[ + "hf buckets create my-bucket", + "hf buckets create user/my-bucket", + "hf buckets create hf://buckets/user/my-bucket", + "hf buckets create user/my-bucket --private", + "hf buckets create user/my-bucket --exist-ok", + ], +) +def create( + bucket_id: Annotated[ + str, + typer.Argument( + help="Bucket ID: bucket_name, namespace/bucket_name, or hf://buckets/namespace/bucket_name", + ), + ], + private: Annotated[ + bool, + typer.Option( + "--private", + help="Create a private bucket.", + ), + ] = False, + exist_ok: Annotated[ + bool, + typer.Option( + "--exist-ok", + help="Do not raise an error if the bucket already exists.", + ), + ] = False, + quiet: QuietOpt = False, + token: TokenOpt = None, +) -> None: + """Create a new bucket.""" + api = get_hf_api(token=token) + + if bucket_id.startswith(BUCKET_PREFIX): + try: + parsed_id, prefix = _parse_bucket_argument(bucket_id) + except ValueError as e: + raise typer.BadParameter(str(e)) + if prefix: + raise typer.BadParameter( + f"Cannot specify a prefix for bucket creation: {bucket_id}." + f" Use namespace/bucket_name or {BUCKET_PREFIX}namespace/bucket_name." + ) + bucket_id = parsed_id + + bucket_url = api.create_bucket( + bucket_id, + private=private if private else None, + exist_ok=exist_ok, + ) + if quiet: + print(bucket_url.handle) + else: + print(f"Bucket created: {bucket_url.url} (handle: {bucket_url.handle})") + + +def _is_bucket_id(argument: str) -> bool: + """Check if argument is a bucket ID (namespace/name) vs just a namespace.""" + if argument.startswith(BUCKET_PREFIX): + path = argument[len(BUCKET_PREFIX) :] + else: + path = argument + return "/" in path + + +@buckets_cli.command( + name="list | ls", + examples=[ + "hf buckets list", + "hf buckets list huggingface", + "hf buckets list user/my-bucket", + "hf buckets list user/my-bucket -R", + "hf buckets list user/my-bucket -h", + "hf buckets list user/my-bucket --tree", + "hf buckets list user/my-bucket --tree -h", + "hf buckets list hf://buckets/user/my-bucket", + "hf buckets list user/my-bucket/sub -R", + ], +) +def list_cmd( + argument: Annotated[ + Optional[str], + typer.Argument( + help=( + "Namespace (user or org) to list buckets, or bucket ID" + " (namespace/bucket_name(/prefix) or hf://buckets/...) to list files." + ), + ), + ] = None, + human_readable: Annotated[ + bool, + typer.Option( + "--human-readable", + "-h", + help="Show sizes in human readable format.", + ), + ] = False, + as_tree: Annotated[ + bool, + typer.Option( + "--tree", + help="List files in tree format (only for listing files).", + ), + ] = False, + recursive: Annotated[ + bool, + typer.Option( + "--recursive", + "-R", + help="List files recursively (only for listing files).", + ), + ] = False, + format: FormatOpt = OutputFormat.table, + quiet: QuietOpt = False, + token: TokenOpt = None, +) -> None: + """List buckets or files in a bucket. + + When called with no argument or a namespace, lists buckets. + When called with a bucket ID (namespace/bucket_name), lists files in the bucket. + """ + # Determine mode: listing buckets or listing files + is_file_mode = argument is not None and _is_bucket_id(argument) + + if is_file_mode: + _list_files( + argument=argument, # type: ignore[arg-type] + human_readable=human_readable, + as_tree=as_tree, + recursive=recursive, + format=format, + quiet=quiet, + token=token, + ) + else: + _list_buckets( + namespace=argument, + human_readable=human_readable, + as_tree=as_tree, + recursive=recursive, + format=format, + quiet=quiet, + token=token, + ) + + +def _list_buckets( + namespace: Optional[str], + human_readable: bool, + as_tree: bool, + recursive: bool, + format: OutputFormat, + quiet: bool, + token: Optional[str], +) -> None: + """List buckets in a namespace.""" + # Validate incompatible flags + if as_tree: + raise typer.BadParameter("Cannot use --tree when listing buckets.") + if recursive: + raise typer.BadParameter("Cannot use --recursive when listing buckets.") + + # Handle hf://buckets/namespace format + if namespace is not None and namespace.startswith(BUCKET_PREFIX): + namespace = namespace[len(BUCKET_PREFIX) :] + # Strip trailing slash if any + namespace = namespace.rstrip("/") + + api = get_hf_api(token=token) + results = [api_object_to_dict(bucket) for bucket in api.list_buckets(namespace=namespace)] + + if not results: + if not quiet and format != OutputFormat.json: + resolved_namespace = namespace if namespace is not None else api.whoami()["name"] + print(f"No buckets found under namespace '{resolved_namespace}'.") + return + + headers = ["id", "private", "size", "total_files", "created_at"] + + def row_fn(item: dict) -> list[str]: + from ._cli_utils import _format_cell + + return [ + _format_cell(item.get("id")), + _format_cell(item.get("private")), + _format_size(item.get("size", 0), human_readable=human_readable), + _format_cell(item.get("total_files")), + _format_cell(item.get("created_at")), + ] + + alignments = {"size": "right", "total_files": "right"} + print_list_output(results, format=format, quiet=quiet, headers=headers, row_fn=row_fn, alignments=alignments) + + +def _list_files( + argument: str, + human_readable: bool, + as_tree: bool, + recursive: bool, + format: OutputFormat, + quiet: bool, + token: Optional[str], +) -> None: + """List files in a bucket.""" + # Validate incompatible flags + if as_tree and format == OutputFormat.json: + raise typer.BadParameter("Cannot use --tree with --format json.") + + api = get_hf_api(token=token) + + try: + bucket_id, prefix = _parse_bucket_argument(argument) + except ValueError as e: + raise typer.BadParameter(str(e)) + + # Fetch items from the bucket + items = list( + api.list_bucket_tree( + bucket_id, + prefix=prefix or None, + recursive=recursive, + ) + ) + + if not items: + print("(empty)") + return + + has_directories = any(isinstance(item, BucketFolder) for item in items) + + if format == OutputFormat.json: + results = [api_object_to_dict(item) for item in items] + print(json.dumps(results, indent=2)) + elif as_tree: + # Tree format with size+date prefix, or quiet for structure only + tree_lines = _build_tree(items, human_readable=human_readable, quiet=quiet) + for line in tree_lines: + print(line) + elif quiet: + for item in items: + if isinstance(item, BucketFolder): + print(f"{item.path}/") + else: + print(item.path) + else: + # Flat table format + for item in items: + if isinstance(item, BucketFolder): + mtime_str = _format_mtime(item.uploaded_at, human_readable) + print(f"{'':>12} {mtime_str:>19} {item.path}/") + else: + size_str = _format_size(item.size, human_readable) + mtime_str = _format_mtime(item.mtime, human_readable) + print(f"{size_str:>12} {mtime_str:>19} {item.path}") + + if not recursive and has_directories: + StatusLine().done("Use -R to list files recursively.") + + +@buckets_cli.command( + name="info", + examples=[ + "hf buckets info user/my-bucket", + "hf buckets info hf://buckets/user/my-bucket", + ], +) +def info( + bucket_id: Annotated[ + str, + typer.Argument( + help="Bucket ID: namespace/bucket_name or hf://buckets/namespace/bucket_name", + ), + ], + quiet: QuietOpt = False, + token: TokenOpt = None, +) -> None: + """Get info about a bucket.""" + api = get_hf_api(token=token) + + try: + parsed_id, _ = _parse_bucket_argument(bucket_id) + except ValueError as e: + raise typer.BadParameter(str(e)) + + bucket = api.bucket_info(parsed_id) + if quiet: + print(bucket.id) + else: + print(json.dumps(api_object_to_dict(bucket), indent=2)) + + +@buckets_cli.command( + name="delete", + examples=[ + "hf buckets delete user/my-bucket", + "hf buckets delete hf://buckets/user/my-bucket", + "hf buckets delete user/my-bucket --yes", + "hf buckets delete user/my-bucket --missing-ok", + ], +) +def delete( + bucket_id: Annotated[ + str, + typer.Argument( + help="Bucket ID: namespace/bucket_name or hf://buckets/namespace/bucket_name", + ), + ], + yes: Annotated[ + bool, + typer.Option( + "--yes", + "-y", + help="Skip confirmation prompt.", + ), + ] = False, + missing_ok: Annotated[ + bool, + typer.Option( + "--missing-ok", + help="Do not raise an error if the bucket does not exist.", + ), + ] = False, + quiet: QuietOpt = False, + token: TokenOpt = None, +) -> None: + """Delete a bucket. + + This deletes the entire bucket and all its contents. Use `hf buckets rm` to remove individual files. + """ + if bucket_id.startswith(BUCKET_PREFIX): + try: + parsed_id, prefix = _parse_bucket_argument(bucket_id) + except ValueError as e: + raise typer.BadParameter(str(e)) + if prefix: + raise typer.BadParameter( + f"Cannot specify a prefix for bucket deletion: {bucket_id}." + f" Use namespace/bucket_name or {BUCKET_PREFIX}namespace/bucket_name." + ) + bucket_id = parsed_id + elif "/" not in bucket_id: + raise typer.BadParameter( + f"Invalid bucket ID: {bucket_id}." + f" Must be in format namespace/bucket_name or {BUCKET_PREFIX}namespace/bucket_name." + ) + + if not yes: + confirm = typer.confirm(f"Are you sure you want to delete bucket '{bucket_id}'?") + if not confirm: + print("Aborted.") + raise typer.Abort() + + api = get_hf_api(token=token) + api.delete_bucket(bucket_id, missing_ok=missing_ok) + if quiet: + print(bucket_id) + else: + print(f"Bucket deleted: {bucket_id}") + + +@buckets_cli.command( + name="remove | rm", + examples=[ + "hf buckets remove user/my-bucket/file.txt", + "hf buckets rm hf://buckets/user/my-bucket/file.txt", + "hf buckets rm user/my-bucket/logs/ --recursive", + 'hf buckets rm user/my-bucket --recursive --include "*.tmp"', + "hf buckets rm user/my-bucket/data/ --recursive --dry-run", + ], +) +def remove( + argument: Annotated[ + str, + typer.Argument( + help=( + "Bucket path: namespace/bucket_name/path or hf://buckets/namespace/bucket_name/path." + " With --recursive, namespace/bucket_name is also accepted to target all files." + ), + ), + ], + recursive: Annotated[ + bool, + typer.Option( + "--recursive", + "-R", + help="Remove files recursively under the given prefix.", + ), + ] = False, + yes: Annotated[ + bool, + typer.Option( + "--yes", + "-y", + help="Skip confirmation prompt.", + ), + ] = False, + dry_run: Annotated[ + bool, + typer.Option( + "--dry-run", + help="Preview what would be deleted without actually deleting.", + ), + ] = False, + include: Annotated[ + Optional[list[str]], + typer.Option( + help="Include only files matching pattern (can specify multiple). Requires --recursive.", + ), + ] = None, + exclude: Annotated[ + Optional[list[str]], + typer.Option( + help="Exclude files matching pattern (can specify multiple). Requires --recursive.", + ), + ] = None, + quiet: QuietOpt = False, + token: TokenOpt = None, +) -> None: + """Remove files from a bucket. + + To delete an entire bucket, use `hf buckets delete` instead. + """ + try: + bucket_id, prefix = _parse_bucket_argument(argument) + except ValueError as e: + raise typer.BadParameter(str(e)) + + if prefix == "" and not recursive: + raise typer.BadParameter( + f"No file path specified. To remove files, provide a path" + f" (e.g. '{bucket_id}/FILE') or use --recursive to remove all files." + f" To delete the entire bucket, use `hf buckets delete {bucket_id}`." + ) + + if (include or exclude) and not recursive: + raise typer.BadParameter("--include and --exclude require --recursive.") + + api = get_hf_api(token=token) + + if recursive: + status = StatusLine(enabled=not quiet) + status.update("Listing files from remote") + + all_files: list[BucketFile] = [] + for item in api.list_bucket_tree( + bucket_id, + prefix=prefix.rstrip("/") or None, + recursive=True, + ): + if isinstance(item, BucketFile): + all_files.append(item) + status.update(f"Listing files from remote ({len(all_files)} files)") + status.done(f"Listing files from remote ({len(all_files)} files)") + + if include or exclude: + matcher = FilterMatcher(include_patterns=include, exclude_patterns=exclude) + matched_files = [f for f in all_files if matcher.matches(f.path)] + else: + matched_files = all_files + + file_paths = [f.path for f in matched_files] + total_size = sum(f.size for f in matched_files) + size_str = _format_size(total_size, human_readable=True) + + if not file_paths: + if not quiet: + print("No files to remove.") + return + + count_label = f"{len(file_paths)} file(s) totaling {size_str}" + + if not yes and not dry_run: + if not quiet: + for path in file_paths: + print(f" {path}") + confirm = typer.confirm(f"Remove {count_label} from '{bucket_id}'?") + if not confirm: + print("Aborted.") + raise typer.Abort() + + if dry_run: + for path in file_paths: + print(f"delete: {BUCKET_PREFIX}{bucket_id}/{path}") + print(f"(dry run) {count_label} would be removed.") + return + + api.batch_bucket_files(bucket_id, delete=file_paths) + if quiet: + for path in file_paths: + print(path) + else: + for path in file_paths: + print(f"delete: {BUCKET_PREFIX}{bucket_id}/{path}") + print(f"Removed {count_label} from '{bucket_id}'.") + + else: + file_path = prefix.rstrip("/") + if not file_path: + raise typer.BadParameter("File path cannot be empty.") + + if dry_run: + print(f"delete: {BUCKET_PREFIX}{bucket_id}/{file_path}") + print("(dry run) 1 file would be removed.") + return + + if not yes: + confirm = typer.confirm(f"Remove '{file_path}' from '{bucket_id}'?") + if not confirm: + print("Aborted.") + raise typer.Abort() + + api.batch_bucket_files(bucket_id, delete=[file_path]) + if quiet: + print(file_path) + else: + print(f"delete: {BUCKET_PREFIX}{bucket_id}/{file_path}") + + +@buckets_cli.command( + name="move", + examples=[ + "hf buckets move user/old-bucket user/new-bucket", + "hf buckets move user/my-bucket my-org/my-bucket", + "hf buckets move hf://buckets/user/old-bucket hf://buckets/user/new-bucket", + ], +) +def move( + from_id: Annotated[ + str, + typer.Argument( + help="Source bucket ID: namespace/bucket_name or hf://buckets/namespace/bucket_name", + ), + ], + to_id: Annotated[ + str, + typer.Argument( + help="Destination bucket ID: namespace/bucket_name or hf://buckets/namespace/bucket_name", + ), + ], + token: TokenOpt = None, +) -> None: + """Move (rename) a bucket to a new name or namespace.""" + # Parse from_id + parsed_from_id, from_prefix = _parse_bucket_argument(from_id) + if from_prefix: + raise typer.BadParameter( + f"Cannot specify a prefix for bucket move: {from_id}." + f" Use namespace/bucket_name or {BUCKET_PREFIX}namespace/bucket_name." + ) + + # Parse to_id + parsed_to_id, to_prefix = _parse_bucket_argument(to_id) + if to_prefix: + raise typer.BadParameter( + f"Cannot specify a prefix for bucket move: {to_id}." + f" Use namespace/bucket_name or {BUCKET_PREFIX}namespace/bucket_name." + ) + + api = get_hf_api(token=token) + api.move_bucket(from_id=parsed_from_id, to_id=parsed_to_id) + print(f"Bucket moved: {parsed_from_id} -> {parsed_to_id}") + + +# ============================================================================= +# Sync command +# ============================================================================= + + +@buckets_cli.command( + name="sync", + examples=[ + "hf buckets sync ./data hf://buckets/user/my-bucket", + "hf buckets sync hf://buckets/user/my-bucket ./data", + "hf buckets sync ./data hf://buckets/user/my-bucket --delete", + 'hf buckets sync hf://buckets/user/my-bucket ./data --include "*.safetensors" --exclude "*.tmp"', + "hf buckets sync ./data hf://buckets/user/my-bucket --plan sync-plan.jsonl", + "hf buckets sync --apply sync-plan.jsonl", + "hf buckets sync ./data hf://buckets/user/my-bucket --dry-run", + "hf buckets sync ./data hf://buckets/user/my-bucket --dry-run | jq .", + ], +) +def sync( + source: Annotated[ + Optional[str], + typer.Argument( + help="Source path: local directory or hf://buckets/namespace/bucket_name(/prefix)", + ), + ] = None, + dest: Annotated[ + Optional[str], + typer.Argument( + help="Destination path: local directory or hf://buckets/namespace/bucket_name(/prefix)", + ), + ] = None, + delete: Annotated[ + bool, + typer.Option( + help="Delete destination files not present in source.", + ), + ] = False, + ignore_times: Annotated[ + bool, + typer.Option( + "--ignore-times", + help="Skip files only based on size, ignoring modification times.", + ), + ] = False, + ignore_sizes: Annotated[ + bool, + typer.Option( + "--ignore-sizes", + help="Skip files only based on modification times, ignoring sizes.", + ), + ] = False, + plan: Annotated[ + Optional[str], + typer.Option( + help="Save sync plan to JSONL file for review instead of executing.", + ), + ] = None, + apply: Annotated[ + Optional[str], + typer.Option( + help="Apply a previously saved plan file.", + ), + ] = None, + dry_run: Annotated[ + bool, + typer.Option( + "--dry-run", + help="Print sync plan to stdout as JSONL without executing.", + ), + ] = False, + include: Annotated[ + Optional[list[str]], + typer.Option( + help="Include files matching pattern (can specify multiple).", + ), + ] = None, + exclude: Annotated[ + Optional[list[str]], + typer.Option( + help="Exclude files matching pattern (can specify multiple).", + ), + ] = None, + filter_from: Annotated[ + Optional[str], + typer.Option( + help="Read include/exclude patterns from file.", + ), + ] = None, + existing: Annotated[ + bool, + typer.Option( + "--existing", + help="Skip creating new files on receiver (only update existing files).", + ), + ] = False, + ignore_existing: Annotated[ + bool, + typer.Option( + "--ignore-existing", + help="Skip updating files that exist on receiver (only create new files).", + ), + ] = False, + verbose: Annotated[ + bool, + typer.Option( + "--verbose", + "-v", + help="Show detailed logging with reasoning.", + ), + ] = False, + quiet: Annotated[ + bool, + typer.Option( + "--quiet", + "-q", + help="Minimal output.", + ), + ] = False, + token: TokenOpt = None, +) -> None: + """Sync files between local directory and a bucket.""" + api = get_hf_api(token=token) + api.sync_bucket( + source=source, + dest=dest, + delete=delete, + ignore_times=ignore_times, + ignore_sizes=ignore_sizes, + existing=existing, + ignore_existing=ignore_existing, + include=include, + exclude=exclude, + filter_from=filter_from, + plan=plan, + apply=apply, + dry_run=dry_run, + verbose=verbose, + quiet=quiet, + ) + + +# ============================================================================= +# Cp command +# ============================================================================= + + +@buckets_cli.command( + name="cp", + examples=[ + "hf buckets cp hf://buckets/user/my-bucket/config.json", + "hf buckets cp hf://buckets/user/my-bucket/config.json ./data/", + "hf buckets cp hf://buckets/user/my-bucket/config.json my-config.json", + "hf buckets cp hf://buckets/user/my-bucket/config.json -", + "hf buckets cp my-config.json hf://buckets/user/my-bucket", + "hf buckets cp my-config.json hf://buckets/user/my-bucket/logs/", + "hf buckets cp my-config.json hf://buckets/user/my-bucket/remote-config.json", + "hf buckets cp - hf://buckets/user/my-bucket/config.json", + ], +) +def cp( + src: Annotated[str, typer.Argument(help="Source: local file, hf://buckets/... path, or - for stdin")], + dst: Annotated[ + Optional[str], typer.Argument(help="Destination: local path, hf://buckets/... path, or - for stdout") + ] = None, + quiet: QuietOpt = False, + token: TokenOpt = None, +) -> None: + """Copy a single file to or from a bucket.""" + api = get_hf_api(token=token) + + src_is_bucket = _is_bucket_path(src) + dst_is_bucket = dst is not None and _is_bucket_path(dst) + src_is_stdin = src == "-" + dst_is_stdout = dst == "-" + + # --- Validation --- + if src_is_bucket and dst_is_bucket: + raise typer.BadParameter("Remote-to-remote copy not supported.") + + if not src_is_bucket and not dst_is_bucket and not src_is_stdin: + if dst is None: + raise typer.BadParameter("Missing destination. Provide a bucket path as DST.") + raise typer.BadParameter("One of SRC or DST must be a bucket path (hf://buckets/...).") + + if src_is_stdin and not dst_is_bucket: + raise typer.BadParameter("Stdin upload requires a bucket destination.") + + if src_is_stdin and dst_is_bucket: + assert dst is not None + _, prefix = _parse_bucket_path(dst) + if prefix == "" or prefix.endswith("/"): + raise typer.BadParameter("Stdin upload requires a full destination path including filename.") + + if dst_is_stdout and not src_is_bucket: + raise typer.BadParameter("Cannot pipe to stdout for uploads.") + + if not src_is_bucket and not src_is_stdin and os.path.isdir(src): + raise typer.BadParameter("Source must be a file, not a directory. Use `hf buckets sync` for directories.") + + # --- Determine direction and execute --- + if src_is_bucket: + # Download: remote -> local or stdout + bucket_id, prefix = _parse_bucket_path(src) + if prefix == "" or prefix.endswith("/"): + raise typer.BadParameter("Source path must include a file name, not just a bucket or directory path.") + filename = prefix.rsplit("/", 1)[-1] + + if dst_is_stdout: + # Download to stdout: always suppress progress bars to avoid polluting output + # Only re-enable if they weren't already disabled by the caller + pbar_was_disabled = are_progress_bars_disabled() + if not pbar_was_disabled: + disable_progress_bars() + try: + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = os.path.join(tmp_dir, filename) + api.download_bucket_files(bucket_id, [(prefix, tmp_path)]) + with open(tmp_path, "rb") as f: + sys.stdout.buffer.write(f.read()) + finally: + if not pbar_was_disabled: + enable_progress_bars() + else: + # Download to file + if dst is None: + local_path = filename + elif os.path.isdir(dst) or dst.endswith(os.sep) or dst.endswith("/"): + local_path = os.path.join(dst, filename) + else: + local_path = dst + + # Ensure parent directory exists + parent_dir = os.path.dirname(local_path) + if parent_dir: + os.makedirs(parent_dir, exist_ok=True) + + if quiet: + disable_progress_bars() + try: + api.download_bucket_files(bucket_id, [(prefix, local_path)]) + finally: + if quiet: + enable_progress_bars() + + if not quiet: + print(f"Downloaded: {src} -> {local_path}") + + elif src_is_stdin: + # Upload from stdin + bucket_id, remote_path = _parse_bucket_path(dst) # type: ignore[arg-type] + data = sys.stdin.buffer.read() + + if quiet: + disable_progress_bars() + try: + api.batch_bucket_files(bucket_id, add=[(data, remote_path)]) + finally: + if quiet: + enable_progress_bars() + + if not quiet: + print(f"Uploaded: stdin -> {dst}") + + else: + # Upload from file + if not os.path.isfile(src): + raise typer.BadParameter(f"Source file not found: {src}") + + bucket_id, prefix = _parse_bucket_path(dst) # type: ignore[arg-type] + + if prefix == "": + remote_path = os.path.basename(src) + elif prefix.endswith("/"): + remote_path = prefix + os.path.basename(src) + else: + remote_path = prefix + + if quiet: + disable_progress_bars() + try: + api.batch_bucket_files(bucket_id, add=[(src, remote_path)]) + finally: + if quiet: + enable_progress_bars() + + if not quiet: + print(f"Uploaded: {src} -> {BUCKET_PREFIX}{bucket_id}/{remote_path}") diff --git a/huggingface_hub/cli/cache.py b/huggingface_hub/cli/cache.py new file mode 100644 index 0000000000000000000000000000000000000000..2145d988df8d3e4280675e03f78a0d68e8855c55 --- /dev/null +++ b/huggingface_hub/cli/cache.py @@ -0,0 +1,812 @@ +# coding=utf-8 +# Copyright 2025-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains the 'hf cache' command group with cache management subcommands.""" + +import json +import re +import sys +import time +from collections import defaultdict +from dataclasses import dataclass +from enum import Enum +from typing import Annotated, Any, Callable, Dict, List, Mapping, Optional, Tuple + +import typer + +from huggingface_hub.errors import CLIError + +from ..utils import ( + ANSI, + CachedRepoInfo, + CachedRevisionInfo, + CacheNotFound, + HFCacheInfo, + _format_size, + scan_cache_dir, + tabulate, +) +from ..utils._parsing import parse_duration, parse_size +from ._cli_utils import ( + OutputFormat, + RepoIdArg, + RepoTypeOpt, + RevisionOpt, + TokenOpt, + get_hf_api, + typer_factory, +) + + +cache_cli = typer_factory(help="Manage local cache directory.") + + +#### Cache helper utilities + + +@dataclass(frozen=True) +class _DeletionResolution: + revisions: frozenset[str] + selected: dict[CachedRepoInfo, frozenset[CachedRevisionInfo]] + missing: tuple[str, ...] + + +_FILTER_PATTERN = re.compile(r"^(?P[a-zA-Z_]+)\s*(?P==|!=|>=|<=|>|<|=)\s*(?P.+)$") +_ALLOWED_OPERATORS = {"=", "!=", ">", "<", ">=", "<="} +_FILTER_KEYS = {"accessed", "modified", "refs", "size", "type"} +_SORT_KEYS = {"accessed", "modified", "name", "size"} +_SORT_PATTERN = re.compile(r"^(?P[a-zA-Z_]+)(?::(?Pasc|desc))?$") +_SORT_DEFAULT_ORDER = { + # Default ordering: accessed/modified/size are descending (newest/biggest first), name is ascending + "accessed": "desc", + "modified": "desc", + "size": "desc", + "name": "asc", +} + + +# Dynamically generate SortOptions enum from _SORT_KEYS +_sort_options_dict = {} +for key in sorted(_SORT_KEYS): + _sort_options_dict[key] = key + _sort_options_dict[f"{key}_asc"] = f"{key}:asc" + _sort_options_dict[f"{key}_desc"] = f"{key}:desc" + +SortOptions = Enum("SortOptions", _sort_options_dict, type=str, module=__name__) # type: ignore + + +@dataclass(frozen=True) +class CacheDeletionCounts: + """Simple counters summarizing cache deletions for CLI messaging.""" + + repo_count: int + partial_revision_count: int + total_revision_count: int + + +CacheEntry = Tuple[CachedRepoInfo, Optional[CachedRevisionInfo]] +RepoRefsMap = Dict[CachedRepoInfo, frozenset[str]] + + +def summarize_deletions( + selected_by_repo: Mapping[CachedRepoInfo, frozenset[CachedRevisionInfo]], +) -> CacheDeletionCounts: + """Summarize deletions across repositories.""" + repo_count = 0 + total_revisions = 0 + revisions_in_full_repos = 0 + + for repo, revisions in selected_by_repo.items(): + total_revisions += len(revisions) + if len(revisions) == len(repo.revisions): + repo_count += 1 + revisions_in_full_repos += len(revisions) + + partial_revision_count = total_revisions - revisions_in_full_repos + return CacheDeletionCounts(repo_count, partial_revision_count, total_revisions) + + +def print_cache_selected_revisions(selected_by_repo: Mapping[CachedRepoInfo, frozenset[CachedRevisionInfo]]) -> None: + """Pretty-print selected cache revisions during confirmation prompts.""" + for repo in sorted(selected_by_repo.keys(), key=lambda repo: (repo.repo_type, repo.repo_id.lower())): + repo_key = f"{repo.repo_type}/{repo.repo_id}" + revisions = sorted(selected_by_repo[repo], key=lambda rev: rev.commit_hash) + if len(revisions) == len(repo.revisions): + print(f" - {repo_key} (entire repo)") + continue + + print(f" - {repo_key}:") + for revision in revisions: + refs = " ".join(sorted(revision.refs)) or "(detached)" + print(f" {revision.commit_hash} [{refs}] {revision.size_on_disk_str}") + + +def build_cache_index( + hf_cache_info: HFCacheInfo, +) -> Tuple[ + Dict[str, CachedRepoInfo], + Dict[str, Tuple[CachedRepoInfo, CachedRevisionInfo]], +]: + """Create lookup tables so CLI commands can resolve repo ids and revisions quickly.""" + repo_lookup: dict[str, CachedRepoInfo] = {} + revision_lookup: dict[str, tuple[CachedRepoInfo, CachedRevisionInfo]] = {} + for repo in hf_cache_info.repos: + repo_key = repo.cache_id.lower() + repo_lookup[repo_key] = repo + for revision in repo.revisions: + revision_lookup[revision.commit_hash.lower()] = (repo, revision) + return repo_lookup, revision_lookup + + +def collect_cache_entries( + hf_cache_info: HFCacheInfo, *, include_revisions: bool +) -> Tuple[List[CacheEntry], RepoRefsMap]: + """Flatten cache metadata into rows consumed by `hf cache ls`.""" + entries: List[CacheEntry] = [] + repo_refs_map: RepoRefsMap = {} + sorted_repos = sorted(hf_cache_info.repos, key=lambda repo: (repo.repo_type, repo.repo_id.lower())) + for repo in sorted_repos: + repo_refs_map[repo] = frozenset({ref for revision in repo.revisions for ref in revision.refs}) + if include_revisions: + for revision in sorted(repo.revisions, key=lambda rev: rev.commit_hash): + entries.append((repo, revision)) + else: + entries.append((repo, None)) + if include_revisions: + entries.sort( + key=lambda entry: ( + entry[0].cache_id, + entry[1].commit_hash if entry[1] is not None else "", + ) + ) + else: + entries.sort(key=lambda entry: entry[0].cache_id) + return entries, repo_refs_map + + +def compile_cache_filter( + expr: str, repo_refs_map: RepoRefsMap +) -> Callable[[CachedRepoInfo, Optional[CachedRevisionInfo], float], bool]: + """Convert a `hf cache ls` filter expression into the yes/no test we apply to each cache entry before displaying it.""" + match = _FILTER_PATTERN.match(expr.strip()) + if not match: + raise ValueError(f"Invalid filter expression: '{expr}'.") + + key = match.group("key").lower() + op = match.group("op") + value_raw = match.group("value").strip() + + if op not in _ALLOWED_OPERATORS: + raise ValueError(f"Unsupported operator '{op}' in filter '{expr}'. Must be one of {list(_ALLOWED_OPERATORS)}.") + + if key not in _FILTER_KEYS: + raise ValueError(f"Unsupported filter key '{key}' in '{expr}'. Must be one of {list(_FILTER_KEYS)}.") + # at this point we know that key is in `_FILTER_KEYS` + if key == "size": + size_threshold = parse_size(value_raw) + return lambda repo, revision, _: _compare_numeric( + revision.size_on_disk if revision is not None else repo.size_on_disk, + op, + size_threshold, + ) + + if key in {"modified", "accessed"}: + seconds = parse_duration(value_raw.strip()) + + def _time_filter(repo: CachedRepoInfo, revision: Optional[CachedRevisionInfo], now: float) -> bool: + timestamp = ( + repo.last_accessed + if key == "accessed" + else revision.last_modified + if revision is not None + else repo.last_modified + ) + if timestamp is None: + return False + return _compare_numeric(now - timestamp, op, seconds) + + return _time_filter + + if key == "type": + expected = value_raw.lower() + + if op != "=": + raise ValueError(f"Only '=' is supported for 'type' filters. Got '{op}'.") + + def _type_filter(repo: CachedRepoInfo, revision: Optional[CachedRevisionInfo], _: float) -> bool: + return repo.repo_type.lower() == expected + + return _type_filter + + else: # key == "refs" + if op != "=": + raise ValueError(f"Only '=' is supported for 'refs' filters. Got {op}.") + + def _refs_filter(repo: CachedRepoInfo, revision: Optional[CachedRevisionInfo], _: float) -> bool: + refs = revision.refs if revision is not None else repo_refs_map.get(repo, frozenset()) + return value_raw.lower() in [ref.lower() for ref in refs] + + return _refs_filter + + +def _build_cache_export_payload( + entries: List[CacheEntry], *, include_revisions: bool, repo_refs_map: RepoRefsMap +) -> List[Dict[str, Any]]: + """Normalize cache entries into serializable records for JSON/CSV exports.""" + payload: List[Dict[str, Any]] = [] + for repo, revision in entries: + if include_revisions: + if revision is None: + continue + record: Dict[str, Any] = { + "repo_id": repo.repo_id, + "repo_type": repo.repo_type, + "revision": revision.commit_hash, + "snapshot_path": str(revision.snapshot_path), + "size_on_disk": revision.size_on_disk, + "last_accessed": repo.last_accessed, + "last_modified": revision.last_modified, + "refs": sorted(revision.refs), + } + else: + record = { + "repo_id": repo.repo_id, + "repo_type": repo.repo_type, + "size_on_disk": repo.size_on_disk, + "last_accessed": repo.last_accessed, + "last_modified": repo.last_modified, + "refs": sorted(repo_refs_map.get(repo, frozenset())), + } + payload.append(record) + return payload + + +def print_cache_entries_table( + entries: List[CacheEntry], *, include_revisions: bool, repo_refs_map: RepoRefsMap +) -> None: + """Render cache entries as a table and show a human-readable summary.""" + if not entries: + message = "No cached revisions found." if include_revisions else "No cached repositories found." + print(message) + return + table_rows: List[List[str]] + if include_revisions: + headers = ["ID", "REVISION", "SIZE", "LAST_MODIFIED", "REFS"] + table_rows = [ + [ + repo.cache_id, + revision.commit_hash, + revision.size_on_disk_str.rjust(8), + revision.last_modified_str, + " ".join(sorted(revision.refs)), + ] + for repo, revision in entries + if revision is not None + ] + else: + headers = ["ID", "SIZE", "LAST_ACCESSED", "LAST_MODIFIED", "REFS"] + table_rows = [ + [ + repo.cache_id, + repo.size_on_disk_str.rjust(8), + repo.last_accessed_str or "", + repo.last_modified_str, + " ".join(sorted(repo_refs_map.get(repo, frozenset()))), + ] + for repo, _ in entries + ] + + print(tabulate(table_rows, headers=headers)) # type: ignore[arg-type] + + unique_repos = {repo for repo, _ in entries} + repo_count = len(unique_repos) + if include_revisions: + revision_count = sum(1 for _, revision in entries if revision is not None) + total_size = sum(revision.size_on_disk for _, revision in entries if revision is not None) + else: + revision_count = sum(len(repo.revisions) for repo in unique_repos) + total_size = sum(repo.size_on_disk for repo in unique_repos) + + summary = f"\nFound {repo_count} repo(s) for a total of {revision_count} revision(s) and {_format_size(total_size)} on disk." + print(ANSI.bold(summary)) + + +def print_cache_entries_json( + entries: List[CacheEntry], *, include_revisions: bool, repo_refs_map: RepoRefsMap +) -> None: + """Dump cache entries as JSON for scripting or automation.""" + payload = _build_cache_export_payload(entries, include_revisions=include_revisions, repo_refs_map=repo_refs_map) + json.dump(payload, sys.stdout, indent=2) + sys.stdout.write("\n") + + +def _compare_numeric(left: Optional[float], op: str, right: float) -> bool: + """Evaluate numeric comparisons for filters.""" + if left is None: + return False + + comparisons = { + "=": left == right, + "!=": left != right, + ">": left > right, + "<": left < right, + ">=": left >= right, + "<=": left <= right, + } + + if op not in comparisons: + raise ValueError(f"Unsupported numeric comparison operator: {op}") + + return comparisons[op] + + +def compile_cache_sort(sort_expr: str) -> tuple[Callable[[CacheEntry], tuple[Any, ...]], bool]: + """Convert a `hf cache ls` sort expression into a key function for sorting entries. + + Returns: + A tuple of (key_function, reverse_flag) where reverse_flag indicates whether + to sort in descending order (True) or ascending order (False). + """ + match = _SORT_PATTERN.match(sort_expr.strip().lower()) + if not match: + raise ValueError(f"Invalid sort expression: '{sort_expr}'. Expected format: 'key' or 'key:asc' or 'key:desc'.") + + key = match.group("key").lower() + explicit_order = match.group("order") + + if key not in _SORT_KEYS: + raise ValueError(f"Unsupported sort key '{key}' in '{sort_expr}'. Must be one of {list(_SORT_KEYS)}.") + + # Use explicit order if provided, otherwise use default for the key + order = explicit_order if explicit_order else _SORT_DEFAULT_ORDER[key] + reverse = order == "desc" + + def _sort_key(entry: CacheEntry) -> tuple[Any, ...]: + repo, revision = entry + + if key == "name": + # Sort by cache_id (repo type/id) + value: Any = repo.cache_id.lower() + return (value,) + + if key == "size": + # Use revision size if available, otherwise repo size + value = revision.size_on_disk if revision is not None else repo.size_on_disk + return (value,) + + if key == "accessed": + # For revisions, accessed is not available per-revision, use repo's last_accessed + # For repos, use repo's last_accessed + value = repo.last_accessed if repo.last_accessed is not None else 0.0 + return (value,) + + if key == "modified": + # Use revision's last_modified if available, otherwise repo's last_modified + if revision is not None: + value = revision.last_modified if revision.last_modified is not None else 0.0 + else: + value = repo.last_modified if repo.last_modified is not None else 0.0 + return (value,) + + # Should never reach here due to validation above + raise ValueError(f"Unsupported sort key: {key}") + + return _sort_key, reverse + + +def _resolve_deletion_targets(hf_cache_info: HFCacheInfo, targets: list[str]) -> _DeletionResolution: + """Resolve the deletion targets into a deletion resolution.""" + repo_lookup, revision_lookup = build_cache_index(hf_cache_info) + + selected: dict[CachedRepoInfo, set[CachedRevisionInfo]] = defaultdict(set) + revisions: set[str] = set() + missing: list[str] = [] + + for raw_target in targets: + target = raw_target.strip() + if not target: + continue + lowered = target.lower() + + if re.fullmatch(r"[0-9a-fA-F]{40}", lowered): + match = revision_lookup.get(lowered) + if match is None: + missing.append(raw_target) + continue + repo, revision = match + selected[repo].add(revision) + revisions.add(revision.commit_hash) + continue + + matched_repo = repo_lookup.get(lowered) + if matched_repo is None: + missing.append(raw_target) + continue + + for revision in matched_repo.revisions: + selected[matched_repo].add(revision) + revisions.add(revision.commit_hash) + + frozen_selected = {repo: frozenset(revs) for repo, revs in selected.items()} + return _DeletionResolution( + revisions=frozenset(revisions), + selected=frozen_selected, + missing=tuple(missing), + ) + + +#### Cache CLI commands + + +@cache_cli.command( + "list | ls", + examples=[ + "hf cache ls", + "hf cache ls --revisions", + 'hf cache ls --filter "size>1GB" --limit 20', + "hf cache ls --format json", + ], +) +def ls( + cache_dir: Annotated[ + Optional[str], + typer.Option( + help="Cache directory to scan (defaults to Hugging Face cache).", + ), + ] = None, + revisions: Annotated[ + bool, + typer.Option( + help="Include revisions in the output instead of aggregated repositories.", + ), + ] = False, + filter: Annotated[ + Optional[list[str]], + typer.Option( + "-f", + "--filter", + help="Filter entries (e.g. 'size>1GB', 'type=model', 'accessed>7d'). Can be used multiple times.", + ), + ] = None, + format: Annotated[ + OutputFormat, + typer.Option( + help="Output format.", + ), + ] = OutputFormat.table, + quiet: Annotated[ + bool, + typer.Option( + "-q", + "--quiet", + help="Print only IDs (repo IDs or revision hashes).", + ), + ] = False, + sort: Annotated[ + Optional[SortOptions], + typer.Option( + help="Sort entries by key. Supported keys: 'accessed', 'modified', 'name', 'size'. " + "Append ':asc' or ':desc' to explicitly set the order (e.g., 'modified:asc'). " + "Defaults: 'accessed', 'modified', 'size' default to 'desc' (newest/biggest first); " + "'name' defaults to 'asc' (alphabetical).", + ), + ] = None, + limit: Annotated[ + Optional[int], + typer.Option( + help="Limit the number of results returned. Returns only the top N entries after sorting.", + ), + ] = None, +) -> None: + """List cached repositories or revisions.""" + try: + hf_cache_info = scan_cache_dir(cache_dir) + except CacheNotFound as exc: + raise CLIError(f"Cache directory not found: {exc.cache_dir}") from exc + + filters = filter or [] + + entries, repo_refs_map = collect_cache_entries(hf_cache_info, include_revisions=revisions) + try: + filter_fns = [compile_cache_filter(expr, repo_refs_map) for expr in filters] + except ValueError as exc: + raise typer.BadParameter(str(exc)) from exc + + now = time.time() + for fn in filter_fns: + entries = [entry for entry in entries if fn(entry[0], entry[1], now)] + + # Apply sorting if requested + if sort: + try: + sort_key_fn, reverse = compile_cache_sort(sort.value) + entries.sort(key=sort_key_fn, reverse=reverse) + except ValueError as exc: + raise typer.BadParameter(str(exc)) from exc + + # Apply limit if requested + if limit is not None: + if limit < 0: + raise typer.BadParameter(f"Limit must be a positive integer, got {limit}.") + entries = entries[:limit] + + if quiet: + for repo, revision in entries: + print(revision.commit_hash if revision is not None else repo.cache_id) + return + + formatters = { + OutputFormat.table: print_cache_entries_table, + OutputFormat.json: print_cache_entries_json, + } + return formatters[format](entries, include_revisions=revisions, repo_refs_map=repo_refs_map) + + +@cache_cli.command( + examples=[ + "hf cache rm model/gpt2", + "hf cache rm ", + "hf cache rm model/gpt2 --dry-run", + "hf cache rm model/gpt2 --yes", + ], +) +def rm( + targets: Annotated[ + list[str], + typer.Argument( + help="One or more repo IDs (e.g. model/bert-base-uncased) or revision hashes to delete.", + ), + ], + cache_dir: Annotated[ + Optional[str], + typer.Option( + help="Cache directory to scan (defaults to Hugging Face cache).", + ), + ] = None, + yes: Annotated[ + bool, + typer.Option( + "-y", + "--yes", + help="Skip confirmation prompt.", + ), + ] = False, + dry_run: Annotated[ + bool, + typer.Option( + help="Preview deletions without removing anything.", + ), + ] = False, +) -> None: + """Remove cached repositories or revisions.""" + try: + hf_cache_info = scan_cache_dir(cache_dir) + except CacheNotFound as exc: + raise CLIError(f"Cache directory not found: {exc.cache_dir}") from exc + + resolution = _resolve_deletion_targets(hf_cache_info, targets) + + if resolution.missing: + print("Could not find the following targets in the cache:") + for entry in resolution.missing: + print(f" - {entry}") + + if len(resolution.revisions) == 0: + print("Nothing to delete.") + raise typer.Exit(code=0) + + strategy = hf_cache_info.delete_revisions(*sorted(resolution.revisions)) + counts = summarize_deletions(resolution.selected) + + summary_parts: list[str] = [] + if counts.repo_count: + summary_parts.append(f"{counts.repo_count} repo(s)") + if counts.partial_revision_count: + summary_parts.append(f"{counts.partial_revision_count} revision(s)") + if not summary_parts: + summary_parts.append(f"{counts.total_revision_count} revision(s)") + + summary_text = " and ".join(summary_parts) + print(f"About to delete {summary_text} totalling {strategy.expected_freed_size_str}.") + print_cache_selected_revisions(resolution.selected) + + if dry_run: + print("Dry run: no files were deleted.") + return + + if not yes and not typer.confirm("Proceed with deletion?", default=False): + print("Deletion cancelled.") + return + + strategy.execute() + counts = summarize_deletions(resolution.selected) + print( + f"Deleted {counts.repo_count} repo(s) and {counts.total_revision_count} revision(s); freed {strategy.expected_freed_size_str}." + ) + + +@cache_cli.command(examples=["hf cache prune", "hf cache prune --dry-run"]) +def prune( + cache_dir: Annotated[ + Optional[str], + typer.Option( + help="Cache directory to scan (defaults to Hugging Face cache).", + ), + ] = None, + yes: Annotated[ + bool, + typer.Option( + "-y", + "--yes", + help="Skip confirmation prompt.", + ), + ] = False, + dry_run: Annotated[ + bool, + typer.Option( + help="Preview deletions without removing anything.", + ), + ] = False, +) -> None: + """Remove detached revisions from the cache.""" + try: + hf_cache_info = scan_cache_dir(cache_dir) + except CacheNotFound as exc: + raise CLIError(f"Cache directory not found: {exc.cache_dir}") from exc + + selected: dict[CachedRepoInfo, frozenset[CachedRevisionInfo]] = {} + revisions: set[str] = set() + for repo in hf_cache_info.repos: + detached = frozenset(revision for revision in repo.revisions if len(revision.refs) == 0) + if not detached: + continue + selected[repo] = detached + revisions.update(revision.commit_hash for revision in detached) + + if len(revisions) == 0: + print("No unreferenced revisions found. Nothing to prune.") + return + + resolution = _DeletionResolution( + revisions=frozenset(revisions), + selected=selected, + missing=(), + ) + strategy = hf_cache_info.delete_revisions(*sorted(resolution.revisions)) + counts = summarize_deletions(selected) + + print( + f"About to delete {counts.total_revision_count} unreferenced revision(s) ({strategy.expected_freed_size_str} total)." + ) + print_cache_selected_revisions(selected) + + if dry_run: + print("Dry run: no files were deleted.") + return + + if not yes and not typer.confirm("Proceed?"): + print("Pruning cancelled.") + return + + strategy.execute() + print(f"Deleted {counts.total_revision_count} unreferenced revision(s); freed {strategy.expected_freed_size_str}.") + + +@cache_cli.command( + examples=[ + "hf cache verify gpt2", + "hf cache verify gpt2 --revision refs/pr/1", + "hf cache verify my-dataset --repo-type dataset", + ], +) +def verify( + repo_id: RepoIdArg, + repo_type: RepoTypeOpt = RepoTypeOpt.model, + revision: RevisionOpt = None, + cache_dir: Annotated[ + Optional[str], + typer.Option( + help="Cache directory to use when verifying files from cache (defaults to Hugging Face cache).", + ), + ] = None, + local_dir: Annotated[ + Optional[str], + typer.Option( + help="If set, verify files under this directory instead of the cache.", + ), + ] = None, + fail_on_missing_files: Annotated[ + bool, + typer.Option( + "--fail-on-missing-files", + help="Fail if some files exist on the remote but are missing locally.", + ), + ] = False, + fail_on_extra_files: Annotated[ + bool, + typer.Option( + "--fail-on-extra-files", + help="Fail if some files exist locally but are not present on the remote revision.", + ), + ] = False, + token: TokenOpt = None, +) -> None: + """Verify checksums for a single repo revision from cache or a local directory. + + Examples: + - Verify main revision in cache: `hf cache verify gpt2` + - Verify specific revision: `hf cache verify gpt2 --revision refs/pr/1` + - Verify dataset: `hf cache verify karpathy/fineweb-edu-100b-shuffle --repo-type dataset` + - Verify local dir: `hf cache verify deepseek-ai/DeepSeek-OCR --local-dir /path/to/repo` + """ + + if local_dir is not None and cache_dir is not None: + print("Cannot pass both --local-dir and --cache-dir. Use one or the other.") + raise typer.Exit(code=2) + + api = get_hf_api(token=token) + + result = api.verify_repo_checksums( + repo_id=repo_id, + repo_type=repo_type.value if hasattr(repo_type, "value") else str(repo_type), + revision=revision, + local_dir=local_dir, + cache_dir=cache_dir, + token=token, + ) + + exit_code = 0 + + has_mismatches = bool(result.mismatches) + if has_mismatches: + print("❌ Checksum verification failed for the following file(s):") + for m in result.mismatches: + print(f" - {m['path']}: expected {m['expected']} ({m['algorithm']}), got {m['actual']}") + exit_code = 1 + + if result.missing_paths: + if fail_on_missing_files: + print("Missing files (present remotely, absent locally):") + for p in result.missing_paths: + print(f" - {p}") + exit_code = 1 + else: + warning = ( + f"{len(result.missing_paths)} remote file(s) are missing locally. " + "Use --fail-on-missing-files for details." + ) + print(f"⚠️ {warning}") + + if result.extra_paths: + if fail_on_extra_files: + print("Extra files (present locally, absent remotely):") + for p in result.extra_paths: + print(f" - {p}") + exit_code = 1 + else: + warning = ( + f"{len(result.extra_paths)} local file(s) do not exist on the remote repo. " + "Use --fail-on-extra-files for details." + ) + print(f"⚠️ {warning}") + + verified_location = result.verified_path + + if exit_code != 0: + print(f"❌ Verification failed for '{repo_id}' ({repo_type.value}) in {verified_location}.") + print(f" Revision: {result.revision}") + raise typer.Exit(code=exit_code) + + print(f"✅ Verified {result.checked_count} file(s) for '{repo_id}' ({repo_type.value}) in {verified_location}") + print(" All checksums match.") diff --git a/huggingface_hub/cli/collections.py b/huggingface_hub/cli/collections.py new file mode 100644 index 0000000000000000000000000000000000000000..572e3c38ab1d608da13ba0ea913989a71261935e --- /dev/null +++ b/huggingface_hub/cli/collections.py @@ -0,0 +1,331 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains commands to interact with collections on the Hugging Face Hub. + +Usage: + # list collections on the Hub + hf collections ls + + # list collections for a specific user + hf collections ls --owner username + + # get info about a collection + hf collections info username/collection-slug + + # create a new collection + hf collections create "My Collection" --description "A collection of models" + + # add an item to a collection + hf collections add-item username/collection-slug username/model-name model + + # delete a collection + hf collections delete username/collection-slug +""" + +import enum +import json +from typing import Annotated, Optional, get_args + +import typer + +from huggingface_hub.hf_api import CollectionItemType_T, CollectionSort_T + +from ._cli_utils import ( + FormatOpt, + LimitOpt, + OutputFormat, + QuietOpt, + TokenOpt, + api_object_to_dict, + get_hf_api, + print_list_output, + typer_factory, +) + + +# Build enums dynamically from Literal types to avoid duplication +_COLLECTION_ITEM_TYPES = get_args(CollectionItemType_T) +CollectionItemType = enum.Enum("CollectionItemType", {t: t for t in _COLLECTION_ITEM_TYPES}, type=str) # type: ignore[misc] + +_COLLECTION_SORT_OPTIONS = get_args(CollectionSort_T) +CollectionSort = enum.Enum("CollectionSort", {s: s for s in _COLLECTION_SORT_OPTIONS}, type=str) # type: ignore[misc] + + +collections_cli = typer_factory(help="Interact with collections on the Hub.") + + +@collections_cli.command( + "list | ls", + examples=[ + "hf collections ls", + "hf collections ls --owner nvidia", + "hf collections ls --item models/teknium/OpenHermes-2.5-Mistral-7B --limit 10", + ], +) +def collections_ls( + owner: Annotated[ + Optional[str], + typer.Option(help="Filter by owner username or organization."), + ] = None, + item: Annotated[ + Optional[str], + typer.Option( + help='Filter collections containing a specific item (e.g., "models/gpt2", "datasets/squad", "papers/2311.12983").' + ), + ] = None, + sort: Annotated[ + Optional[CollectionSort], + typer.Option(help="Sort results by last modified, trending, or upvotes."), + ] = None, + limit: LimitOpt = 10, + format: FormatOpt = OutputFormat.table, + quiet: QuietOpt = False, + token: TokenOpt = None, +) -> None: + """List collections on the Hub.""" + api = get_hf_api(token=token) + sort_key = sort.value if sort else None + results = [ + api_object_to_dict(collection) + for collection in api.list_collections( + owner=owner, + item=item, + sort=sort_key, # type: ignore[arg-type] + limit=limit, + ) + ] + print_list_output(results, format=format, quiet=quiet) + + +@collections_cli.command( + "info", + examples=[ + "hf collections info username/my-collection-slug", + ], +) +def collections_info( + collection_slug: Annotated[str, typer.Argument(help="The collection slug (e.g., 'username/collection-slug').")], + token: TokenOpt = None, +) -> None: + """Get info about a collection on the Hub.""" + api = get_hf_api(token=token) + collection = api.get_collection(collection_slug) + print(json.dumps(api_object_to_dict(collection), indent=2)) + + +@collections_cli.command( + "create", + examples=[ + 'hf collections create "My Models"', + 'hf collections create "My Models" --description "A collection of my favorite models" --private', + 'hf collections create "Org Collection" --namespace my-org', + ], +) +def collections_create( + title: Annotated[str, typer.Argument(help="The title of the collection.")], + namespace: Annotated[ + Optional[str], + typer.Option(help="The namespace (username or organization). Defaults to the authenticated user."), + ] = None, + description: Annotated[ + Optional[str], + typer.Option(help="A description for the collection."), + ] = None, + private: Annotated[ + bool, + typer.Option(help="Create a private collection."), + ] = False, + exists_ok: Annotated[ + bool, + typer.Option(help="Do not raise an error if the collection already exists."), + ] = False, + token: TokenOpt = None, +) -> None: + """Create a new collection on the Hub.""" + api = get_hf_api(token=token) + collection = api.create_collection( + title=title, + namespace=namespace, + description=description, + private=private, + exists_ok=exists_ok, + ) + print(f"Collection created: {collection.url}") + print(json.dumps(api_object_to_dict(collection), indent=2)) + + +@collections_cli.command( + "update", + examples=[ + 'hf collections update username/my-collection --title "New Title"', + 'hf collections update username/my-collection --description "Updated description"', + "hf collections update username/my-collection --private --theme green", + ], +) +def collections_update( + collection_slug: Annotated[str, typer.Argument(help="The collection slug (e.g., 'username/collection-slug').")], + title: Annotated[ + Optional[str], + typer.Option(help="The new title for the collection."), + ] = None, + description: Annotated[ + Optional[str], + typer.Option(help="The new description for the collection."), + ] = None, + position: Annotated[ + Optional[int], + typer.Option(help="The new position of the collection in the owner's list."), + ] = None, + private: Annotated[ + Optional[bool], + typer.Option(help="Whether the collection should be private."), + ] = None, + theme: Annotated[ + Optional[str], + typer.Option(help="The theme color for the collection (e.g., 'green', 'blue')."), + ] = None, + token: TokenOpt = None, +) -> None: + """Update a collection's metadata on the Hub.""" + api = get_hf_api(token=token) + collection = api.update_collection_metadata( + collection_slug=collection_slug, + title=title, + description=description, + position=position, + private=private, + theme=theme, + ) + print(f"Collection updated: {collection.url}") + print(json.dumps(api_object_to_dict(collection), indent=2)) + + +@collections_cli.command( + "delete", + examples=[ + "hf collections delete username/my-collection", + "hf collections delete username/my-collection --missing-ok", + ], +) +def collections_delete( + collection_slug: Annotated[str, typer.Argument(help="The collection slug (e.g., 'username/collection-slug').")], + missing_ok: Annotated[ + bool, + typer.Option(help="Do not raise an error if the collection doesn't exist."), + ] = False, + token: TokenOpt = None, +) -> None: + """Delete a collection from the Hub.""" + api = get_hf_api(token=token) + api.delete_collection(collection_slug, missing_ok=missing_ok) + print(f"Collection deleted: {collection_slug}") + + +@collections_cli.command( + "add-item", + examples=[ + "hf collections add-item username/my-collection moonshotai/kimi-k2 model", + 'hf collections add-item username/my-collection Qwen/DeepPlanning dataset --note "Useful dataset"', + "hf collections add-item username/my-collection Tongyi-MAI/Z-Image space", + ], +) +def collections_add_item( + collection_slug: Annotated[str, typer.Argument(help="The collection slug (e.g., 'username/collection-slug').")], + item_id: Annotated[ + str, typer.Argument(help="The ID of the item to add (repo_id for repos, paper ID for papers).") + ], + item_type: Annotated[ + CollectionItemType, + typer.Argument(help="The type of item (model, dataset, space, paper, or collection)."), + ], + note: Annotated[ + Optional[str], + typer.Option(help="A note to attach to the item (max 500 characters)."), + ] = None, + exists_ok: Annotated[ + bool, + typer.Option(help="Do not raise an error if the item is already in the collection."), + ] = False, + token: TokenOpt = None, +) -> None: + """Add an item to a collection.""" + api = get_hf_api(token=token) + collection = api.add_collection_item( + collection_slug=collection_slug, + item_id=item_id, + item_type=item_type.value, # type: ignore[arg-type] + note=note, + exists_ok=exists_ok, + ) + print(f"Item added to collection: {collection_slug}") + print(json.dumps(api_object_to_dict(collection), indent=2)) + + +@collections_cli.command( + "update-item", + examples=[ + 'hf collections update-item username/my-collection ITEM_OBJECT_ID --note "Updated note"', + "hf collections update-item username/my-collection ITEM_OBJECT_ID --position 0", + ], +) +def collections_update_item( + collection_slug: Annotated[str, typer.Argument(help="The collection slug (e.g., 'username/collection-slug').")], + item_object_id: Annotated[ + str, + typer.Argument(help="The ID of the item in the collection (from 'item_object_id' field, not the repo_id)."), + ], + note: Annotated[ + Optional[str], + typer.Option(help="A new note for the item (max 500 characters)."), + ] = None, + position: Annotated[ + Optional[int], + typer.Option(help="The new position of the item in the collection."), + ] = None, + token: TokenOpt = None, +) -> None: + """Update an item in a collection.""" + api = get_hf_api(token=token) + api.update_collection_item( + collection_slug=collection_slug, + item_object_id=item_object_id, + note=note, + position=position, + ) + print(f"Item updated in collection: {collection_slug}") + + +@collections_cli.command("delete-item") +def collections_delete_item( + collection_slug: Annotated[str, typer.Argument(help="The collection slug (e.g., 'username/collection-slug').")], + item_object_id: Annotated[ + str, + typer.Argument( + help="The ID of the item in the collection (retrieved from `item_object_id` field returned by 'hf collections info'." + ), + ], + missing_ok: Annotated[ + bool, + typer.Option(help="Do not raise an error if the item doesn't exist."), + ] = False, + token: TokenOpt = None, +) -> None: + """Delete an item from a collection.""" + api = get_hf_api(token=token) + api.delete_collection_item( + collection_slug=collection_slug, + item_object_id=item_object_id, + missing_ok=missing_ok, + ) + print(f"Item deleted from collection: {collection_slug}") diff --git a/huggingface_hub/cli/datasets.py b/huggingface_hub/cli/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..8b4eede41064e24d76dd6ac7082abdd2b54930db --- /dev/null +++ b/huggingface_hub/cli/datasets.py @@ -0,0 +1,180 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains commands to interact with datasets on the Hugging Face Hub. + +Usage: + # list datasets on the Hub + hf datasets ls + + # list datasets with a search query + hf datasets ls --search "code" + + # get info about a dataset + hf datasets info HuggingFaceFW/fineweb +""" + +import enum +import json +from typing import Annotated, Optional, get_args + +import typer + +from huggingface_hub._dataset_viewer import execute_raw_sql_query +from huggingface_hub.errors import CLIError, RepositoryNotFoundError, RevisionNotFoundError +from huggingface_hub.hf_api import DatasetSort_T, ExpandDatasetProperty_T + +from ._cli_utils import ( + AuthorOpt, + FilterOpt, + FormatOpt, + LimitOpt, + OutputFormat, + QuietOpt, + RevisionOpt, + SearchOpt, + TokenOpt, + api_object_to_dict, + get_hf_api, + make_expand_properties_parser, + print_list_output, + typer_factory, +) + + +_EXPAND_PROPERTIES = sorted(get_args(ExpandDatasetProperty_T)) +_SORT_OPTIONS = get_args(DatasetSort_T) +DatasetSortEnum = enum.Enum("DatasetSortEnum", {s: s for s in _SORT_OPTIONS}, type=str) # type: ignore[misc] + + +ExpandOpt = Annotated[ + Optional[str], + typer.Option( + help=f"Comma-separated properties to expand. Example: '--expand=downloads,likes,tags'. Valid: {', '.join(_EXPAND_PROPERTIES)}.", + callback=make_expand_properties_parser(_EXPAND_PROPERTIES), + ), +] + + +datasets_cli = typer_factory(help="Interact with datasets on the Hub.") + + +@datasets_cli.command( + "list | ls", + examples=[ + "hf datasets ls", + "hf datasets ls --sort downloads --limit 10", + 'hf datasets ls --search "code"', + ], +) +def datasets_ls( + search: SearchOpt = None, + author: AuthorOpt = None, + filter: FilterOpt = None, + sort: Annotated[ + Optional[DatasetSortEnum], + typer.Option(help="Sort results."), + ] = None, + limit: LimitOpt = 10, + expand: ExpandOpt = None, + format: FormatOpt = OutputFormat.table, + quiet: QuietOpt = False, + token: TokenOpt = None, +) -> None: + """List datasets on the Hub.""" + api = get_hf_api(token=token) + sort_key = sort.value if sort else None + results = [ + api_object_to_dict(dataset_info) + for dataset_info in api.list_datasets( + filter=filter, + author=author, + search=search, + sort=sort_key, + limit=limit, + expand=expand, # type: ignore[arg-type] + ) + ] + print_list_output(results, format=format, quiet=quiet) + + +@datasets_cli.command( + "info", + examples=[ + "hf datasets info HuggingFaceFW/fineweb", + "hf datasets info my-dataset --expand downloads,likes,tags", + ], +) +def datasets_info( + dataset_id: Annotated[str, typer.Argument(help="The dataset ID (e.g. `username/repo-name`).")], + revision: RevisionOpt = None, + expand: ExpandOpt = None, + token: TokenOpt = None, +) -> None: + """Get info about a dataset on the Hub.""" + api = get_hf_api(token=token) + try: + info = api.dataset_info(repo_id=dataset_id, revision=revision, expand=expand) # type: ignore[arg-type] + except RepositoryNotFoundError as e: + raise CLIError(f"Dataset '{dataset_id}' not found.") from e + except RevisionNotFoundError as e: + raise CLIError(f"Revision '{revision}' not found on '{dataset_id}'.") from e + print(json.dumps(api_object_to_dict(info), indent=2)) + + +@datasets_cli.command( + "parquet", + examples=[ + "hf datasets parquet cfahlgren1/hub-stats", + "hf datasets parquet cfahlgren1/hub-stats --subset models", + "hf datasets parquet cfahlgren1/hub-stats --split train", + "hf datasets parquet cfahlgren1/hub-stats --format json", + ], +) +def datasets_parquet( + dataset_id: Annotated[str, typer.Argument(help="The dataset ID (e.g. `username/repo-name`).")], + subset: Annotated[Optional[str], typer.Option("--subset", help="Filter parquet entries by subset/config.")] = None, + split: Annotated[Optional[str], typer.Option(help="Filter parquet entries by split.")] = None, + format: FormatOpt = OutputFormat.table, + quiet: QuietOpt = False, + token: TokenOpt = None, +) -> None: + """List parquet file URLs available for a dataset.""" + api = get_hf_api(token=token) + entries = api.list_dataset_parquet_files(repo_id=dataset_id, config=subset) + filtered = [entry for entry in entries if split is None or entry.split == split] + results = [ + {"subset": entry.config, "split": entry.split, "url": entry.url, "size": entry.size} for entry in filtered + ] + print_list_output(results, format=format, quiet=quiet, id_key="url") + + +@datasets_cli.command( + "sql", + examples=[ + "hf datasets sql \"SELECT COUNT(*) AS rows FROM read_parquet('https://huggingface.co/api/datasets/cfahlgren1/hub-stats/parquet/models/train/0.parquet')\"", + "hf datasets sql \"SELECT * FROM read_parquet('https://huggingface.co/api/datasets/cfahlgren1/hub-stats/parquet/models/train/0.parquet') LIMIT 5\" --format json", + ], +) +def datasets_sql( + sql: Annotated[str, typer.Argument(help="Raw SQL query to execute.")], + format: FormatOpt = OutputFormat.table, + token: TokenOpt = None, +) -> None: + """Execute a raw SQL query with DuckDB against dataset parquet URLs.""" + try: + result = execute_raw_sql_query(sql_query=sql, token=token) + except ImportError as e: + raise CLIError(str(e)) from e + + print_list_output(result, format=format, quiet=False) diff --git a/huggingface_hub/cli/discussions.py b/huggingface_hub/cli/discussions.py new file mode 100644 index 0000000000000000000000000000000000000000..79d58001cde8287c85b26406b5d24c0853b8929f --- /dev/null +++ b/huggingface_hub/cli/discussions.py @@ -0,0 +1,593 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains commands to interact with discussions and pull requests on the Hugging Face Hub.""" + +import enum +import json +import sys +from pathlib import Path +from typing import Annotated, Optional + +import typer + +from huggingface_hub import constants +from huggingface_hub.community import DiscussionComment, DiscussionWithDetails +from huggingface_hub.utils import ANSI + +from ._cli_utils import ( + AuthorOpt, + FormatOpt, + LimitOpt, + OutputFormat, + QuietOpt, + RepoIdArg, + RepoType, + RepoTypeOpt, + TokenOpt, + _format_cell, + api_object_to_dict, + get_hf_api, + print_list_output, + typer_factory, +) + + +class DiscussionStatus(str, enum.Enum): + open = "open" + closed = "closed" + merged = "merged" + draft = "draft" + all = "all" + + +class DiscussionKind(str, enum.Enum): + all = "all" + discussion = "discussion" + pull_request = "pull_request" + + +class InfoFormat(str, enum.Enum): + """Output format for the info command.""" + + text = "text" + json = "json" + + +# "merged" and "draft" are valid Discussion statuses but the Hub API filter +# (DiscussionStatusFilter) only accepts "all", "open", "closed". When the user +# asks for merged/draft we fetch with api_status=None (i.e. all) and filter +# client-side. +_CLIENT_SIDE_STATUSES = {"merged", "draft"} + + +DiscussionNumArg = Annotated[ + int, + typer.Argument( + help="The discussion or pull request number.", + min=1, + ), +] + + +def _format_status(status: str) -> str: + if status == "open": + return ANSI.green("open") + elif status == "closed": + return ANSI.red("closed") + elif status == "merged": + return ANSI.blue("merged") + elif status == "draft": + return ANSI.yellow("draft") + return status + + +def _read_body(body: Optional[str], body_file: Optional[Path]) -> Optional[str]: + """Resolve body text from --body or --body-file (supports '-' for stdin).""" + if body is not None and body_file is not None: + raise typer.BadParameter("Cannot use both --body and --body-file.") + if body_file is not None: + if str(body_file) == "-": + return sys.stdin.read() + return body_file.read_text(encoding="utf-8") + return body + + +def _print_discussion_info(details: DiscussionWithDetails, show_comments: bool = False) -> None: + kind = "Pull Request" if details.is_pull_request else "Discussion" + + print(f"{ANSI.bold(details.title)} {ANSI.gray(f'#{details.num}')}") + parts = [_format_status(details.status), details.author, details.created_at.strftime("%Y-%m-%d %H:%M")] + if details.is_pull_request and details.target_branch: + parts.append(f"into {ANSI.bold(details.target_branch)}") + print(f"{kind}: {' · '.join(parts)}") + + if details.is_pull_request and details.conflicting_files: + if details.conflicting_files is True: + print(ANSI.yellow("Has conflicting files")) + else: + print(ANSI.yellow(f"Conflicting files: {', '.join(details.conflicting_files)}")) + + body = None + comments = [] + for event in details.events: + if isinstance(event, DiscussionComment) and not event.hidden: + if body is None: + body = event + else: + comments.append(event) + + if body and body.content.strip(): + print() + print(body.content.strip()) + + if show_comments and comments: + print() + print(ANSI.gray("─" * 60)) + for comment in comments: + print() + print(f"{ANSI.bold(comment.author)} · {comment.created_at.strftime('%Y-%m-%d %H:%M')}") + print(comment.content.strip()) + elif comments: + print() + print(ANSI.gray(f"{len(comments)} comment{'s' if len(comments) != 1 else ''} (use --comments to show)")) + + print() + print(f"View on Hub: {ANSI.blue(details.url)}") + + +discussions_cli = typer_factory(help="Manage discussions and pull requests on the Hub.") + + +@discussions_cli.command( + "list | ls", + examples=[ + "hf discussions list username/my-model", + "hf discussions list username/my-model --kind pull_request --status merged", + "hf discussions list username/my-dataset --type dataset --status closed", + "hf discussions list username/my-model --author alice --format json", + ], +) +def discussion_list( + repo_id: RepoIdArg, + status: Annotated[ + DiscussionStatus, + typer.Option( + "-s", + "--status", + help="Filter by status (open, closed, merged, draft, all).", + ), + ] = DiscussionStatus.open, + kind: Annotated[ + DiscussionKind, + typer.Option( + "-k", + "--kind", + help="Filter by kind (discussion, pull_request, all).", + ), + ] = DiscussionKind.all, + author: AuthorOpt = None, + limit: LimitOpt = 30, + repo_type: RepoTypeOpt = RepoType.model, + format: FormatOpt = OutputFormat.table, + quiet: QuietOpt = False, + token: TokenOpt = None, +) -> None: + """List discussions and pull requests on a repo.""" + api = get_hf_api(token=token) + + api_status: Optional[constants.DiscussionStatusFilter] + if status == DiscussionStatus.open: + api_status = "open" + elif status == DiscussionStatus.closed: + api_status = "closed" + else: + api_status = None + + api_discussion_type: Optional[constants.DiscussionTypeFilter] + if kind == DiscussionKind.all: + api_discussion_type = None + else: + api_discussion_type = kind.value # type: ignore[assignment] + + discussions = [] + for d in api.get_repo_discussions( + repo_id=repo_id, + author=author, + discussion_type=api_discussion_type, + discussion_status=api_status, + repo_type=repo_type.value, + ): + if status.value in _CLIENT_SIDE_STATUSES and d.status != status.value: + continue + discussions.append(d) + if len(discussions) >= limit: + break + + items = [api_object_to_dict(d) for d in discussions] + + print_list_output( + items, + format=format, + quiet=quiet, + id_key="num", + headers=["num", "title", "is_pull_request", "status", "author", "created_at"], + row_fn=lambda item: [ + f"#{item['num']}", + _format_cell(item.get("title", ""), max_len=50), + "PR" if item.get("is_pull_request") else "", + _format_status(str(item.get("status", ""))), + str(item.get("author", "")), + _format_cell(item.get("created_at", "")), + ], + alignments={"num": "right"}, + ) + + +@discussions_cli.command( + "info", + examples=[ + "hf discussions info username/my-model 5", + "hf discussions info username/my-model 5 --comments", + "hf discussions info username/my-model 5 --diff", + "hf discussions info username/my-model 5 --format json", + ], +) +def discussion_info( + repo_id: RepoIdArg, + num: DiscussionNumArg, + comments: Annotated[ + bool, + typer.Option( + "--comments", + help="Show all comments.", + ), + ] = False, + diff: Annotated[ + bool, + typer.Option( + "--diff", + help="Show the diff (for pull requests).", + ), + ] = False, + no_color: Annotated[ + bool, + typer.Option( + "--no-color", + help="Disable colored output.", + ), + ] = False, + repo_type: RepoTypeOpt = RepoType.model, + format: Annotated[ + InfoFormat, + typer.Option( + help="Output format (text or json).", + ), + ] = InfoFormat.text, + token: TokenOpt = None, +) -> None: + """Get info about a discussion or pull request.""" + import os + + if no_color: + os.environ["NO_COLOR"] = "1" + + api = get_hf_api(token=token) + details = api.get_discussion_details( + repo_id=repo_id, + discussion_num=num, + repo_type=repo_type.value, + ) + + if format == InfoFormat.json: + result = api_object_to_dict(details) + if not diff: + result.pop("diff", None) + print(json.dumps(result, indent=2)) + return + + _print_discussion_info(details, show_comments=comments) + + if diff and details.diff: + print() + print(ANSI.gray("─" * 60)) + print(details.diff) + + +@discussions_cli.command( + "create", + examples=[ + 'hf discussions create username/my-model --title "Bug report"', + 'hf discussions create username/my-model --title "Feature request" --body "Please add X"', + 'hf discussions create username/my-model --title "Fix typo" --pull-request', + 'hf discussions create username/my-dataset --type dataset --title "Data quality issue"', + ], +) +def discussion_create( + repo_id: RepoIdArg, + title: Annotated[ + str, + typer.Option( + "--title", + help="The title of the discussion or pull request.", + ), + ], + body: Annotated[ + Optional[str], + typer.Option( + "--body", + help="The description (supports Markdown).", + ), + ] = None, + body_file: Annotated[ + Optional[Path], + typer.Option( + "--body-file", + help="Read the description from a file. Use '-' for stdin.", + ), + ] = None, + pull_request: Annotated[ + bool, + typer.Option( + "--pull-request", + "--pr", + help="Create a pull request instead of a discussion.", + ), + ] = False, + repo_type: RepoTypeOpt = RepoType.model, + token: TokenOpt = None, +) -> None: + """Create a new discussion or pull request on a repo.""" + description = _read_body(body, body_file) + api = get_hf_api(token=token) + discussion = api.create_discussion( + repo_id=repo_id, + title=title, + description=description, + repo_type=repo_type.value, + pull_request=pull_request, + ) + kind = "pull request" if pull_request else "discussion" + print(f"Created {kind} {ANSI.bold(f'#{discussion.num}')} on {ANSI.bold(repo_id)}") + if pull_request: + print(f"Push changes to: {ANSI.bold(f'refs/pr/{discussion.num}')}") + print(f"View on Hub: {ANSI.blue(discussion.url)}") + + +@discussions_cli.command( + "comment", + examples=[ + 'hf discussions comment username/my-model 5 --body "Thanks for reporting!"', + 'hf discussions comment username/my-model 5 --body "LGTM!"', + ], +) +def discussion_comment( + repo_id: RepoIdArg, + num: DiscussionNumArg, + body: Annotated[ + Optional[str], + typer.Option( + "--body", + help="The comment text (supports Markdown).", + ), + ] = None, + body_file: Annotated[ + Optional[Path], + typer.Option( + "--body-file", + help="Read the comment from a file. Use '-' for stdin.", + ), + ] = None, + repo_type: RepoTypeOpt = RepoType.model, + token: TokenOpt = None, +) -> None: + """Comment on a discussion or pull request.""" + comment = _read_body(body, body_file) + if comment is None: + raise typer.BadParameter("Either --body or --body-file is required.") + api = get_hf_api(token=token) + api.comment_discussion( + repo_id=repo_id, + discussion_num=num, + comment=comment, + repo_type=repo_type.value, + ) + print(f"Commented on #{num} in {ANSI.bold(repo_id)}") + + +@discussions_cli.command( + "close", + examples=[ + "hf discussions close username/my-model 5", + 'hf discussions close username/my-model 5 --comment "Closing as resolved."', + ], +) +def discussion_close( + repo_id: RepoIdArg, + num: DiscussionNumArg, + comment: Annotated[ + Optional[str], + typer.Option( + "--comment", + help="An optional comment to post when closing.", + ), + ] = None, + yes: Annotated[ + bool, + typer.Option( + "--yes", + "-y", + help="Skip confirmation prompt.", + ), + ] = False, + repo_type: RepoTypeOpt = RepoType.model, + token: TokenOpt = None, +) -> None: + """Close a discussion or pull request.""" + if not yes: + confirm = typer.confirm(f"Close #{num} on '{repo_id}'?") + if not confirm: + print("Aborted.") + raise typer.Exit() + api = get_hf_api(token=token) + api.change_discussion_status( + repo_id=repo_id, + discussion_num=num, + new_status="closed", + comment=comment, + repo_type=repo_type.value, + ) + print(f"Closed #{num} in {ANSI.bold(repo_id)}") + + +@discussions_cli.command( + "reopen", + examples=[ + "hf discussions reopen username/my-model 5", + 'hf discussions reopen username/my-model 5 --comment "Reopening for further investigation."', + ], +) +def discussion_reopen( + repo_id: RepoIdArg, + num: DiscussionNumArg, + comment: Annotated[ + Optional[str], + typer.Option( + "--comment", + help="An optional comment to post when reopening.", + ), + ] = None, + yes: Annotated[ + bool, + typer.Option( + "--yes", + "-y", + help="Skip confirmation prompt.", + ), + ] = False, + repo_type: RepoTypeOpt = RepoType.model, + token: TokenOpt = None, +) -> None: + """Reopen a closed discussion or pull request.""" + if not yes: + confirm = typer.confirm(f"Reopen #{num} on '{repo_id}'?") + if not confirm: + print("Aborted.") + raise typer.Exit() + api = get_hf_api(token=token) + api.change_discussion_status( + repo_id=repo_id, + discussion_num=num, + new_status="open", + comment=comment, + repo_type=repo_type.value, + ) + print(f"Reopened #{num} in {ANSI.bold(repo_id)}") + + +@discussions_cli.command( + "rename", + examples=[ + 'hf discussions rename username/my-model 5 "Updated title"', + ], +) +def discussion_rename( + repo_id: RepoIdArg, + num: DiscussionNumArg, + new_title: Annotated[ + str, + typer.Argument( + help="The new title.", + ), + ], + repo_type: RepoTypeOpt = RepoType.model, + token: TokenOpt = None, +) -> None: + """Rename a discussion or pull request.""" + api = get_hf_api(token=token) + api.rename_discussion( + repo_id=repo_id, + discussion_num=num, + new_title=new_title, + repo_type=repo_type.value, + ) + print(f"Renamed #{num} to {ANSI.bold(new_title)} in {ANSI.bold(repo_id)}") + + +@discussions_cli.command( + "merge", + examples=[ + "hf discussions merge username/my-model 5", + 'hf discussions merge username/my-model 5 --comment "Merging, thanks!"', + ], +) +def discussion_merge( + repo_id: RepoIdArg, + num: DiscussionNumArg, + comment: Annotated[ + Optional[str], + typer.Option( + "--comment", + help="An optional comment to post when merging.", + ), + ] = None, + yes: Annotated[ + bool, + typer.Option( + "--yes", + "-y", + help="Skip confirmation prompt.", + ), + ] = False, + repo_type: RepoTypeOpt = RepoType.model, + token: TokenOpt = None, +) -> None: + """Merge a pull request.""" + if not yes: + confirm = typer.confirm(f"Merge #{num} on '{repo_id}'?") + if not confirm: + print("Aborted.") + raise typer.Exit() + api = get_hf_api(token=token) + api.merge_pull_request( + repo_id=repo_id, + discussion_num=num, + comment=comment, + repo_type=repo_type.value, + ) + print(f"Merged #{num} in {ANSI.bold(repo_id)}") + + +@discussions_cli.command( + "diff", + examples=[ + "hf discussions diff username/my-model 5", + ], +) +def discussion_diff( + repo_id: RepoIdArg, + num: DiscussionNumArg, + repo_type: RepoTypeOpt = RepoType.model, + token: TokenOpt = None, +) -> None: + """Show the diff of a pull request.""" + api = get_hf_api(token=token) + details = api.get_discussion_details( + repo_id=repo_id, + discussion_num=num, + repo_type=repo_type.value, + ) + if details.diff: + print(details.diff) + else: + print("No diff available.") diff --git a/huggingface_hub/cli/download.py b/huggingface_hub/cli/download.py new file mode 100644 index 0000000000000000000000000000000000000000..6b2b2480ead8b613a83c25220dc595064ec2ae4e --- /dev/null +++ b/huggingface_hub/cli/download.py @@ -0,0 +1,225 @@ +# coding=utf-8 +# Copyright 202-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains command to download files from the Hub with the CLI. + +Usage: + hf download --help + + # Download file + hf download gpt2 config.json + + # Download entire repo + hf download fffiloni/zeroscope --repo-type=space --revision=refs/pr/78 + + # Download repo with filters + hf download gpt2 --include="*.safetensors" + + # Download with token + hf download Wauplin/private-model --token=hf_*** + + # Download quietly (no progress bar, no warnings, only the returned path) + hf download gpt2 config.json --quiet + + # Download to local dir + hf download gpt2 --local-dir=./models/gpt2 + + # Download a subfolder + hf download HuggingFaceM4/FineVision art/ --repo-type=dataset +""" + +import warnings +from typing import Annotated, Optional, Union + +import typer + +from huggingface_hub import logging +from huggingface_hub._snapshot_download import snapshot_download +from huggingface_hub.errors import CLIError +from huggingface_hub.file_download import DryRunFileInfo, hf_hub_download +from huggingface_hub.utils import _format_size, disable_progress_bars, enable_progress_bars, tabulate + +from ._cli_utils import RepoIdArg, RepoTypeOpt, RevisionOpt, TokenOpt + + +DOWNLOAD_EXAMPLES = [ + "hf download meta-llama/Llama-3.2-1B-Instruct", + "hf download meta-llama/Llama-3.2-1B-Instruct config.json tokenizer.json", + 'hf download meta-llama/Llama-3.2-1B-Instruct --include "*.safetensors" --exclude "*.bin"', + "hf download meta-llama/Llama-3.2-1B-Instruct --local-dir ./models/llama", + "hf download HuggingFaceM4/FineVision art/ --repo-type dataset", +] + + +logger = logging.get_logger(__name__) + + +def download( + repo_id: RepoIdArg, + filenames: Annotated[ + Optional[list[str]], + typer.Argument( + help="Files to download (e.g. `config.json`, `data/metadata.jsonl`).", + ), + ] = None, + repo_type: RepoTypeOpt = RepoTypeOpt.model, + revision: RevisionOpt = None, + include: Annotated[ + Optional[list[str]], + typer.Option( + help="Glob patterns to include from files to download. eg: *.json", + ), + ] = None, + exclude: Annotated[ + Optional[list[str]], + typer.Option( + help="Glob patterns to exclude from files to download.", + ), + ] = None, + cache_dir: Annotated[ + Optional[str], + typer.Option( + help="Directory where to save files.", + ), + ] = None, + local_dir: Annotated[ + Optional[str], + typer.Option( + help="If set, the downloaded file will be placed under this directory. Check out https://huggingface.co/docs/huggingface_hub/guides/download#download-files-to-a-local-folder for more details.", + ), + ] = None, + force_download: Annotated[ + bool, + typer.Option( + help="If True, the files will be downloaded even if they are already cached.", + ), + ] = False, + dry_run: Annotated[ + bool, + typer.Option( + help="If True, perform a dry run without actually downloading the file.", + ), + ] = False, + token: TokenOpt = None, + quiet: Annotated[ + bool, + typer.Option( + help="If True, progress bars are disabled and only the path to the download files is printed.", + ), + ] = False, + max_workers: Annotated[ + int, + typer.Option( + help="Maximum number of workers to use for downloading files. Default is 8.", + ), + ] = 8, +) -> None: + """Download files from the Hub.""" + + def run_download() -> Union[str, DryRunFileInfo, list[DryRunFileInfo]]: + filenames_list = filenames if filenames is not None else [] + + # Separate subfolder patterns (ending with '/') from regular filenames + # Subfolders like "art/" are converted to include patterns like "art/**" + subfolders = [f for f in filenames_list if f.endswith("/")] + subfolder_patterns = [f"{f.rstrip('/')}/**" for f in subfolders] + regular_filenames = [f for f in filenames_list if not f.endswith("/")] + + # Error if subfolder patterns are combined with --include/--exclude + # Guide user to use --include instead of subfolder argument + if len(subfolder_patterns) > 0: + if include is not None and len(include) > 0: + raise CLIError( + f"Cannot combine subfolder argument ('{subfolders[0]}') with `--include`. " + f'Please use `--include "{subfolders[0]}*"` instead.' + ) + if exclude is not None and len(exclude) > 0: + raise CLIError( + f"Cannot combine subfolder argument ('{subfolders[0]}') with `--exclude`. " + f'Please use `--include "{subfolders[0]}*"` with `--exclude` instead.' + ) + + # Warn user if patterns are ignored (only if regular filenames are provided) + if len(regular_filenames) > 0: + if include is not None and len(include) > 0: + warnings.warn("Ignoring `--include` since filenames have being explicitly set.") + if exclude is not None and len(exclude) > 0: + warnings.warn("Ignoring `--exclude` since filenames have being explicitly set.") + + # Single file to download (not a subfolder): use `hf_hub_download` + if len(regular_filenames) == 1 and len(subfolder_patterns) == 0: + return hf_hub_download( + repo_id=repo_id, + repo_type=repo_type.value, + revision=revision, + filename=regular_filenames[0], + cache_dir=cache_dir, + force_download=force_download, + token=token, + local_dir=local_dir, + library_name="huggingface-cli", + dry_run=dry_run, + ) + + # Otherwise: use `snapshot_download` to ensure all files comes from same revision + if len(regular_filenames) == 0 and len(subfolder_patterns) == 0: + # No filenames provided: use include/exclude patterns + allow_patterns = include + ignore_patterns = exclude + else: + # Combine regular filenames and subfolder patterns as allow_patterns + allow_patterns = regular_filenames + subfolder_patterns + ignore_patterns = None + + return snapshot_download( + repo_id=repo_id, + repo_type=repo_type.value, + revision=revision, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + force_download=force_download, + cache_dir=cache_dir, + token=token, + local_dir=local_dir, + library_name="huggingface-cli", + max_workers=max_workers, + dry_run=dry_run, + ) + + def _print_result(result: Union[str, DryRunFileInfo, list[DryRunFileInfo]]) -> None: + if isinstance(result, str): + print(result) + return + + # Print dry run info + if isinstance(result, DryRunFileInfo): + result = [result] + print( + f"[dry-run] Will download {len([r for r in result if r.will_download])} files (out of {len(result)}) totalling {_format_size(sum(r.file_size for r in result if r.will_download))}." + ) + columns = ["File", "Bytes to download"] + items: list[list[Union[str, int]]] = [] + for info in sorted(result, key=lambda x: x.filename): + items.append([info.filename, _format_size(info.file_size) if info.will_download else "-"]) + print(tabulate(items, headers=columns)) + + if quiet: + disable_progress_bars() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + _print_result(run_download()) + enable_progress_bars() + else: + _print_result(run_download()) + logging.set_verbosity_warning() diff --git a/huggingface_hub/cli/extensions.py b/huggingface_hub/cli/extensions.py new file mode 100644 index 0000000000000000000000000000000000000000..a3da528ea4dac74a529accef3f2a92e0faba7c9b --- /dev/null +++ b/huggingface_hub/cli/extensions.py @@ -0,0 +1,568 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains helper utilities for hf CLI extensions.""" + +import errno +import json +import os +import re +import shutil +import subprocess +import venv +from dataclasses import asdict, dataclass +from datetime import datetime, timezone +from pathlib import Path +from typing import Annotated, Literal, Optional + +import typer + +from huggingface_hub.errors import CLIError, CLIExtensionInstallError +from huggingface_hub.utils import StatusLine, get_session, logging + +from ._cli_utils import FormatOpt, OutputFormat, QuietOpt, print_list_output, typer_factory + + +DEFAULT_EXTENSION_OWNER = "huggingface" +EXTENSIONS_ROOT = Path("~/.local/share/hf/extensions") +MANIFEST_FILENAME = "manifest.json" +EXTENSIONS_HELP = ( + "Manage hf CLI extensions.\n\n" + "Security Warning: extensions are third-party executables or Python packages. " + "Install only from sources you trust." +) +extensions_cli = typer_factory(help=EXTENSIONS_HELP) +_EXTENSIONS_DEFAULT_BRANCH = "main" # Fallback when the GitHub API is unreachable. +_EXTENSIONS_GITHUB_TOPIC = "hf-extension" +_EXTENSIONS_DOWNLOAD_TIMEOUT = 10 +_EXTENSIONS_PIP_INSTALL_TIMEOUT = 300 + +logger = logging.get_logger(__name__) + + +@dataclass +class ExtensionManifest: + owner: str + repo: str + repo_id: str + short_name: str + executable_name: str + executable_path: str + type: Literal["binary", "python"] + installed_at: datetime + source: str + description: Optional[str] = None + + @classmethod + def load(cls, path: Path) -> "ExtensionManifest": + manifest_path = path / MANIFEST_FILENAME + if not manifest_path.is_file(): + raise CLIError(f"Manifest file not found at {manifest_path}. Your extension may be corrupted.") + data = json.loads(manifest_path.read_text()) + data["installed_at"] = datetime.fromisoformat(data["installed_at"]) + return ExtensionManifest(**data) + + def save(self, path: Path) -> None: + manifest_path = path / MANIFEST_FILENAME + manifest_path.parent.mkdir(parents=True, exist_ok=True) + data = asdict(self) + data["installed_at"] = self.installed_at.isoformat() + manifest_path.write_text(json.dumps(data, indent=2, sort_keys=True)) + + +@extensions_cli.command( + "install", + examples=[ + "hf extensions install hf-claude", + "hf extensions install hanouticelina/hf-claude", + "hf extensions install alvarobartt/hf-mem", + ], +) +def extension_install( + ctx: typer.Context, + repo_id: Annotated[ + str, + typer.Argument(help="GitHub extension repository in `[OWNER/]hf-` format."), + ], + force: Annotated[bool, typer.Option("--force", help="Overwrite if already installed.")] = False, +) -> None: + """Install an extension from a public GitHub repository. + + Security warning: this installs a third-party executable or Python package. + Install only from sources you trust. + """ + owner, repo_name, short_name = _normalize_repo_id(repo_id) + root_ctx = ctx.find_root() + reserved_commands = set(getattr(root_ctx.command, "commands", {}).keys()) + if short_name in reserved_commands: + raise CLIError( + f"Cannot install extension '{short_name}' because it conflicts with an existing `hf {short_name}` command." + ) + + extension_dir = _get_extension_dir(short_name) + extension_exists = extension_dir.exists() + if extension_exists and not force: + raise CLIError(f"Extension '{short_name}' is already installed. Use --force to overwrite.") + + branch, description = _resolve_github_repo_info(owner=owner, repo_name=repo_name) + + if extension_exists: + shutil.rmtree(extension_dir) + + # Check if the repository has a root executable + try: + binary = _fetch_remote_binary(owner=owner, repo_name=repo_name, branch=branch, short_name=short_name) + except Exception: + binary = None + + # Install extension as binary or Python + if binary is not None: + print("Binary found, installing as binary extension...") + manifest = _install_binary_extension( + owner=owner, + repo_name=repo_name, + short_name=short_name, + extension_dir=extension_dir, + binary=binary, + ) + print(f"Binary extension installed successfully from {owner}/{repo_name}.") + else: + print("Binary not found, trying to install as Python extension...") + manifest = _install_python_extension( + owner=owner, + repo_name=repo_name, + short_name=short_name, + extension_dir=extension_dir, + branch=branch, + ) + print(f"Python extension installed successfully from {owner}/{repo_name}.") + + # Try to fetch a description from repo and save + description = _try_fetch_remote_description( + owner=owner, repo_name=repo_name, branch=branch, candidate_description=description + ) + manifest.description = description + manifest.save(extension_dir) + + print(f"Run it with: hf {short_name}") + + +@extensions_cli.command( + "exec", + context_settings={"allow_extra_args": True, "allow_interspersed_args": False, "ignore_unknown_options": True}, + examples=[ + "hf extensions exec claude -- --help", + "hf extensions exec claude --model zai-org/GLM-5", + ], +) +def extension_exec( + ctx: typer.Context, + name: Annotated[ + str, + typer.Argument(help="Extension name (with or without `hf-` prefix)."), + ], +) -> None: + """Execute an installed extension.""" + short_name = _normalize_extension_name(name) + executable_path = _resolve_installed_executable_path(short_name) + + if not executable_path.is_file(): + raise CLIError(f"Extension '{short_name}' is not installed.") + + exit_code = _execute_extension_binary(executable_path=executable_path, args=list(ctx.args)) + raise typer.Exit(code=exit_code) + + +@extensions_cli.command("list | ls", examples=["hf extensions list"]) +def extension_list(format: FormatOpt = OutputFormat.table, quiet: QuietOpt = False) -> None: + """List installed extension commands.""" + rows = [ + { + "command": f"hf {manifest.short_name}", + "source": str(manifest.repo_id), + "type": str(manifest.type), + "installed": manifest.installed_at.strftime("%Y-%m-%d"), + "description": manifest.description, + } + for manifest in _list_installed_extensions() + ] + print_list_output(rows, format=format, quiet=quiet, id_key="command") + + +@extensions_cli.command("search", examples=["hf extensions search"]) +def extension_search(format: FormatOpt = OutputFormat.table, quiet: QuietOpt = False) -> None: + """Search extensions available on GitHub (tagged with 'hf-extension' topic).""" + response = get_session().get( + "https://api.github.com/search/repositories", + params={"q": f"topic:{_EXTENSIONS_GITHUB_TOPIC}", "sort": "stars", "order": "desc", "per_page": 100}, + follow_redirects=True, + timeout=_EXTENSIONS_DOWNLOAD_TIMEOUT, + ) + response.raise_for_status() + data = response.json() + + installed = {m.short_name for m in _list_installed_extensions()} + + rows = [] + for repo in data.get("items", []): + repo_name = repo["name"] + short_name = repo_name[3:] if repo_name.startswith("hf-") else repo_name + rows.append( + { + "name": short_name, + "repo": repo["full_name"], + "stars": repo.get("stargazers_count", 0), + "description": repo.get("description") or "", + "installed": "yes" if short_name in installed else "", + } + ) + + print_list_output(rows, format=format, quiet=quiet, id_key="repo", alignments={"stars": "right"}) + + +@extensions_cli.command("remove | rm", examples=["hf extensions remove claude"]) +def extension_remove( + name: Annotated[ + str, + typer.Argument(help="Extension name to remove (with or without `hf-` prefix)."), + ], +) -> None: + """Remove an installed extension.""" + short_name = _normalize_extension_name(name) + extension_dir = _get_extension_dir(short_name) + + if not extension_dir.is_dir(): + raise CLIError(f"Extension '{short_name}' is not installed.") + + shutil.rmtree(extension_dir) + print(f"Removed extension '{short_name}'.") + + +### HELPER FUNCTIONS + + +def _list_installed_extensions() -> list[ExtensionManifest]: + """Return manifests for all validly-installed extensions, sorted by directory name.""" + root_dir = EXTENSIONS_ROOT.expanduser() + if not root_dir.is_dir(): + return [] + manifests = [] + for extension_dir in sorted(root_dir.iterdir()): + if not extension_dir.is_dir() or not extension_dir.name.startswith("hf-"): + continue + try: + manifests.append(ExtensionManifest.load(extension_dir)) + except Exception as e: + logger.debug(f"Failed to load manifest for extension '{extension_dir.name}': {e}") + continue + return manifests + + +def list_installed_extensions_for_help() -> list[tuple[str, str]]: + entries = [] + for manifest in _list_installed_extensions(): + tag = f"[extension {manifest.repo_id}]" + help_text = f"{manifest.description} {tag}" if manifest.description is not None else tag + entries.append((manifest.short_name, help_text)) + return entries + + +def dispatch_unknown_top_level_extension(args: list[str], known_commands: set[str]) -> Optional[int]: + if not args: + return None + + command_name = args[0] + if command_name.startswith("-"): + return None + if command_name in known_commands: + return None + + short_name = command_name[3:] if command_name.startswith("hf-") else command_name + if not short_name: + return None + + try: + executable_path = _resolve_installed_executable_path(short_name) + except Exception: + return None + + if not executable_path.is_file(): + return None + + return _execute_extension_binary(executable_path=executable_path, args=list(args[1:])) + + +def _fetch_remote_binary(owner: str, repo_name: str, branch: str, short_name: str) -> bytes: + executable_name = _get_executable_name(short_name) + raw_url = f"https://raw.githubusercontent.com/{owner}/{repo_name}/refs/heads/{branch}/{executable_name}" + response = get_session().get(raw_url, follow_redirects=True, timeout=_EXTENSIONS_DOWNLOAD_TIMEOUT) + response.raise_for_status() + return response.content + + +def _install_binary_extension( + *, owner: str, repo_name: str, short_name: str, extension_dir: Path, binary: bytes +) -> ExtensionManifest: + # Save extension binary + executable_name = _get_executable_name(short_name) + extension_dir.mkdir(parents=True, exist_ok=False) + executable_path = extension_dir / executable_name + executable_path.write_bytes(binary) + + # Make it executable + if os.name != "nt": + os.chmod(executable_path, 0o755) + + # Create manifest + return ExtensionManifest( + owner=owner, + repo=repo_name, + repo_id=f"{owner}/{repo_name}", + short_name=short_name, + executable_name=executable_name, + executable_path=str(executable_path), + type="binary", + installed_at=datetime.now(timezone.utc), + source=f"https://github.com/{owner}/{repo_name}", + ) + + +def _install_python_extension( + *, owner: str, repo_name: str, short_name: str, extension_dir: Path, branch: str +) -> ExtensionManifest: + source_url = f"https://github.com/{owner}/{repo_name}/archive/refs/heads/{branch}.zip" + venv_dir = extension_dir / "venv" + installed = False + + status = StatusLine() + try: + status.update(f"Creating virtual environment in {venv_dir}") + if extension_dir.exists(): + shutil.rmtree(extension_dir, ignore_errors=True) + extension_dir.mkdir(parents=True, exist_ok=False) + venv.EnvBuilder(with_pip=True).create(str(venv_dir)) + status.done(f"Virtual environment created in {venv_dir}") + + status.update(f"Installing package from {source_url}") + venv_python = _get_venv_python_path(venv_dir) + subprocess.run( + [ + str(venv_python), + "-m", + "pip", + "install", + "--disable-pip-version-check", + "--no-input", + source_url, + ], + check=True, + timeout=_EXTENSIONS_PIP_INSTALL_TIMEOUT, + ) + status.done(f"Package installed from {source_url}") + + executable_name = _get_executable_name(short_name) + venv_executable = _get_venv_extension_executable_path(venv_dir, short_name) + if not venv_executable.is_file(): + raise CLIError( + f"Installed package from '{owner}/{repo_name}' does not expose the required console script " + f"'{executable_name}'." + ) + + manifest = ExtensionManifest( + owner=owner, + repo=repo_name, + repo_id=f"{owner}/{repo_name}", + short_name=short_name, + executable_name=executable_name, + executable_path=str(venv_executable.resolve()), + type="python", + installed_at=datetime.now(timezone.utc), + source=f"https://github.com/{owner}/{repo_name}", + ) + installed = True + return manifest + except CLIError: + raise + except subprocess.TimeoutExpired as e: + raise CLIExtensionInstallError( + f"Pip install timed out after {_EXTENSIONS_PIP_INSTALL_TIMEOUT}s for '{owner}/{repo_name}'. " + "See pip output above for details." + ) from e + except subprocess.CalledProcessError as e: + raise CLIExtensionInstallError( + f"Failed to install pip package from '{owner}/{repo_name}' (exit code {e.returncode}). " + "See pip output above for details." + ) from e + except Exception as e: + raise CLIExtensionInstallError(f"Failed to set up pip extension from '{owner}/{repo_name}': {e}") from e + finally: + if not installed: + shutil.rmtree(extension_dir, ignore_errors=True) + + +def _try_fetch_remote_description( + owner: str, repo_name: str, branch: str, candidate_description: Optional[str] +) -> Optional[str]: + """Try to fetch project description either from: + - manifest.json + - pyproject.toml + + Only best effort, no error handling. + """ + # from manifest.json + try: + response = get_session().get( + f"https://raw.githubusercontent.com/{owner}/{repo_name}/refs/heads/{branch}/{MANIFEST_FILENAME}", + follow_redirects=True, + ) + response.raise_for_status() + data = response.json() + description = data.get("description") + if isinstance(description, str): + return description + except Exception: + pass + + # from pyproject.toml + try: + response = get_session().get( + f"https://raw.githubusercontent.com/{owner}/{repo_name}/refs/heads/{branch}/pyproject.toml", + follow_redirects=True, + ) + response.raise_for_status() + + # Weak parser but ok for "best effort" + for line in response.text.splitlines(): + line = line.strip() + if line.startswith("description"): + _, _, value = line.partition("=") + return value.strip().strip("\"'") + except Exception: + pass + + # fallback to value fetched from GH API directly + return candidate_description + + +def _get_extensions_root() -> Path: + root_dir = EXTENSIONS_ROOT.expanduser() + root_dir.mkdir(parents=True, exist_ok=True) + return root_dir + + +def _get_extension_dir(short_name: str) -> Path: + safe_name = _validate_extension_short_name(short_name, original_input=short_name) + root = _get_extensions_root().resolve() + target = (root / f"hf-{safe_name}").resolve() + if root not in target.parents: + raise CLIError(f"Invalid extension name '{short_name}'.") + return target + + +def _resolve_github_repo_info(owner: str, repo_name: str) -> tuple[str, Optional[str]]: + try: + response = get_session().get( + f"https://api.github.com/repos/{owner}/{repo_name}", + follow_redirects=True, + timeout=_EXTENSIONS_DOWNLOAD_TIMEOUT, + ) + response.raise_for_status() + data = response.json() + return data["default_branch"], data.get("description") + except Exception: + return _EXTENSIONS_DEFAULT_BRANCH, None + + +def _get_executable_name(short_name: str) -> str: + name = f"hf-{short_name}" + if os.name == "nt": + name += ".exe" + return name + + +def _resolve_installed_executable_path(short_name: str) -> Path: + extension_dir = _get_extension_dir(short_name) + manifest = ExtensionManifest.load(extension_dir) + return Path(manifest.executable_path).expanduser() + + +def _get_venv_python_path(venv_dir: Path) -> Path: + if os.name == "nt": + return venv_dir / "Scripts" / "python.exe" + return venv_dir / "bin" / "python" + + +def _get_venv_extension_executable_path(venv_dir: Path, short_name: str) -> Path: + executable_name = _get_executable_name(short_name) + if os.name == "nt": + return venv_dir / "Scripts" / executable_name + return venv_dir / "bin" / executable_name + + +_ALLOWED_EXTENSION_NAME = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]*$") + + +def _validate_extension_short_name(short_name: str, *, original_input: str) -> str: + name = short_name.strip() + if not name: + raise CLIError("Extension name cannot be empty.") + if any(sep in name for sep in ("/", "\\")): + raise CLIError(f"Invalid extension name '{original_input}'.") + if ".." in name or ":" in name: + raise CLIError(f"Invalid extension name '{original_input}'.") + if not _ALLOWED_EXTENSION_NAME.fullmatch(name): + raise CLIError( + f"Invalid extension name '{original_input}'. Allowed characters: letters, digits, '.', '_' and '-'." + ) + return name + + +def _normalize_repo_id(repo_id: str) -> tuple[str, str, str]: + if "://" in repo_id: + raise CLIError("Only GitHub repositories in `[OWNER/]hf-` format are supported.") + + parts = repo_id.split("/") + if len(parts) == 1: + owner = DEFAULT_EXTENSION_OWNER + repo_name = parts[0] + elif len(parts) == 2 and all(parts): + owner, repo_name = parts + else: + raise CLIError(f"Expected `[OWNER/]REPO` format, got '{repo_id}'.") + + if not repo_name.startswith("hf-"): + raise CLIError(f"Extension repository name must start with 'hf-', got '{repo_name}'.") + + short_name = repo_name[3:] + if not short_name: + raise CLIError("Invalid extension repository name 'hf-'.") + _validate_extension_short_name(short_name, original_input=repo_id) + + return owner, repo_name, short_name + + +def _normalize_extension_name(name: str) -> str: + candidate = name.strip() + if not candidate: + raise CLIError("Extension name cannot be empty.") + normalized = candidate[3:] if candidate.startswith("hf-") else candidate + return _validate_extension_short_name(normalized, original_input=name) + + +def _execute_extension_binary(executable_path: Path, args: list[str]) -> int: + try: + return subprocess.call([str(executable_path)] + args) + except OSError as e: + if os.name == "nt" or e.errno != errno.ENOEXEC: + raise + return subprocess.call(["sh", str(executable_path)] + args) diff --git a/huggingface_hub/cli/hf.py b/huggingface_hub/cli/hf.py new file mode 100644 index 0000000000000000000000000000000000000000..decb25e92bd12eba10e638fc6be21b7123674e2b --- /dev/null +++ b/huggingface_hub/cli/hf.py @@ -0,0 +1,127 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import traceback +from typing import Annotated, Optional + +import typer + +from huggingface_hub import __version__, constants +from huggingface_hub.cli._cli_utils import check_cli_update, fallback_typer_group_factory, typer_factory +from huggingface_hub.cli._errors import format_known_exception +from huggingface_hub.cli.auth import auth_cli +from huggingface_hub.cli.buckets import buckets_cli, sync +from huggingface_hub.cli.cache import cache_cli +from huggingface_hub.cli.collections import collections_cli +from huggingface_hub.cli.datasets import datasets_cli +from huggingface_hub.cli.discussions import discussions_cli +from huggingface_hub.cli.download import DOWNLOAD_EXAMPLES, download +from huggingface_hub.cli.extensions import ( + dispatch_unknown_top_level_extension, + extensions_cli, + list_installed_extensions_for_help, +) +from huggingface_hub.cli.inference_endpoints import ie_cli +from huggingface_hub.cli.jobs import jobs_cli +from huggingface_hub.cli.lfs import lfs_enable_largefiles, lfs_multipart_upload +from huggingface_hub.cli.models import models_cli +from huggingface_hub.cli.papers import papers_cli +from huggingface_hub.cli.repo_files import repo_files_cli +from huggingface_hub.cli.repos import repos_cli +from huggingface_hub.cli.skills import skills_cli +from huggingface_hub.cli.spaces import spaces_cli +from huggingface_hub.cli.system import env, version +from huggingface_hub.cli.upload import UPLOAD_EXAMPLES, upload +from huggingface_hub.cli.upload_large_folder import UPLOAD_LARGE_FOLDER_EXAMPLES, upload_large_folder +from huggingface_hub.cli.webhooks import webhooks_cli +from huggingface_hub.utils import ANSI, logging + + +app = typer_factory( + help="Hugging Face Hub CLI", + cls=fallback_typer_group_factory( + dispatch_unknown_top_level_extension, + extra_commands_provider=list_installed_extensions_for_help, + ), +) + + +def _version_callback(value: bool) -> None: + if value: + print(__version__) + raise typer.Exit() + + +@app.callback(invoke_without_command=True) +def app_callback( + version: Annotated[ + Optional[bool], typer.Option("--version", callback=_version_callback, is_eager=True, hidden=True) + ] = None, +) -> None: + pass + + +# top level single commands (defined in their respective files) +app.command()(sync) +app.command(examples=DOWNLOAD_EXAMPLES)(download) +app.command(examples=UPLOAD_EXAMPLES)(upload) +app.command(examples=UPLOAD_LARGE_FOLDER_EXAMPLES)(upload_large_folder) + +app.command(topic="help")(env) +app.command(topic="help")(version) + +app.command(hidden=True)(lfs_enable_largefiles) +app.command(hidden=True)(lfs_multipart_upload) + +# command groups +app.add_typer(auth_cli, name="auth") +app.add_typer(buckets_cli, name="buckets") +app.add_typer(cache_cli, name="cache") +app.add_typer(collections_cli, name="collections") +app.add_typer(datasets_cli, name="datasets") +app.add_typer(discussions_cli, name="discussions") +app.add_typer(jobs_cli, name="jobs") +app.add_typer(models_cli, name="models") +app.add_typer(papers_cli, name="papers") +app.add_typer(repos_cli, name="repos | repo") +app.add_typer(repo_files_cli, name="repo-files", hidden=True) +app.add_typer(skills_cli, name="skills") +app.add_typer(spaces_cli, name="spaces") +app.add_typer(webhooks_cli, name="webhooks") +app.add_typer(ie_cli, name="endpoints") +app.add_typer(extensions_cli, name="extensions | ext") + + +def main(): + if not constants.HF_DEBUG: + logging.set_verbosity_info() + check_cli_update("huggingface_hub") + + try: + app() + except Exception as e: + message = format_known_exception(e) + if message: + print(f"Error: {message}", file=sys.stderr) + if constants.HF_DEBUG: + traceback.print_exc() + else: + print(ANSI.gray("Set HF_DEBUG=1 as environment variable for full traceback.")) + sys.exit(1) + raise + + +if __name__ == "__main__": + main() diff --git a/huggingface_hub/cli/inference_endpoints.py b/huggingface_hub/cli/inference_endpoints.py new file mode 100644 index 0000000000000000000000000000000000000000..d2b4ddf34331a2da5c87edf4f03a13006100f41d --- /dev/null +++ b/huggingface_hub/cli/inference_endpoints.py @@ -0,0 +1,463 @@ +"""CLI commands for Hugging Face Inference Endpoints.""" + +import json +from typing import Annotated, Any, Optional + +import typer + +from huggingface_hub._inference_endpoints import InferenceEndpoint, InferenceEndpointScalingMetric +from huggingface_hub.errors import HfHubHTTPError + +from ._cli_utils import ( + FormatOpt, + OutputFormat, + QuietOpt, + TokenOpt, + get_hf_api, + print_list_output, + typer_factory, +) + + +ie_cli = typer_factory(help="Manage Hugging Face Inference Endpoints.") + +catalog_app = typer_factory(help="Interact with the Inference Endpoints catalog.") + + +NameArg = Annotated[ + str, + typer.Argument(help="Endpoint name."), +] +NameOpt = Annotated[ + Optional[str], + typer.Option(help="Endpoint name."), +] + +NamespaceOpt = Annotated[ + Optional[str], + typer.Option( + help="The namespace associated with the Inference Endpoint. Defaults to the current user's namespace.", + ), +] + + +def _print_endpoint(endpoint: InferenceEndpoint) -> None: + typer.echo(json.dumps(endpoint.raw, indent=2, sort_keys=True)) + + +@ie_cli.command("list | ls", examples=["hf endpoints ls", "hf endpoints ls --namespace my-org"]) +def ls( + namespace: NamespaceOpt = None, + format: FormatOpt = OutputFormat.table, + quiet: QuietOpt = False, + token: TokenOpt = None, +) -> None: + """Lists all Inference Endpoints for the given namespace.""" + api = get_hf_api(token=token) + try: + endpoints = api.list_inference_endpoints(namespace=namespace, token=token) + except HfHubHTTPError as error: + typer.echo(f"Listing failed: {error}") + raise typer.Exit(code=error.response.status_code) from error + + results = [endpoint.raw for endpoint in endpoints] + + def row_fn(item: dict[str, Any]) -> list[str]: + status = item.get("status", {}) + model = item.get("model", {}) + compute = item.get("compute", {}) + provider = item.get("provider", {}) + return [ + str(item.get("name", "")), + str(model.get("repository", "") if isinstance(model, dict) else ""), + str(status.get("state", "") if isinstance(status, dict) else ""), + str(model.get("task", "") if isinstance(model, dict) else ""), + str(model.get("framework", "") if isinstance(model, dict) else ""), + str(compute.get("instanceType", "") if isinstance(compute, dict) else ""), + str(provider.get("vendor", "") if isinstance(provider, dict) else ""), + str(provider.get("region", "") if isinstance(provider, dict) else ""), + ] + + print_list_output( + items=results, + format=format, + quiet=quiet, + id_key="name", + headers=["NAME", "MODEL", "STATUS", "TASK", "FRAMEWORK", "INSTANCE", "VENDOR", "REGION"], + row_fn=row_fn, + ) + + +@ie_cli.command(name="deploy", examples=["hf endpoints deploy my-endpoint --repo gpt2 --framework pytorch ..."]) +def deploy( + name: NameArg, + repo: Annotated[ + str, + typer.Option( + help="The name of the model repository associated with the Inference Endpoint (e.g. 'openai/gpt-oss-120b').", + ), + ], + framework: Annotated[ + str, + typer.Option( + help="The machine learning framework used for the model (e.g. 'vllm').", + ), + ], + accelerator: Annotated[ + str, + typer.Option( + help="The hardware accelerator to be used for inference (e.g. 'cpu').", + ), + ], + instance_size: Annotated[ + str, + typer.Option( + help="The size or type of the instance to be used for hosting the model (e.g. 'x4').", + ), + ], + instance_type: Annotated[ + str, + typer.Option( + help="The cloud instance type where the Inference Endpoint will be deployed (e.g. 'intel-icl').", + ), + ], + region: Annotated[ + str, + typer.Option( + help="The cloud region in which the Inference Endpoint will be created (e.g. 'us-east-1').", + ), + ], + vendor: Annotated[ + str, + typer.Option( + help="The cloud provider or vendor where the Inference Endpoint will be hosted (e.g. 'aws').", + ), + ], + *, + namespace: NamespaceOpt = None, + task: Annotated[ + Optional[str], + typer.Option( + help="The task on which to deploy the model (e.g. 'text-classification').", + ), + ] = None, + token: TokenOpt = None, + min_replica: Annotated[ + int, + typer.Option( + help="The minimum number of replicas (instances) to keep running for the Inference Endpoint.", + ), + ] = 1, + max_replica: Annotated[ + int, + typer.Option( + help="The maximum number of replicas (instances) to scale to for the Inference Endpoint.", + ), + ] = 1, + scale_to_zero_timeout: Annotated[ + Optional[int], + typer.Option( + help="The duration in minutes before an inactive endpoint is scaled to zero.", + ), + ] = None, + scaling_metric: Annotated[ + Optional[InferenceEndpointScalingMetric], + typer.Option( + help="The metric reference for scaling.", + ), + ] = None, + scaling_threshold: Annotated[ + Optional[float], + typer.Option( + help="The scaling metric threshold used to trigger a scale up. Ignored when scaling metric is not provided.", + ), + ] = None, +) -> None: + """Deploy an Inference Endpoint from a Hub repository.""" + api = get_hf_api(token=token) + endpoint = api.create_inference_endpoint( + name=name, + repository=repo, + framework=framework, + accelerator=accelerator, + instance_size=instance_size, + instance_type=instance_type, + region=region, + vendor=vendor, + namespace=namespace, + task=task, + token=token, + min_replica=min_replica, + max_replica=max_replica, + scaling_metric=scaling_metric, + scaling_threshold=scaling_threshold, + scale_to_zero_timeout=scale_to_zero_timeout, + ) + + _print_endpoint(endpoint) + + +@catalog_app.command(name="deploy", examples=["hf endpoints catalog deploy --repo meta-llama/Llama-3.2-1B-Instruct"]) +def deploy_from_catalog( + repo: Annotated[ + str, + typer.Option( + help="The name of the model repository associated with the Inference Endpoint (e.g. 'openai/gpt-oss-120b').", + ), + ], + name: NameOpt = None, + accelerator: Annotated[ + Optional[str], + typer.Option( + help="The hardware accelerator to be used for inference (e.g. 'cpu', 'gpu', 'neuron').", + ), + ] = None, + namespace: NamespaceOpt = None, + token: TokenOpt = None, +) -> None: + """Deploy an Inference Endpoint from the Model Catalog.""" + api = get_hf_api(token=token) + try: + endpoint = api.create_inference_endpoint_from_catalog( + repo_id=repo, + name=name, + accelerator=accelerator, + namespace=namespace, + token=token, + ) + except HfHubHTTPError as error: + typer.echo(f"Deployment failed: {error}") + raise typer.Exit(code=error.response.status_code) from error + + _print_endpoint(endpoint) + + +def list_catalog( + token: TokenOpt = None, +) -> None: + """List available Catalog models.""" + api = get_hf_api(token=token) + try: + models = api.list_inference_catalog(token=token) + except HfHubHTTPError as error: + typer.echo(f"Catalog fetch failed: {error}") + raise typer.Exit(code=error.response.status_code) from error + + typer.echo(json.dumps({"models": models}, indent=2, sort_keys=True)) + + +catalog_app.command(name="list | ls", examples=["hf endpoints catalog ls"])(list_catalog) +ie_cli.command(name="list-catalog", hidden=True)(list_catalog) + + +ie_cli.add_typer(catalog_app, name="catalog") + + +@ie_cli.command(examples=["hf endpoints describe my-endpoint"]) +def describe( + name: NameArg, + namespace: NamespaceOpt = None, + token: TokenOpt = None, +) -> None: + """Get information about an existing endpoint.""" + api = get_hf_api(token=token) + try: + endpoint = api.get_inference_endpoint(name=name, namespace=namespace, token=token) + except HfHubHTTPError as error: + typer.echo(f"Fetch failed: {error}") + raise typer.Exit(code=error.response.status_code) from error + + _print_endpoint(endpoint) + + +@ie_cli.command(examples=["hf endpoints update my-endpoint --min-replica 2"]) +def update( + name: NameArg, + namespace: NamespaceOpt = None, + repo: Annotated[ + Optional[str], + typer.Option( + help="The name of the model repository associated with the Inference Endpoint (e.g. 'openai/gpt-oss-120b').", + ), + ] = None, + accelerator: Annotated[ + Optional[str], + typer.Option( + help="The hardware accelerator to be used for inference (e.g. 'cpu').", + ), + ] = None, + instance_size: Annotated[ + Optional[str], + typer.Option( + help="The size or type of the instance to be used for hosting the model (e.g. 'x4').", + ), + ] = None, + instance_type: Annotated[ + Optional[str], + typer.Option( + help="The cloud instance type where the Inference Endpoint will be deployed (e.g. 'intel-icl').", + ), + ] = None, + framework: Annotated[ + Optional[str], + typer.Option( + help="The machine learning framework used for the model (e.g. 'custom').", + ), + ] = None, + revision: Annotated[ + Optional[str], + typer.Option( + help="The specific model revision to deploy on the Inference Endpoint (e.g. '6c0e6080953db56375760c0471a8c5f2929baf11').", + ), + ] = None, + task: Annotated[ + Optional[str], + typer.Option( + help="The task on which to deploy the model (e.g. 'text-classification').", + ), + ] = None, + min_replica: Annotated[ + Optional[int], + typer.Option( + help="The minimum number of replicas (instances) to keep running for the Inference Endpoint.", + ), + ] = None, + max_replica: Annotated[ + Optional[int], + typer.Option( + help="The maximum number of replicas (instances) to scale to for the Inference Endpoint.", + ), + ] = None, + scale_to_zero_timeout: Annotated[ + Optional[int], + typer.Option( + help="The duration in minutes before an inactive endpoint is scaled to zero.", + ), + ] = None, + scaling_metric: Annotated[ + Optional[InferenceEndpointScalingMetric], + typer.Option( + help="The metric reference for scaling.", + ), + ] = None, + scaling_threshold: Annotated[ + Optional[float], + typer.Option( + help="The scaling metric threshold used to trigger a scale up. Ignored when scaling metric is not provided.", + ), + ] = None, + token: TokenOpt = None, +) -> None: + """Update an existing endpoint.""" + api = get_hf_api(token=token) + try: + endpoint = api.update_inference_endpoint( + name=name, + namespace=namespace, + repository=repo, + framework=framework, + revision=revision, + task=task, + accelerator=accelerator, + instance_size=instance_size, + instance_type=instance_type, + min_replica=min_replica, + max_replica=max_replica, + scale_to_zero_timeout=scale_to_zero_timeout, + scaling_metric=scaling_metric, + scaling_threshold=scaling_threshold, + token=token, + ) + except HfHubHTTPError as error: + typer.echo(f"Update failed: {error}") + raise typer.Exit(code=error.response.status_code) from error + _print_endpoint(endpoint) + + +@ie_cli.command(examples=["hf endpoints delete my-endpoint"]) +def delete( + name: NameArg, + namespace: NamespaceOpt = None, + yes: Annotated[ + bool, + typer.Option("--yes", help="Skip confirmation prompts."), + ] = False, + token: TokenOpt = None, +) -> None: + """Delete an Inference Endpoint permanently.""" + if not yes: + confirmation = typer.prompt(f"Delete endpoint '{name}'? Type the name to confirm.") + if confirmation != name: + typer.echo("Aborted.") + raise typer.Exit(code=2) + + api = get_hf_api(token=token) + try: + api.delete_inference_endpoint(name=name, namespace=namespace, token=token) + except HfHubHTTPError as error: + typer.echo(f"Delete failed: {error}") + raise typer.Exit(code=error.response.status_code) from error + + typer.echo(f"Deleted '{name}'.") + + +@ie_cli.command(examples=["hf endpoints pause my-endpoint"]) +def pause( + name: NameArg, + namespace: NamespaceOpt = None, + token: TokenOpt = None, +) -> None: + """Pause an Inference Endpoint.""" + api = get_hf_api(token=token) + try: + endpoint = api.pause_inference_endpoint(name=name, namespace=namespace, token=token) + except HfHubHTTPError as error: + typer.echo(f"Pause failed: {error}") + raise typer.Exit(code=error.response.status_code) from error + + _print_endpoint(endpoint) + + +@ie_cli.command(examples=["hf endpoints resume my-endpoint"]) +def resume( + name: NameArg, + namespace: NamespaceOpt = None, + fail_if_already_running: Annotated[ + bool, + typer.Option( + "--fail-if-already-running", + help="If `True`, the method will raise an error if the Inference Endpoint is already running.", + ), + ] = False, + token: TokenOpt = None, +) -> None: + """Resume an Inference Endpoint.""" + api = get_hf_api(token=token) + try: + endpoint = api.resume_inference_endpoint( + name=name, + namespace=namespace, + token=token, + running_ok=not fail_if_already_running, + ) + except HfHubHTTPError as error: + typer.echo(f"Resume failed: {error}") + raise typer.Exit(code=error.response.status_code) from error + _print_endpoint(endpoint) + + +@ie_cli.command(examples=["hf endpoints scale-to-zero my-endpoint"]) +def scale_to_zero( + name: NameArg, + namespace: NamespaceOpt = None, + token: TokenOpt = None, +) -> None: + """Scale an Inference Endpoint to zero.""" + api = get_hf_api(token=token) + try: + endpoint = api.scale_to_zero_inference_endpoint(name=name, namespace=namespace, token=token) + except HfHubHTTPError as error: + typer.echo(f"Scale To Zero failed: {error}") + raise typer.Exit(code=error.response.status_code) from error + + _print_endpoint(endpoint) diff --git a/huggingface_hub/cli/jobs.py b/huggingface_hub/cli/jobs.py new file mode 100644 index 0000000000000000000000000000000000000000..d965d8d71ec41a7e9010b4a200c449222c118f86 --- /dev/null +++ b/huggingface_hub/cli/jobs.py @@ -0,0 +1,1213 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains commands to interact with jobs on the Hugging Face Hub. + +Usage: + # run a job + hf jobs run + + # List running or completed jobs + hf jobs ps [-a] [-f key=value] [--format table|json|TEMPLATE] [-q] + + # Print logs from a job (non-blocking) + hf jobs logs + + # Stream logs from a job (blocking, like `docker logs -f`) + hf jobs logs -f + + # Stream resources usage stats and metrics from a job + hf jobs stats + + # Inspect detailed information about a job + hf jobs inspect + + # Cancel a running job + hf jobs cancel + + # List available hardware options + hf jobs hardware + + # Run a UV script + hf jobs uv run )', bygroups(using(XmlLexer), + Other, + using(XmlLexer))), + (r'(.+?)(?=<)', using(XmlLexer)), + (r'.+', using(XmlLexer)), + ], + } + + +# TODO support multiple languages within the same source file +class CSharpAspxLexer(DelegatingLexer): + """ + Lexer for highlighting C# within ASP.NET pages. + """ + + name = 'aspx-cs' + aliases = ['aspx-cs'] + filenames = ['*.aspx', '*.asax', '*.ascx', '*.ashx', '*.asmx', '*.axd'] + mimetypes = [] + url = 'https://dotnet.microsoft.com/en-us/apps/aspnet' + version_added = '' + + def __init__(self, **options): + super().__init__(CSharpLexer, GenericAspxLexer, **options) + + def analyse_text(text): + if re.search(r'Page\s*Language="C#"', text, re.I) is not None: + return 0.2 + elif re.search(r'script[^>]+language=["\']C#', text, re.I) is not None: + return 0.15 + + +class VbNetAspxLexer(DelegatingLexer): + """ + Lexer for highlighting Visual Basic.net within ASP.NET pages. + """ + + name = 'aspx-vb' + aliases = ['aspx-vb'] + filenames = ['*.aspx', '*.asax', '*.ascx', '*.ashx', '*.asmx', '*.axd'] + mimetypes = [] + url = 'https://dotnet.microsoft.com/en-us/apps/aspnet' + version_added = '' + + def __init__(self, **options): + super().__init__(VbNetLexer, GenericAspxLexer, **options) + + def analyse_text(text): + if re.search(r'Page\s*Language="Vb"', text, re.I) is not None: + return 0.2 + elif re.search(r'script[^>]+language=["\']vb', text, re.I) is not None: + return 0.15 + + +# Very close to functional.OcamlLexer +class FSharpLexer(RegexLexer): + """ + For the F# language (version 3.0). + """ + + name = 'F#' + url = 'https://fsharp.org/' + aliases = ['fsharp', 'f#'] + filenames = ['*.fs', '*.fsi', '*.fsx'] + mimetypes = ['text/x-fsharp'] + version_added = '1.5' + + keywords = [ + 'abstract', 'as', 'assert', 'base', 'begin', 'class', 'default', + 'delegate', 'do!', 'do', 'done', 'downcast', 'downto', 'elif', 'else', + 'end', 'exception', 'extern', 'false', 'finally', 'for', 'function', + 'fun', 'global', 'if', 'inherit', 'inline', 'interface', 'internal', + 'in', 'lazy', 'let!', 'let', 'match', 'member', 'module', 'mutable', + 'namespace', 'new', 'null', 'of', 'open', 'override', 'private', 'public', + 'rec', 'return!', 'return', 'select', 'static', 'struct', 'then', 'to', + 'true', 'try', 'type', 'upcast', 'use!', 'use', 'val', 'void', 'when', + 'while', 'with', 'yield!', 'yield', + ] + # Reserved words; cannot hurt to color them as keywords too. + keywords += [ + 'atomic', 'break', 'checked', 'component', 'const', 'constraint', + 'constructor', 'continue', 'eager', 'event', 'external', 'fixed', + 'functor', 'include', 'method', 'mixin', 'object', 'parallel', + 'process', 'protected', 'pure', 'sealed', 'tailcall', 'trait', + 'virtual', 'volatile', + ] + keyopts = [ + '!=', '#', '&&', '&', r'\(', r'\)', r'\*', r'\+', ',', r'-\.', + '->', '-', r'\.\.', r'\.', '::', ':=', ':>', ':', ';;', ';', '<-', + r'<\]', '<', r'>\]', '>', r'\?\?', r'\?', r'\[<', r'\[\|', r'\[', r'\]', + '_', '`', r'\{', r'\|\]', r'\|', r'\}', '~', '<@@', '<@', '=', '@>', '@@>', + ] + + operators = r'[!$%&*+\./:<=>?@^|~-]' + word_operators = ['and', 'or', 'not'] + prefix_syms = r'[!?~]' + infix_syms = r'[=<>@^|&+\*/$%-]' + primitives = [ + 'sbyte', 'byte', 'char', 'nativeint', 'unativeint', 'float32', 'single', + 'float', 'double', 'int8', 'uint8', 'int16', 'uint16', 'int32', + 'uint32', 'int64', 'uint64', 'decimal', 'unit', 'bool', 'string', + 'list', 'exn', 'obj', 'enum', + ] + + # See http://msdn.microsoft.com/en-us/library/dd233181.aspx and/or + # http://fsharp.org/about/files/spec.pdf for reference. Good luck. + + tokens = { + 'escape-sequence': [ + (r'\\[\\"\'ntbrafv]', String.Escape), + (r'\\[0-9]{3}', String.Escape), + (r'\\u[0-9a-fA-F]{4}', String.Escape), + (r'\\U[0-9a-fA-F]{8}', String.Escape), + ], + 'root': [ + (r'\s+', Whitespace), + (r'\(\)|\[\]', Name.Builtin.Pseudo), + (r'\b(? and <| are weak + indicators.""" + result = 0 + if '|>' in text: + result += 0.05 + if '<|' in text: + result += 0.05 + + return result + + +class XppLexer(RegexLexer): + + """ + For X++ source code. This is based loosely on the CSharpLexer + """ + + name = 'X++' + url = 'https://learn.microsoft.com/en-us/dynamics365/fin-ops-core/dev-itpro/dev-ref/xpp-language-reference' + aliases = ['xpp', 'x++'] + filenames = ['*.xpp'] + version_added = '2.15' + + flags = re.MULTILINE + + XPP_CHARS = ('@?(?:_|[^' + + uni.allexcept('Lu', 'Ll', 'Lt', 'Lm', 'Lo', 'Nl') + '])' + + '[^' + uni.allexcept('Lu', 'Ll', 'Lt', 'Lm', 'Lo', 'Nl', + 'Nd', 'Pc', 'Cf', 'Mn', 'Mc') + ']*') + # Temporary, see + # https://github.com/thatch/regexlint/pull/49 + XPP_CHARS = XPP_CHARS.replace('\x00', '\x01') + + OPERATORS = ( + '<=', '>=', '+=', '-=', '*=', '/=', '!=', '==', + '&&', '||', '>>', '<<', '++', '--', '+', '-', '*', + '/', '%', '&', '|', '^', '<', '>', '?', '!', '~', '=', + ) + KEYWORDS = ('abstract','anytype','as','async','asc','at','avg','break','breakpoint','by','byref','case','catch', + 'changecompany','client','container','continue','count','crosscompany','default','delegate', + 'delete_from','desc','display','div','do','edit','else','element','eventhandler','exists','false','final', + 'firstfast','firstonly','firstonly10','firstonly100','firstonly1000','flush','for','forceliterals', + 'forcenestedloop','forceplaceholders','forceselectorder','forupdate','from','group','if','insert_recordset', + 'interface','is','join','like','maxof','minof','mod','new','next','nofetch','notexists','null','optimisticlock','order', + 'outer','pause','pessimisticlock','print','private','protected','public','repeatableread','retry','return', + 'reverse','select','server','setting','static','sum','super','switch','tablelock','this','throw','true','try','ttsabort','ttsbegin', + 'ttscommit','update_recordset','validtimestate','void','where','while','window') + RUNTIME_FUNCTIONS = ('_duration','abs','acos','any2Date','any2Enum','any2Guid','any2Int','any2Int64','any2Real','any2Str','anytodate', + 'anytoenum','anytoguid','anytoint','anytoint64','anytoreal','anytostr','asin','atan','beep','cTerm','char2Num','classIdGet', + 'corrFlagGet','corrFlagSet','cos','cosh','curExt','curUserId','date2Num','date2Str','datetime2Str','dayName','dayOfMth', + 'dayOfWk','dayOfYr','ddb','decRound','dg','dimOf','endMth','enum2str','exp','exp10','fV','fieldId2Name','fieldId2PName', + 'fieldName2Id','frac','funcName','getCurrentPartition','getCurrentPartitionRecId','getPrefix','guid2Str','idg','indexId2Name', + 'indexName2Id','int2Str','int642Str','intvMax','intvName','intvNo','intvNorm','log10','logN','match','max','min','mkDate','mthName', + 'mthOfYr','newGuid','nextMth','nextQtr','nextYr','num2Char','num2Date','num2Str','pmt','power','prevMth','prevQtr','prevYr', + 'prmIsDefault','pt','pv','rate','refPrintAll','round','runAs','sessionId','setPrefix','sin','sinh','sleep','sln','str2Date', + 'str2Datetime','str2Enum','str2Guid','str2Int','str2Int64','str2Num','str2Time','strAlpha','strCmp','strColSeq','strDel', + 'strFind','strFmt','strIns','strKeep','strLTrim','strLen','strLine','strLwr','strNFind','strPoke','strPrompt','strRTrim', + 'strRem','strRep','strScan','strUpr','subStr','syd','systemDateGet','systemDateSet','tableId2Name', + 'tableId2PName','tableName2Id','tan','tanh','term','time2Str','timeNow','today','trunc','typeOf','uint2Str','wkOfYr','year') + COMPILE_FUNCTIONS = ('attributeStr','classNum','classStr','configurationKeyNum','configurationKeyStr','dataEntityDataSourceStr','delegateStr', + 'dimensionHierarchyLevelStr','dimensionHierarchyStr','dimensionReferenceStr','dutyStr','enumCnt','enumLiteralStr','enumNum','enumStr', + 'extendedTypeNum','extendedTypeStr','fieldNum','fieldPName','fieldStr','formControlStr','formDataFieldStr','formDataSourceStr', + 'formMethodStr','formStr','identifierStr','indexNum','indexStr','licenseCodeNum','licenseCodeStr','literalStr','maxDate','maxInt', + 'measureStr','measurementStr','menuItemActionStr','menuItemDisplayStr','menuItemOutputStr','menuStr','methodStr','minInt','privilegeStr', + 'queryDatasourceStr','queryMethodStr','queryStr','reportStr','resourceStr','roleStr','ssrsReportStr','staticDelegateStr','staticMethodStr', + 'tableCollectionStr','tableFieldGroupStr','tableMethodStr','tableNum','tablePName','tableStaticMethodStr','tableStr','tileStr','varStr', + 'webActionItemStr','webDisplayContentItemStr','webFormStr','webMenuStr','webOutputContentItemStr','webReportStr','webSiteTempStr', + 'webStaticFileStr','webUrlItemStr','webWebPartStr','webletItemStr','webpageDefStr','websiteDefStr','workflowApprovalStr', + 'workflowCategoryStr','workflowTaskStr','workflowTypeStr') + + tokens = {} + + tokens = { + 'root': [ + # method names + (r'(\s*)\b(else|if)\b([^\n])', bygroups(Whitespace, Keyword, using(this))), # ensure that if is not treated like a function + (r'^([ \t]*)((?:' + XPP_CHARS + r'(?:\[\])?\s+)+?)' # return type + r'(' + XPP_CHARS + ')' # method name + r'(\s*)(\()', # signature start + bygroups(Whitespace, using(this), Name.Function, Whitespace, + Punctuation)), + (r'^(\s*)(\[)([^\n]*?)(\])', bygroups(Whitespace, Name.Attribute, Name.Variable.Class, Name.Attribute)), + (r'[^\S\n]+', Whitespace), + (r'(\\)(\n)', bygroups(Text, Whitespace)), # line continuation + (r'//[^\n]*?\n', Comment.Single), + (r'/[*][^\n]*?[*]/', Comment.Multiline), + (r'\n', Whitespace), + (words(OPERATORS), Operator), + (r'=~|!=|==|<<|>>|[-+/*%=<>&^|]', Operator), + (r'[()\[\];:,.#@]', Punctuation), + (r'[{}]', Punctuation), + (r'@"(""|[^"])*"', String), + (r'\$?"(\\\\|\\[^\\]|[^"\\\n])*["\n]', String), + (r"'\\.'|'[^\\]'", String.Char), + (r"[0-9]+(\.[0-9]*)?([eE][+-][0-9]+)?" + r"[flFLdD]?|0[xX][0-9a-fA-F]+[Ll]?", Number), + (words(KEYWORDS, suffix=r'\b'), Keyword), + (r'(boolean|int|int64|str|real|guid|date)\b\??', Keyword.Type), + (r'(class|struct|extends|implements)(\s+)', bygroups(Keyword, Whitespace), 'class'), + (r'('+XPP_CHARS+')(::)', bygroups(Name.Variable.Class, Punctuation)), + (r'(\s*)(\w+)(\s+\w+(,|=)?[^\n]*;)', bygroups(Whitespace, Name.Variable.Class, using(this))), # declaration + # x++ specific function to get field should highlight the classname + (r'(fieldNum\()('+XPP_CHARS+r')(\s*,\s*)('+XPP_CHARS+r')(\s*\))', + bygroups(using(this), Name.Variable.Class, using(this), Name.Property, using(this))), + # x++ specific function to get table should highlight the classname + (r'(tableNum\()('+XPP_CHARS+r')(\s*\))', + bygroups(using(this), Name.Variable.Class, using(this))), + (words(RUNTIME_FUNCTIONS, suffix=r'(?=\()'), Name.Function.Magic), + (words(COMPILE_FUNCTIONS, suffix=r'(?=\()'), Name.Function.Magic), + (XPP_CHARS, Name), + ], + 'class': [ + (XPP_CHARS, Name.Class, '#pop'), + default('#pop'), + ], + 'namespace': [ + (r'(?=\()', Text, '#pop'), # using (resource) + ('(' + XPP_CHARS + r'|\.)+', Name.Namespace, '#pop'), + ] + } diff --git a/pygments/lexers/dsls.py b/pygments/lexers/dsls.py new file mode 100644 index 0000000000000000000000000000000000000000..d30c11249ef4da74a6da56e1c5122dc7c58fa707 --- /dev/null +++ b/pygments/lexers/dsls.py @@ -0,0 +1,970 @@ +""" + pygments.lexers.dsls + ~~~~~~~~~~~~~~~~~~~~ + + Lexers for various domain-specific languages. + + :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS. + :license: BSD, see LICENSE for details. +""" + +import re + +from pygments.lexer import ExtendedRegexLexer, RegexLexer, bygroups, words, \ + include, default, this, using, combined +from pygments.token import Text, Comment, Operator, Keyword, Name, String, \ + Number, Punctuation, Whitespace + +__all__ = ['ProtoBufLexer', 'ZeekLexer', 'PuppetLexer', 'RslLexer', + 'MscgenLexer', 'VGLLexer', 'AlloyLexer', 'PanLexer', + 'CrmshLexer', 'ThriftLexer', 'FlatlineLexer', 'SnowballLexer'] + + +class ProtoBufLexer(RegexLexer): + """ + Lexer for Protocol Buffer definition files. + """ + + name = 'Protocol Buffer' + url = 'https://developers.google.com/protocol-buffers/' + aliases = ['protobuf', 'proto'] + filenames = ['*.proto'] + version_added = '1.4' + + tokens = { + 'root': [ + (r'[ \t]+', Whitespace), + (r'[,;{}\[\]()<>]', Punctuation), + (r'/(\\\n)?/(\n|(.|\n)*?[^\\]\n)', Comment.Single), + (r'/(\\\n)?\*(.|\n)*?\*(\\\n)?/', Comment.Multiline), + (words(( + 'import', 'option', 'optional', 'required', 'repeated', + 'reserved', 'default', 'packed', 'ctype', 'extensions', 'to', + 'max', 'rpc', 'returns', 'oneof', 'syntax'), prefix=r'\b', suffix=r'\b'), + Keyword), + (words(( + 'int32', 'int64', 'uint32', 'uint64', 'sint32', 'sint64', + 'fixed32', 'fixed64', 'sfixed32', 'sfixed64', + 'float', 'double', 'bool', 'string', 'bytes'), suffix=r'\b'), + Keyword.Type), + (r'(true|false)\b', Keyword.Constant), + (r'(package)(\s+)', bygroups(Keyword.Namespace, Whitespace), 'package'), + (r'(message|extend)(\s+)', + bygroups(Keyword.Declaration, Whitespace), 'message'), + (r'(enum|group|service)(\s+)', + bygroups(Keyword.Declaration, Whitespace), 'type'), + (r'\".*?\"', String), + (r'\'.*?\'', String), + (r'(\d+\.\d*|\.\d+|\d+)[eE][+-]?\d+[LlUu]*', Number.Float), + (r'(\d+\.\d*|\.\d+|\d+[fF])[fF]?', Number.Float), + (r'(\-?(inf|nan))\b', Number.Float), + (r'0x[0-9a-fA-F]+[LlUu]*', Number.Hex), + (r'0[0-7]+[LlUu]*', Number.Oct), + (r'\d+[LlUu]*', Number.Integer), + (r'[+-=]', Operator), + (r'([a-zA-Z_][\w.]*)([ \t]*)(=)', + bygroups(Name.Attribute, Whitespace, Operator)), + (r'[a-zA-Z_][\w.]*', Name), + ], + 'package': [ + (r'[a-zA-Z_]\w*', Name.Namespace, '#pop'), + default('#pop'), + ], + 'message': [ + (r'[a-zA-Z_]\w*', Name.Class, '#pop'), + default('#pop'), + ], + 'type': [ + (r'[a-zA-Z_]\w*', Name, '#pop'), + default('#pop'), + ], + } + + +class ThriftLexer(RegexLexer): + """ + For Thrift interface definitions. + """ + name = 'Thrift' + url = 'https://thrift.apache.org/' + aliases = ['thrift'] + filenames = ['*.thrift'] + mimetypes = ['application/x-thrift'] + version_added = '2.1' + + tokens = { + 'root': [ + include('whitespace'), + include('comments'), + (r'"', String.Double, combined('stringescape', 'dqs')), + (r'\'', String.Single, combined('stringescape', 'sqs')), + (r'(namespace)(\s+)', + bygroups(Keyword.Namespace, Whitespace), 'namespace'), + (r'(enum|union|struct|service|exception)(\s+)', + bygroups(Keyword.Declaration, Whitespace), 'class'), + (r'((?:(?:[^\W\d]|\$)[\w.\[\]$<>]*\s+)+?)' # return arguments + r'((?:[^\W\d]|\$)[\w$]*)' # method name + r'(\s*)(\()', # signature start + bygroups(using(this), Name.Function, Whitespace, Operator)), + include('keywords'), + include('numbers'), + (r'[&=]', Operator), + (r'[:;,{}()<>\[\]]', Punctuation), + (r'[a-zA-Z_](\.\w|\w)*', Name), + ], + 'whitespace': [ + (r'\n', Whitespace), + (r'\s+', Whitespace), + ], + 'comments': [ + (r'#.*$', Comment), + (r'//.*?\n', Comment), + (r'/\*[\w\W]*?\*/', Comment.Multiline), + ], + 'stringescape': [ + (r'\\([\\nrt"\'])', String.Escape), + ], + 'dqs': [ + (r'"', String.Double, '#pop'), + (r'[^\\"\n]+', String.Double), + ], + 'sqs': [ + (r"'", String.Single, '#pop'), + (r'[^\\\'\n]+', String.Single), + ], + 'namespace': [ + (r'[a-z*](\.\w|\w)*', Name.Namespace, '#pop'), + default('#pop'), + ], + 'class': [ + (r'[a-zA-Z_]\w*', Name.Class, '#pop'), + default('#pop'), + ], + 'keywords': [ + (r'(async|oneway|extends|throws|required|optional)\b', Keyword), + (r'(true|false)\b', Keyword.Constant), + (r'(const|typedef)\b', Keyword.Declaration), + (words(( + 'cpp_namespace', 'cpp_include', 'cpp_type', 'java_package', + 'cocoa_prefix', 'csharp_namespace', 'delphi_namespace', + 'php_namespace', 'py_module', 'perl_package', + 'ruby_namespace', 'smalltalk_category', 'smalltalk_prefix', + 'xsd_all', 'xsd_optional', 'xsd_nillable', 'xsd_namespace', + 'xsd_attrs', 'include'), suffix=r'\b'), + Keyword.Namespace), + (words(( + 'void', 'bool', 'byte', 'i16', 'i32', 'i64', 'double', + 'string', 'binary', 'map', 'list', 'set', 'slist', + 'senum'), suffix=r'\b'), + Keyword.Type), + (words(( + 'BEGIN', 'END', '__CLASS__', '__DIR__', '__FILE__', + '__FUNCTION__', '__LINE__', '__METHOD__', '__NAMESPACE__', + 'abstract', 'alias', 'and', 'args', 'as', 'assert', 'begin', + 'break', 'case', 'catch', 'class', 'clone', 'continue', + 'declare', 'def', 'default', 'del', 'delete', 'do', 'dynamic', + 'elif', 'else', 'elseif', 'elsif', 'end', 'enddeclare', + 'endfor', 'endforeach', 'endif', 'endswitch', 'endwhile', + 'ensure', 'except', 'exec', 'finally', 'float', 'for', + 'foreach', 'function', 'global', 'goto', 'if', 'implements', + 'import', 'in', 'inline', 'instanceof', 'interface', 'is', + 'lambda', 'module', 'native', 'new', 'next', 'nil', 'not', + 'or', 'pass', 'public', 'print', 'private', 'protected', + 'raise', 'redo', 'rescue', 'retry', 'register', 'return', + 'self', 'sizeof', 'static', 'super', 'switch', 'synchronized', + 'then', 'this', 'throw', 'transient', 'try', 'undef', + 'unless', 'unsigned', 'until', 'use', 'var', 'virtual', + 'volatile', 'when', 'while', 'with', 'xor', 'yield'), + prefix=r'\b', suffix=r'\b'), + Keyword.Reserved), + ], + 'numbers': [ + (r'[+-]?(\d+\.\d+([eE][+-]?\d+)?|\.?\d+[eE][+-]?\d+)', Number.Float), + (r'[+-]?0x[0-9A-Fa-f]+', Number.Hex), + (r'[+-]?[0-9]+', Number.Integer), + ], + } + + +class ZeekLexer(RegexLexer): + """ + For Zeek scripts. + """ + name = 'Zeek' + url = 'https://www.zeek.org/' + aliases = ['zeek', 'bro'] + filenames = ['*.zeek', '*.bro'] + version_added = '2.5' + + _hex = r'[0-9a-fA-F]' + _float = r'((\d*\.?\d+)|(\d+\.?\d*))([eE][-+]?\d+)?' + _h = r'[A-Za-z0-9][-A-Za-z0-9]*' + + tokens = { + 'root': [ + include('whitespace'), + include('comments'), + include('directives'), + include('attributes'), + include('types'), + include('keywords'), + include('literals'), + include('operators'), + include('punctuation'), + (r'((?:[A-Za-z_]\w*)(?:::(?:[A-Za-z_]\w*))*)(?=\s*\()', + Name.Function), + include('identifiers'), + ], + + 'whitespace': [ + (r'\n', Whitespace), + (r'\s+', Whitespace), + (r'(\\)(\n)', bygroups(Text, Whitespace)), + ], + + 'comments': [ + (r'#.*$', Comment), + ], + + 'directives': [ + (r'@(load-plugin|load-sigs|load|unload)\b.*$', Comment.Preproc), + (r'@(DEBUG|DIR|FILENAME|deprecated|if|ifdef|ifndef|else|endif)\b', Comment.Preproc), + (r'(@prefixes)(\s*)((\+?=).*)$', bygroups(Comment.Preproc, + Whitespace, Comment.Preproc)), + ], + + 'attributes': [ + (words(('redef', 'priority', 'log', 'optional', 'default', 'add_func', + 'delete_func', 'expire_func', 'read_expire', 'write_expire', + 'create_expire', 'synchronized', 'persistent', 'rotate_interval', + 'rotate_size', 'encrypt', 'raw_output', 'mergeable', 'error_handler', + 'type_column', 'deprecated'), + prefix=r'&', suffix=r'\b'), + Keyword.Pseudo), + ], + + 'types': [ + (words(('any', + 'enum', 'record', 'set', 'table', 'vector', + 'function', 'hook', 'event', + 'addr', 'bool', 'count', 'double', 'file', 'int', 'interval', + 'pattern', 'port', 'string', 'subnet', 'time'), + suffix=r'\b'), + Keyword.Type), + + (r'(opaque)(\s+)(of)(\s+)((?:[A-Za-z_]\w*)(?:::(?:[A-Za-z_]\w*))*)\b', + bygroups(Keyword.Type, Whitespace, Operator.Word, Whitespace, Keyword.Type)), + + (r'(type)(\s+)((?:[A-Za-z_]\w*)(?:::(?:[A-Za-z_]\w*))*)(\s*)(:)(\s*)\b(record|enum)\b', + bygroups(Keyword, Whitespace, Name.Class, Whitespace, Operator, Whitespace, Keyword.Type)), + + (r'(type)(\s+)((?:[A-Za-z_]\w*)(?:::(?:[A-Za-z_]\w*))*)(\s*)(:)', + bygroups(Keyword, Whitespace, Name, Whitespace, Operator)), + + (r'(redef)(\s+)(record|enum)(\s+)((?:[A-Za-z_]\w*)(?:::(?:[A-Za-z_]\w*))*)\b', + bygroups(Keyword, Whitespace, Keyword.Type, Whitespace, Name.Class)), + ], + + 'keywords': [ + (words(('redef', 'export', 'if', 'else', 'for', 'while', + 'return', 'break', 'next', 'continue', 'fallthrough', + 'switch', 'default', 'case', + 'add', 'delete', + 'when', 'timeout', 'schedule'), + suffix=r'\b'), + Keyword), + (r'(print)\b', Keyword), + (r'(global|local|const|option)\b', Keyword.Declaration), + (r'(module)(\s+)(([A-Za-z_]\w*)(?:::([A-Za-z_]\w*))*)\b', + bygroups(Keyword.Namespace, Whitespace, Name.Namespace)), + ], + + 'literals': [ + (r'"', String, 'string'), + + # Not the greatest match for patterns, but generally helps + # disambiguate between start of a pattern and just a division + # operator. + (r'/(?=.*/)', String.Regex, 'regex'), + + (r'(T|F)\b', Keyword.Constant), + + # Port + (r'\d{1,5}/(udp|tcp|icmp|unknown)\b', Number), + + # IPv4 Address + (r'(\d{1,3}.){3}(\d{1,3})\b', Number), + + # IPv6 Address + (r'\[([0-9a-fA-F]{0,4}:){2,7}([0-9a-fA-F]{0,4})?((\d{1,3}.){3}(\d{1,3}))?\]', Number), + + # Numeric + (r'0[xX]' + _hex + r'+\b', Number.Hex), + (_float + r'\s*(day|hr|min|sec|msec|usec)s?\b', Number.Float), + (_float + r'\b', Number.Float), + (r'(\d+)\b', Number.Integer), + + # Hostnames + (_h + r'(\.' + _h + r')+', String), + ], + + 'operators': [ + (r'[!%*/+<=>~|&^-]', Operator), + (r'([-+=&|]{2}|[+=!><-]=)', Operator), + (r'(in|as|is|of)\b', Operator.Word), + (r'\??\$', Operator), + ], + + 'punctuation': [ + (r'[{}()\[\],;.]', Punctuation), + # The "ternary if", which uses '?' and ':', could instead be + # treated as an Operator, but colons are more frequently used to + # separate field/identifier names from their types, so the (often) + # less-prominent Punctuation is used even with '?' for consistency. + (r'[?:]', Punctuation), + ], + + 'identifiers': [ + (r'([a-zA-Z_]\w*)(::)', bygroups(Name, Punctuation)), + (r'[a-zA-Z_]\w*', Name) + ], + + 'string': [ + (r'\\.', String.Escape), + (r'%-?[0-9]*(\.[0-9]+)?[DTd-gsx]', String.Escape), + (r'"', String, '#pop'), + (r'.', String), + ], + + 'regex': [ + (r'\\.', String.Escape), + (r'/', String.Regex, '#pop'), + (r'.', String.Regex), + ], + } + + +BroLexer = ZeekLexer + + +class PuppetLexer(RegexLexer): + """ + For Puppet configuration DSL. + """ + name = 'Puppet' + url = 'https://puppet.com/' + aliases = ['puppet'] + filenames = ['*.pp'] + version_added = '1.6' + + tokens = { + 'root': [ + include('comments'), + include('keywords'), + include('names'), + include('numbers'), + include('operators'), + include('strings'), + + (r'[]{}:(),;[]', Punctuation), + (r'\s+', Whitespace), + ], + + 'comments': [ + (r'(\s*)(#.*)$', bygroups(Whitespace, Comment)), + (r'/(\\\n)?[*](.|\n)*?[*](\\\n)?/', Comment.Multiline), + ], + + 'operators': [ + (r'(=>|\?|<|>|=|\+|-|/|\*|~|!|\|)', Operator), + (r'(in|and|or|not)\b', Operator.Word), + ], + + 'names': [ + (r'[a-zA-Z_]\w*', Name.Attribute), + (r'(\$\S+)(\[)(\S+)(\])', bygroups(Name.Variable, Punctuation, + String, Punctuation)), + (r'\$\S+', Name.Variable), + ], + + 'numbers': [ + # Copypasta from the Python lexer + (r'(\d+\.\d*|\d*\.\d+)([eE][+-]?[0-9]+)?j?', Number.Float), + (r'\d+[eE][+-]?[0-9]+j?', Number.Float), + (r'0[0-7]+j?', Number.Oct), + (r'0[xX][a-fA-F0-9]+', Number.Hex), + (r'\d+L', Number.Integer.Long), + (r'\d+j?', Number.Integer) + ], + + 'keywords': [ + # Left out 'group' and 'require' + # Since they're often used as attributes + (words(( + 'absent', 'alert', 'alias', 'audit', 'augeas', 'before', 'case', + 'check', 'class', 'computer', 'configured', 'contained', + 'create_resources', 'crit', 'cron', 'debug', 'default', + 'define', 'defined', 'directory', 'else', 'elsif', 'emerg', + 'err', 'exec', 'extlookup', 'fail', 'false', 'file', + 'filebucket', 'fqdn_rand', 'generate', 'host', 'if', 'import', + 'include', 'info', 'inherits', 'inline_template', 'installed', + 'interface', 'k5login', 'latest', 'link', 'loglevel', + 'macauthorization', 'mailalias', 'maillist', 'mcx', 'md5', + 'mount', 'mounted', 'nagios_command', 'nagios_contact', + 'nagios_contactgroup', 'nagios_host', 'nagios_hostdependency', + 'nagios_hostescalation', 'nagios_hostextinfo', 'nagios_hostgroup', + 'nagios_service', 'nagios_servicedependency', 'nagios_serviceescalation', + 'nagios_serviceextinfo', 'nagios_servicegroup', 'nagios_timeperiod', + 'node', 'noop', 'notice', 'notify', 'package', 'present', 'purged', + 'realize', 'regsubst', 'resources', 'role', 'router', 'running', + 'schedule', 'scheduled_task', 'search', 'selboolean', 'selmodule', + 'service', 'sha1', 'shellquote', 'split', 'sprintf', + 'ssh_authorized_key', 'sshkey', 'stage', 'stopped', 'subscribe', + 'tag', 'tagged', 'template', 'tidy', 'true', 'undef', 'unmounted', + 'user', 'versioncmp', 'vlan', 'warning', 'yumrepo', 'zfs', 'zone', + 'zpool'), prefix='(?i)', suffix=r'\b'), + Keyword), + ], + + 'strings': [ + (r'"([^"])*"', String), + (r"'(\\'|[^'])*'", String), + ], + + } + + +class RslLexer(RegexLexer): + """ + RSL is the formal specification + language used in RAISE (Rigorous Approach to Industrial Software Engineering) + method. + """ + name = 'RSL' + url = 'http://en.wikipedia.org/wiki/RAISE' + aliases = ['rsl'] + filenames = ['*.rsl'] + mimetypes = ['text/rsl'] + version_added = '2.0' + + flags = re.MULTILINE | re.DOTALL + + tokens = { + 'root': [ + (words(( + 'Bool', 'Char', 'Int', 'Nat', 'Real', 'Text', 'Unit', 'abs', + 'all', 'always', 'any', 'as', 'axiom', 'card', 'case', 'channel', + 'chaos', 'class', 'devt_relation', 'dom', 'elems', 'else', 'elif', + 'end', 'exists', 'extend', 'false', 'for', 'hd', 'hide', 'if', + 'in', 'is', 'inds', 'initialise', 'int', 'inter', 'isin', 'len', + 'let', 'local', 'ltl_assertion', 'object', 'of', 'out', 'post', + 'pre', 'read', 'real', 'rng', 'scheme', 'skip', 'stop', 'swap', + 'then', 'theory', 'test_case', 'tl', 'transition_system', 'true', + 'type', 'union', 'until', 'use', 'value', 'variable', 'while', + 'with', 'write', '~isin', '-inflist', '-infset', '-list', + '-set'), prefix=r'\b', suffix=r'\b'), + Keyword), + (r'(variable|value)\b', Keyword.Declaration), + (r'--.*?\n', Comment), + (r'<:.*?:>', Comment), + (r'\{!.*?!\}', Comment), + (r'/\*.*?\*/', Comment), + (r'^([ \t]*)([\w]+)([ \t]*)(:[^:])', bygroups(Whitespace, + Name.Function, Whitespace, Name.Function)), + (r'(^[ \t]*)([\w]+)([ \t]*)(\([\w\s,]*\))([ \t]*)(is|as)', + bygroups(Whitespace, Name.Function, Whitespace, Text, + Whitespace, Keyword)), + (r'\b[A-Z]\w*\b', Keyword.Type), + (r'(true|false)\b', Keyword.Constant), + (r'".*"', String), + (r'\'.\'', String.Char), + (r'(><|->|-m->|/\\|<=|<<=|<\.|\|\||\|\^\||-~->|-~m->|\\/|>=|>>|' + r'\.>|\+\+|-\\|<->|=>|:-|~=|\*\*|<<|>>=|\+>|!!|\|=\||#)', + Operator), + (r'[0-9]+\.[0-9]+([eE][0-9]+)?[fd]?', Number.Float), + (r'0x[0-9a-f]+', Number.Hex), + (r'[0-9]+', Number.Integer), + (r'\s+', Whitespace), + (r'.', Text), + ], + } + + def analyse_text(text): + """ + Check for the most common text in the beginning of a RSL file. + """ + if re.search(r'scheme\s*.*?=\s*class\s*type', text, re.I) is not None: + return 1.0 + + +class MscgenLexer(RegexLexer): + """ + For Mscgen files. + """ + name = 'Mscgen' + url = 'http://www.mcternan.me.uk/mscgen/' + aliases = ['mscgen', 'msc'] + filenames = ['*.msc'] + version_added = '1.6' + + _var = r'(\w+|"(?:\\"|[^"])*")' + + tokens = { + 'root': [ + (r'msc\b', Keyword.Type), + # Options + (r'(hscale|HSCALE|width|WIDTH|wordwraparcs|WORDWRAPARCS' + r'|arcgradient|ARCGRADIENT)\b', Name.Property), + # Operators + (r'(abox|ABOX|rbox|RBOX|box|BOX|note|NOTE)\b', Operator.Word), + (r'(\.|-|\|){3}', Keyword), + (r'(?:-|=|\.|:){2}' + r'|<<=>>|<->|<=>|<<>>|<:>' + r'|->|=>>|>>|=>|:>|-x|-X' + r'|<-|<<=|<<|<=|<:|x-|X-|=', Operator), + # Names + (r'\*', Name.Builtin), + (_var, Name.Variable), + # Other + (r'\[', Punctuation, 'attrs'), + (r'\{|\}|,|;', Punctuation), + include('comments') + ], + 'attrs': [ + (r'\]', Punctuation, '#pop'), + (_var + r'(\s*)(=)(\s*)' + _var, + bygroups(Name.Attribute, Whitespace, Operator, Whitespace, + String)), + (r',', Punctuation), + include('comments') + ], + 'comments': [ + (r'(?://|#).*?\n', Comment.Single), + (r'/\*(?:.|\n)*?\*/', Comment.Multiline), + (r'[ \t\r\n]+', Whitespace) + ] + } + + +class VGLLexer(RegexLexer): + """ + For SampleManager VGL source code. + """ + name = 'VGL' + url = 'http://www.thermoscientific.com/samplemanager' + aliases = ['vgl'] + filenames = ['*.rpf'] + version_added = '1.6' + + flags = re.MULTILINE | re.DOTALL | re.IGNORECASE + + tokens = { + 'root': [ + (r'\{[^}]*\}', Comment.Multiline), + (r'declare', Keyword.Constant), + (r'(if|then|else|endif|while|do|endwhile|and|or|prompt|object' + r'|create|on|line|with|global|routine|value|endroutine|constant' + r'|global|set|join|library|compile_option|file|exists|create|copy' + r'|delete|enable|windows|name|notprotected)(?! *[=<>.,()])', + Keyword), + (r'(true|false|null|empty|error|locked)', Keyword.Constant), + (r'[~^*#!%&\[\]()<>|+=:;,./?-]', Operator), + (r'"[^"]*"', String), + (r'(\.)([a-z_$][\w$]*)', bygroups(Operator, Name.Attribute)), + (r'[0-9][0-9]*(\.[0-9]+(e[+\-]?[0-9]+)?)?', Number), + (r'[a-z_$][\w$]*', Name), + (r'[\r\n]+', Whitespace), + (r'\s+', Whitespace) + ] + } + + +class AlloyLexer(RegexLexer): + """ + For Alloy source code. + """ + + name = 'Alloy' + url = 'http://alloy.mit.edu' + aliases = ['alloy'] + filenames = ['*.als'] + mimetypes = ['text/x-alloy'] + version_added = '2.0' + + flags = re.MULTILINE | re.DOTALL + + iden_rex = r'[a-zA-Z_][\w]*"*' + string_rex = r'"\b(\\\\|\\[^\\]|[^"\\])*"' + text_tuple = (r'[^\S\n]+', Whitespace) + + tokens = { + 'sig': [ + (r'(extends)\b', Keyword, '#pop'), + (iden_rex, Name), + text_tuple, + (r',', Punctuation), + (r'\{', Operator, '#pop'), + ], + 'module': [ + text_tuple, + (iden_rex, Name, '#pop'), + ], + 'fun': [ + text_tuple, + (r'\{', Operator, '#pop'), + (iden_rex, Name, '#pop'), + ], + 'fact': [ + include('fun'), + (string_rex, String, '#pop'), + ], + 'root': [ + (r'--.*?$', Comment.Single), + (r'//.*?$', Comment.Single), + (r'/\*.*?\*/', Comment.Multiline), + text_tuple, + (r'(module|open)(\s+)', bygroups(Keyword.Namespace, Whitespace), + 'module'), + (r'(sig|enum)(\s+)', bygroups(Keyword.Declaration, Whitespace), 'sig'), + (r'(iden|univ|none)\b', Keyword.Constant), + (r'(int|Int)\b', Keyword.Type), + (r'(var|this|abstract|extends|set|seq|one|lone|let)\b', Keyword), + (r'(all|some|no|sum|disj|when|else)\b', Keyword), + (r'(run|check|for|but|exactly|expect|as|steps)\b', Keyword), + (r'(always|after|eventually|until|release)\b', Keyword), # future time operators + (r'(historically|before|once|since|triggered)\b', Keyword), # past time operators + (r'(and|or|implies|iff|in)\b', Operator.Word), + (r'(fun|pred|assert)(\s+)', bygroups(Keyword, Whitespace), 'fun'), + (r'(fact)(\s+)', bygroups(Keyword, Whitespace), 'fact'), + (r'!|#|&&|\+\+|<<|>>|>=|<=>|<=|\.\.|\.|->', Operator), + (r'[-+/*%=<>&!^|~{}\[\]().\';]', Operator), + (iden_rex, Name), + (r'[:,]', Punctuation), + (r'[0-9]+', Number.Integer), + (string_rex, String), + (r'\n', Whitespace), + ] + } + + +class PanLexer(RegexLexer): + """ + Lexer for pan source files. + + Based on tcsh lexer. + """ + + name = 'Pan' + url = 'https://github.com/quattor/pan/' + aliases = ['pan'] + filenames = ['*.pan'] + version_added = '2.0' + + tokens = { + 'root': [ + include('basic'), + (r'\(', Keyword, 'paren'), + (r'\{', Keyword, 'curly'), + include('data'), + ], + 'basic': [ + (words(( + 'if', 'for', 'with', 'else', 'type', 'bind', 'while', 'valid', 'final', + 'prefix', 'unique', 'object', 'foreach', 'include', 'template', + 'function', 'variable', 'structure', 'extensible', 'declaration'), + prefix=r'\b', suffix=r'\b'), + Keyword), + (words(( + 'file_contents', 'format', 'index', 'length', 'match', 'matches', + 'replace', 'splice', 'split', 'substr', 'to_lowercase', 'to_uppercase', + 'debug', 'error', 'traceback', 'deprecated', 'base64_decode', + 'base64_encode', 'digest', 'escape', 'unescape', 'append', 'create', + 'first', 'nlist', 'key', 'list', 'merge', 'next', 'prepend', 'is_boolean', + 'is_defined', 'is_double', 'is_list', 'is_long', 'is_nlist', 'is_null', + 'is_number', 'is_property', 'is_resource', 'is_string', 'to_boolean', + 'to_double', 'to_long', 'to_string', 'clone', 'delete', 'exists', + 'path_exists', 'if_exists', 'return', 'value'), + prefix=r'\b', suffix=r'\b'), + Name.Builtin), + (r'#.*', Comment), + (r'\\[\w\W]', String.Escape), + (r'(\b\w+)(\s*)(=)', bygroups(Name.Variable, Whitespace, Operator)), + (r'[\[\]{}()=]+', Operator), + (r'<<\s*(\'?)\\?(\w+)[\w\W]+?\2', String), + (r';', Punctuation), + ], + 'data': [ + (r'(?s)"(\\\\|\\[0-7]+|\\.|[^"\\])*"', String.Double), + (r"(?s)'(\\\\|\\[0-7]+|\\.|[^'\\])*'", String.Single), + (r'\s+', Whitespace), + (r'[^=\s\[\]{}()$"\'`\\;#]+', Text), + (r'\d+(?= |\Z)', Number), + ], + 'curly': [ + (r'\}', Keyword, '#pop'), + (r':-', Keyword), + (r'\w+', Name.Variable), + (r'[^}:"\'`$]+', Punctuation), + (r':', Punctuation), + include('root'), + ], + 'paren': [ + (r'\)', Keyword, '#pop'), + include('root'), + ], + } + + +class CrmshLexer(RegexLexer): + """ + Lexer for crmsh configuration files for Pacemaker clusters. + """ + name = 'Crmsh' + url = 'http://crmsh.github.io/' + aliases = ['crmsh', 'pcmk'] + filenames = ['*.crmsh', '*.pcmk'] + mimetypes = [] + version_added = '2.1' + + elem = words(( + 'node', 'primitive', 'group', 'clone', 'ms', 'location', + 'colocation', 'order', 'fencing_topology', 'rsc_ticket', + 'rsc_template', 'property', 'rsc_defaults', + 'op_defaults', 'acl_target', 'acl_group', 'user', 'role', + 'tag'), suffix=r'(?![\w#$-])') + sub = words(( + 'params', 'meta', 'operations', 'op', 'rule', + 'attributes', 'utilization'), suffix=r'(?![\w#$-])') + acl = words(('read', 'write', 'deny'), suffix=r'(?![\w#$-])') + bin_rel = words(('and', 'or'), suffix=r'(?![\w#$-])') + un_ops = words(('defined', 'not_defined'), suffix=r'(?![\w#$-])') + date_exp = words(('in_range', 'date', 'spec', 'in'), suffix=r'(?![\w#$-])') + acl_mod = (r'(?:tag|ref|reference|attribute|type|xpath)') + bin_ops = (r'(?:lt|gt|lte|gte|eq|ne)') + val_qual = (r'(?:string|version|number)') + rsc_role_action = (r'(?:Master|Started|Slave|Stopped|' + r'start|promote|demote|stop)') + + tokens = { + 'root': [ + (r'^(#.*)(\n)?', bygroups(Comment, Whitespace)), + # attr=value (nvpair) + (r'([\w#$-]+)(=)("(?:""|[^"])*"|\S+)', + bygroups(Name.Attribute, Punctuation, String)), + # need this construct, otherwise numeric node ids + # are matched as scores + # elem id: + (r'(node)(\s+)([\w#$-]+)(:)', + bygroups(Keyword, Whitespace, Name, Punctuation)), + # scores + (r'([+-]?([0-9]+|inf)):', Number), + # keywords (elements and other) + (elem, Keyword), + (sub, Keyword), + (acl, Keyword), + # binary operators + (rf'(?:{val_qual}:)?({bin_ops})(?![\w#$-])', Operator.Word), + # other operators + (bin_rel, Operator.Word), + (un_ops, Operator.Word), + (date_exp, Operator.Word), + # builtin attributes (e.g. #uname) + (r'#[a-z]+(?![\w#$-])', Name.Builtin), + # acl_mod:blah + (rf'({acl_mod})(:)("(?:""|[^"])*"|\S+)', + bygroups(Keyword, Punctuation, Name)), + # rsc_id[:(role|action)] + # NB: this matches all other identifiers + (rf'([\w#$-]+)(?:(:)({rsc_role_action}))?(?![\w#$-])', + bygroups(Name, Punctuation, Operator.Word)), + # punctuation + (r'(\\(?=\n)|[\[\](){}/:@])', Punctuation), + (r'\s+|\n', Whitespace), + ], + } + + +class FlatlineLexer(RegexLexer): + """ + Lexer for Flatline expressions. + """ + name = 'Flatline' + url = 'https://github.com/bigmlcom/flatline' + aliases = ['flatline'] + filenames = [] + mimetypes = ['text/x-flatline'] + version_added = '2.2' + + special_forms = ('let',) + + builtins = ( + "!=", "*", "+", "-", "<", "<=", "=", ">", ">=", "abs", "acos", "all", + "all-but", "all-with-defaults", "all-with-numeric-default", "and", + "asin", "atan", "avg", "avg-window", "bin-center", "bin-count", "call", + "category-count", "ceil", "cond", "cond-window", "cons", "cos", "cosh", + "count", "diff-window", "div", "ensure-value", "ensure-weighted-value", + "epoch", "epoch-day", "epoch-fields", "epoch-hour", "epoch-millisecond", + "epoch-minute", "epoch-month", "epoch-second", "epoch-weekday", + "epoch-year", "exp", "f", "field", "field-prop", "fields", "filter", + "first", "floor", "head", "if", "in", "integer", "language", "length", + "levenshtein", "linear-regression", "list", "ln", "log", "log10", "map", + "matches", "matches?", "max", "maximum", "md5", "mean", "median", "min", + "minimum", "missing", "missing-count", "missing?", "missing_count", + "mod", "mode", "normalize", "not", "nth", "occurrences", "or", + "percentile", "percentile-label", "population", "population-fraction", + "pow", "preferred", "preferred?", "quantile-label", "rand", "rand-int", + "random-value", "re-quote", "real", "replace", "replace-first", "rest", + "round", "row-number", "segment-label", "sha1", "sha256", "sin", "sinh", + "sqrt", "square", "standard-deviation", "standard_deviation", "str", + "subs", "sum", "sum-squares", "sum-window", "sum_squares", "summary", + "summary-no", "summary-str", "tail", "tan", "tanh", "to-degrees", + "to-radians", "variance", "vectorize", "weighted-random-value", "window", + "winnow", "within-percentiles?", "z-score", + ) + + valid_name = r'(?!#)[\w!$%*+<=>?/.#-]+' + + tokens = { + 'root': [ + # whitespaces - usually not relevant + (r'[,]+', Text), + (r'\s+', Whitespace), + + # numbers + (r'-?\d+\.\d+', Number.Float), + (r'-?\d+', Number.Integer), + (r'0x-?[a-f\d]+', Number.Hex), + + # strings, symbols and characters + (r'"(\\\\|\\[^\\]|[^"\\])*"', String), + (r"\\(.|[a-z]+)", String.Char), + + # expression template placeholder + (r'_', String.Symbol), + + # highlight the special forms + (words(special_forms, suffix=' '), Keyword), + + # highlight the builtins + (words(builtins, suffix=' '), Name.Builtin), + + # the remaining functions + (r'(?<=\()' + valid_name, Name.Function), + + # find the remaining variables + (valid_name, Name.Variable), + + # parentheses + (r'(\(|\))', Punctuation), + ], + } + + +class SnowballLexer(ExtendedRegexLexer): + """ + Lexer for Snowball source code. + """ + + name = 'Snowball' + url = 'https://snowballstem.org/' + aliases = ['snowball'] + filenames = ['*.sbl'] + version_added = '2.2' + + _ws = r'\n\r\t ' + + def __init__(self, **options): + self._reset_stringescapes() + ExtendedRegexLexer.__init__(self, **options) + + def _reset_stringescapes(self): + self._start = "'" + self._end = "'" + + def _string(do_string_first): + def callback(lexer, match, ctx): + s = match.start() + text = match.group() + string = re.compile(rf'([^{re.escape(lexer._start)}]*)(.)').match + escape = re.compile(rf'([^{re.escape(lexer._end)}]*)(.)').match + pos = 0 + do_string = do_string_first + while pos < len(text): + if do_string: + match = string(text, pos) + yield s + match.start(1), String.Single, match.group(1) + if match.group(2) == "'": + yield s + match.start(2), String.Single, match.group(2) + ctx.stack.pop() + break + yield s + match.start(2), String.Escape, match.group(2) + pos = match.end() + match = escape(text, pos) + yield s + match.start(), String.Escape, match.group() + if match.group(2) != lexer._end: + ctx.stack[-1] = 'escape' + break + pos = match.end() + do_string = True + ctx.pos = s + match.end() + return callback + + def _stringescapes(lexer, match, ctx): + lexer._start = match.group(3) + lexer._end = match.group(5) + return bygroups(Keyword.Reserved, Whitespace, String.Escape, Whitespace, + String.Escape)(lexer, match, ctx) + + tokens = { + 'root': [ + (r'len\b', Name.Builtin), + (r'lenof\b', Operator.Word), + include('root1'), + ], + 'root1': [ + (rf'[{_ws}]+', Whitespace), + (r'\d+', Number.Integer), + (r"'", String.Single, 'string'), + (r'[()]', Punctuation), + (r'/\*[\w\W]*?\*/', Comment.Multiline), + (r'//.*', Comment.Single), + (r'[!*+\-/<=>]=|[-=]>|<[+-]|[$*+\-/<=>?\[\]]', Operator), + (words(('as', 'get', 'hex', 'among', 'define', 'decimal', + 'backwardmode'), suffix=r'\b'), + Keyword.Reserved), + (words(('strings', 'booleans', 'integers', 'routines', 'externals', + 'groupings'), suffix=r'\b'), + Keyword.Reserved, 'declaration'), + (words(('do', 'or', 'and', 'for', 'hop', 'non', 'not', 'set', 'try', + 'fail', 'goto', 'loop', 'next', 'test', 'true', + 'false', 'unset', 'atmark', 'attach', 'delete', 'gopast', + 'insert', 'repeat', 'sizeof', 'tomark', 'atleast', + 'atlimit', 'reverse', 'setmark', 'tolimit', 'setlimit', + 'backwards', 'substring'), suffix=r'\b'), + Operator.Word), + (words(('size', 'limit', 'cursor', 'maxint', 'minint'), + suffix=r'\b'), + Name.Builtin), + (rf'(stringdef\b)([{_ws}]*)([^{_ws}]+)', + bygroups(Keyword.Reserved, Whitespace, String.Escape)), + (rf'(stringescapes\b)([{_ws}]*)(.)([{_ws}]*)(.)', + _stringescapes), + (r'[A-Za-z]\w*', Name), + ], + 'declaration': [ + (r'\)', Punctuation, '#pop'), + (words(('len', 'lenof'), suffix=r'\b'), Name, + ('root1', 'declaration')), + include('root1'), + ], + 'string': [ + (r"[^']*'", _string(True)), + ], + 'escape': [ + (r"[^']*'", _string(False)), + ], + } + + def get_tokens_unprocessed(self, text=None, context=None): + self._reset_stringescapes() + return ExtendedRegexLexer.get_tokens_unprocessed(self, text, context) diff --git a/pygments/lexers/dylan.py b/pygments/lexers/dylan.py new file mode 100644 index 0000000000000000000000000000000000000000..2109bd52c8f775cdd9e2ab4fee4b995df834af6e --- /dev/null +++ b/pygments/lexers/dylan.py @@ -0,0 +1,279 @@ +""" + pygments.lexers.dylan + ~~~~~~~~~~~~~~~~~~~~~ + + Lexers for the Dylan language. + + :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS. + :license: BSD, see LICENSE for details. +""" + +import re + +from pygments.lexer import Lexer, RegexLexer, bygroups, do_insertions, \ + default, line_re +from pygments.token import Comment, Operator, Keyword, Name, String, \ + Number, Punctuation, Generic, Literal, Whitespace + +__all__ = ['DylanLexer', 'DylanConsoleLexer', 'DylanLidLexer'] + + +class DylanLexer(RegexLexer): + """ + For the Dylan language. + """ + + name = 'Dylan' + url = 'http://www.opendylan.org/' + aliases = ['dylan'] + filenames = ['*.dylan', '*.dyl', '*.intr'] + mimetypes = ['text/x-dylan'] + version_added = '0.7' + + flags = re.IGNORECASE + + builtins = { + 'subclass', 'abstract', 'block', 'concrete', 'constant', 'class', + 'compiler-open', 'compiler-sideways', 'domain', 'dynamic', + 'each-subclass', 'exception', 'exclude', 'function', 'generic', + 'handler', 'inherited', 'inline', 'inline-only', 'instance', + 'interface', 'import', 'keyword', 'library', 'macro', 'method', + 'module', 'open', 'primary', 'required', 'sealed', 'sideways', + 'singleton', 'slot', 'thread', 'variable', 'virtual'} + + keywords = { + 'above', 'afterwards', 'begin', 'below', 'by', 'case', 'cleanup', + 'create', 'define', 'else', 'elseif', 'end', 'export', 'finally', + 'for', 'from', 'if', 'in', 'let', 'local', 'otherwise', 'rename', + 'select', 'signal', 'then', 'to', 'unless', 'until', 'use', 'when', + 'while'} + + operators = { + '~', '+', '-', '*', '|', '^', '=', '==', '~=', '~==', '<', '<=', + '>', '>=', '&', '|'} + + functions = { + 'abort', 'abs', 'add', 'add!', 'add-method', 'add-new', 'add-new!', + 'all-superclasses', 'always', 'any?', 'applicable-method?', 'apply', + 'aref', 'aref-setter', 'as', 'as-lowercase', 'as-lowercase!', + 'as-uppercase', 'as-uppercase!', 'ash', 'backward-iteration-protocol', + 'break', 'ceiling', 'ceiling/', 'cerror', 'check-type', 'choose', + 'choose-by', 'complement', 'compose', 'concatenate', 'concatenate-as', + 'condition-format-arguments', 'condition-format-string', 'conjoin', + 'copy-sequence', 'curry', 'default-handler', 'dimension', 'dimensions', + 'direct-subclasses', 'direct-superclasses', 'disjoin', 'do', + 'do-handlers', 'element', 'element-setter', 'empty?', 'error', 'even?', + 'every?', 'false-or', 'fill!', 'find-key', 'find-method', 'first', + 'first-setter', 'floor', 'floor/', 'forward-iteration-protocol', + 'function-arguments', 'function-return-values', + 'function-specializers', 'gcd', 'generic-function-mandatory-keywords', + 'generic-function-methods', 'head', 'head-setter', 'identity', + 'initialize', 'instance?', 'integral?', 'intersection', + 'key-sequence', 'key-test', 'last', 'last-setter', 'lcm', 'limited', + 'list', 'logand', 'logbit?', 'logior', 'lognot', 'logxor', 'make', + 'map', 'map-as', 'map-into', 'max', 'member?', 'merge-hash-codes', + 'min', 'modulo', 'negative', 'negative?', 'next-method', + 'object-class', 'object-hash', 'odd?', 'one-of', 'pair', 'pop', + 'pop-last', 'positive?', 'push', 'push-last', 'range', 'rank', + 'rcurry', 'reduce', 'reduce1', 'remainder', 'remove', 'remove!', + 'remove-duplicates', 'remove-duplicates!', 'remove-key!', + 'remove-method', 'replace-elements!', 'replace-subsequence!', + 'restart-query', 'return-allowed?', 'return-description', + 'return-query', 'reverse', 'reverse!', 'round', 'round/', + 'row-major-index', 'second', 'second-setter', 'shallow-copy', + 'signal', 'singleton', 'size', 'size-setter', 'slot-initialized?', + 'sort', 'sort!', 'sorted-applicable-methods', 'subsequence-position', + 'subtype?', 'table-protocol', 'tail', 'tail-setter', 'third', + 'third-setter', 'truncate', 'truncate/', 'type-error-expected-type', + 'type-error-value', 'type-for-copy', 'type-union', 'union', 'values', + 'vector', 'zero?'} + + valid_name = '\\\\?[\\w!&*<>|^$%@\\-+~?/=]+' + + def get_tokens_unprocessed(self, text): + for index, token, value in RegexLexer.get_tokens_unprocessed(self, text): + if token is Name: + lowercase_value = value.lower() + if lowercase_value in self.builtins: + yield index, Name.Builtin, value + continue + if lowercase_value in self.keywords: + yield index, Keyword, value + continue + if lowercase_value in self.functions: + yield index, Name.Builtin, value + continue + if lowercase_value in self.operators: + yield index, Operator, value + continue + yield index, token, value + + tokens = { + 'root': [ + # Whitespace + (r'\s+', Whitespace), + + # single line comment + (r'//.*?\n', Comment.Single), + + # lid header + (r'([a-z0-9-]+)(:)([ \t]*)(.*(?:\n[ \t].+)*)', + bygroups(Name.Attribute, Operator, Whitespace, String)), + + default('code') # no header match, switch to code + ], + 'code': [ + # Whitespace + (r'\s+', Whitespace), + + # single line comment + (r'(//.*?)(\n)', bygroups(Comment.Single, Whitespace)), + + # multi-line comment + (r'/\*', Comment.Multiline, 'comment'), + + # strings and characters + (r'"', String, 'string'), + (r"'(\\.|\\[0-7]{1,3}|\\x[a-f0-9]{1,2}|[^\\\'\n])'", String.Char), + + # binary integer + (r'#b[01]+', Number.Bin), + + # octal integer + (r'#o[0-7]+', Number.Oct), + + # floating point + (r'[-+]?(\d*\.\d+(e[-+]?\d+)?|\d+(\.\d*)?e[-+]?\d+)', Number.Float), + + # decimal integer + (r'[-+]?\d+', Number.Integer), + + # hex integer + (r'#x[0-9a-f]+', Number.Hex), + + # Macro parameters + (r'(\?' + valid_name + ')(:)' + r'(token|name|variable|expression|body|case-body|\*)', + bygroups(Name.Tag, Operator, Name.Builtin)), + (r'(\?)(:)(token|name|variable|expression|body|case-body|\*)', + bygroups(Name.Tag, Operator, Name.Builtin)), + (r'\?' + valid_name, Name.Tag), + + # Punctuation + (r'(=>|::|#\(|#\[|##|\?\?|\?=|\?|[(){}\[\],.;])', Punctuation), + + # Most operators are picked up as names and then re-flagged. + # This one isn't valid in a name though, so we pick it up now. + (r':=', Operator), + + # Pick up #t / #f before we match other stuff with #. + (r'#[tf]', Literal), + + # #"foo" style keywords + (r'#"', String.Symbol, 'keyword'), + + # #rest, #key, #all-keys, etc. + (r'#[a-z0-9-]+', Keyword), + + # required-init-keyword: style keywords. + (valid_name + ':', Keyword), + + # class names + ('<' + valid_name + '>', Name.Class), + + # define variable forms. + (r'\*' + valid_name + r'\*', Name.Variable.Global), + + # define constant forms. + (r'\$' + valid_name, Name.Constant), + + # everything else. We re-flag some of these in the method above. + (valid_name, Name), + ], + 'comment': [ + (r'[^*/]+', Comment.Multiline), + (r'/\*', Comment.Multiline, '#push'), + (r'\*/', Comment.Multiline, '#pop'), + (r'[*/]', Comment.Multiline) + ], + 'keyword': [ + (r'"', String.Symbol, '#pop'), + (r'[^\\"]+', String.Symbol), # all other characters + ], + 'string': [ + (r'"', String, '#pop'), + (r'\\([\\abfnrtv"\']|x[a-f0-9]{2,4}|[0-7]{1,3})', String.Escape), + (r'[^\\"\n]+', String), # all other characters + (r'\\\n', String), # line continuation + (r'\\', String), # stray backslash + ] + } + + +class DylanLidLexer(RegexLexer): + """ + For Dylan LID (Library Interchange Definition) files. + """ + + name = 'DylanLID' + aliases = ['dylan-lid', 'lid'] + filenames = ['*.lid', '*.hdp'] + mimetypes = ['text/x-dylan-lid'] + url = 'http://www.opendylan.org/' + version_added = '1.6' + flags = re.IGNORECASE + + tokens = { + 'root': [ + # Whitespace + (r'\s+', Whitespace), + + # single line comment + (r'(//.*?)(\n)', bygroups(Comment.Single, Whitespace)), + + # lid header + (r'(.*?)(:)([ \t]*)(.*(?:\n[ \t].+)*)', + bygroups(Name.Attribute, Operator, Whitespace, String)), + ] + } + + +class DylanConsoleLexer(Lexer): + """ + For Dylan interactive console output. + + This is based on a copy of the ``RubyConsoleLexer``. + """ + name = 'Dylan session' + aliases = ['dylan-console', 'dylan-repl'] + filenames = ['*.dylan-console'] + mimetypes = ['text/x-dylan-console'] + url = 'http://www.opendylan.org/' + version_added = '1.6' + _example = 'dylan-console/console.dylan-console' + + _prompt_re = re.compile(r'\?| ') + + def get_tokens_unprocessed(self, text): + dylexer = DylanLexer(**self.options) + + curcode = '' + insertions = [] + for match in line_re.finditer(text): + line = match.group() + m = self._prompt_re.match(line) + if m is not None: + end = m.end() + insertions.append((len(curcode), + [(0, Generic.Prompt, line[:end])])) + curcode += line[end:] + else: + if curcode: + yield from do_insertions(insertions, + dylexer.get_tokens_unprocessed(curcode)) + curcode = '' + insertions = [] + yield match.start(), Generic.Output, line + if curcode: + yield from do_insertions(insertions, + dylexer.get_tokens_unprocessed(curcode)) diff --git a/pygments/lexers/ecl.py b/pygments/lexers/ecl.py new file mode 100644 index 0000000000000000000000000000000000000000..f1ee3bf176df8c912b8d2bd61c507051e2ce683b --- /dev/null +++ b/pygments/lexers/ecl.py @@ -0,0 +1,144 @@ +""" + pygments.lexers.ecl + ~~~~~~~~~~~~~~~~~~~ + + Lexers for the ECL language. + + :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS. + :license: BSD, see LICENSE for details. +""" + +import re + +from pygments.lexer import RegexLexer, include, bygroups, words +from pygments.token import Comment, Operator, Keyword, Name, String, \ + Number, Punctuation, Whitespace + +__all__ = ['ECLLexer'] + + +class ECLLexer(RegexLexer): + """ + Lexer for the declarative big-data ECL language. + """ + + name = 'ECL' + url = 'https://hpccsystems.com/training/documentation/ecl-language-reference/html' + aliases = ['ecl'] + filenames = ['*.ecl'] + mimetypes = ['application/x-ecl'] + version_added = '1.5' + + flags = re.IGNORECASE | re.MULTILINE + + tokens = { + 'root': [ + include('whitespace'), + include('statements'), + ], + 'whitespace': [ + (r'\s+', Whitespace), + (r'\/\/.*', Comment.Single), + (r'/(\\\n)?\*(.|\n)*?\*(\\\n)?/', Comment.Multiline), + ], + 'statements': [ + include('types'), + include('keywords'), + include('functions'), + include('hash'), + (r'"', String, 'string'), + (r'\'', String, 'string'), + (r'(\d+\.\d*|\.\d+|\d+)e[+-]?\d+[lu]*', Number.Float), + (r'(\d+\.\d*|\.\d+|\d+f)f?', Number.Float), + (r'0x[0-9a-f]+[lu]*', Number.Hex), + (r'0[0-7]+[lu]*', Number.Oct), + (r'\d+[lu]*', Number.Integer), + (r'[~!%^&*+=|?:<>/-]+', Operator), + (r'[{}()\[\],.;]', Punctuation), + (r'[a-z_]\w*', Name), + ], + 'hash': [ + (r'^#.*$', Comment.Preproc), + ], + 'types': [ + (r'(RECORD|END)\D', Keyword.Declaration), + (r'((?:ASCII|BIG_ENDIAN|BOOLEAN|DATA|DECIMAL|EBCDIC|INTEGER|PATTERN|' + r'QSTRING|REAL|RECORD|RULE|SET OF|STRING|TOKEN|UDECIMAL|UNICODE|' + r'UNSIGNED|VARSTRING|VARUNICODE)\d*)(\s+)', + bygroups(Keyword.Type, Whitespace)), + ], + 'keywords': [ + (words(( + 'APPLY', 'ASSERT', 'BUILD', 'BUILDINDEX', 'EVALUATE', 'FAIL', + 'KEYDIFF', 'KEYPATCH', 'LOADXML', 'NOTHOR', 'NOTIFY', 'OUTPUT', + 'PARALLEL', 'SEQUENTIAL', 'SOAPCALL', 'CHECKPOINT', 'DEPRECATED', + 'FAILCODE', 'FAILMESSAGE', 'FAILURE', 'GLOBAL', 'INDEPENDENT', + 'ONWARNING', 'PERSIST', 'PRIORITY', 'RECOVERY', 'STORED', 'SUCCESS', + 'WAIT', 'WHEN'), suffix=r'\b'), + Keyword.Reserved), + # These are classed differently, check later + (words(( + 'ALL', 'AND', 'ANY', 'AS', 'ATMOST', 'BEFORE', 'BEGINC++', 'BEST', + 'BETWEEN', 'CASE', 'CONST', 'COUNTER', 'CSV', 'DESCEND', 'ENCRYPT', + 'ENDC++', 'ENDMACRO', 'EXCEPT', 'EXCLUSIVE', 'EXPIRE', 'EXPORT', + 'EXTEND', 'FALSE', 'FEW', 'FIRST', 'FLAT', 'FULL', 'FUNCTION', + 'GROUP', 'HEADER', 'HEADING', 'HOLE', 'IFBLOCK', 'IMPORT', 'IN', + 'JOINED', 'KEEP', 'KEYED', 'LAST', 'LEFT', 'LIMIT', 'LOAD', 'LOCAL', + 'LOCALE', 'LOOKUP', 'MACRO', 'MANY', 'MAXCOUNT', 'MAXLENGTH', + 'MIN SKEW', 'MODULE', 'INTERFACE', 'NAMED', 'NOCASE', 'NOROOT', + 'NOSCAN', 'NOSORT', 'NOT', 'OF', 'ONLY', 'OPT', 'OR', 'OUTER', + 'OVERWRITE', 'PACKED', 'PARTITION', 'PENALTY', 'PHYSICALLENGTH', + 'PIPE', 'QUOTE', 'RELATIONSHIP', 'REPEAT', 'RETURN', 'RIGHT', + 'SCAN', 'SELF', 'SEPARATOR', 'SERVICE', 'SHARED', 'SKEW', 'SKIP', + 'SQL', 'STORE', 'TERMINATOR', 'THOR', 'THRESHOLD', 'TOKEN', + 'TRANSFORM', 'TRIM', 'TRUE', 'TYPE', 'UNICODEORDER', 'UNSORTED', + 'VALIDATE', 'VIRTUAL', 'WHOLE', 'WILD', 'WITHIN', 'XML', 'XPATH', + '__COMPRESSED__'), suffix=r'\b'), + Keyword.Reserved), + ], + 'functions': [ + (words(( + 'ABS', 'ACOS', 'ALLNODES', 'ASCII', 'ASIN', 'ASSTRING', 'ATAN', + 'ATAN2', 'AVE', 'CASE', 'CHOOSE', 'CHOOSEN', 'CHOOSESETS', + 'CLUSTERSIZE', 'COMBINE', 'CORRELATION', 'COS', 'COSH', 'COUNT', + 'COVARIANCE', 'CRON', 'DATASET', 'DEDUP', 'DEFINE', 'DENORMALIZE', + 'DISTRIBUTE', 'DISTRIBUTED', 'DISTRIBUTION', 'EBCDIC', 'ENTH', + 'ERROR', 'EVALUATE', 'EVENT', 'EVENTEXTRA', 'EVENTNAME', 'EXISTS', + 'EXP', 'FAILCODE', 'FAILMESSAGE', 'FETCH', 'FROMUNICODE', + 'GETISVALID', 'GLOBAL', 'GRAPH', 'GROUP', 'HASH', 'HASH32', + 'HASH64', 'HASHCRC', 'HASHMD5', 'HAVING', 'IF', 'INDEX', + 'INTFORMAT', 'ISVALID', 'ITERATE', 'JOIN', 'KEYUNICODE', 'LENGTH', + 'LIBRARY', 'LIMIT', 'LN', 'LOCAL', 'LOG', 'LOOP', 'MAP', 'MATCHED', + 'MATCHLENGTH', 'MATCHPOSITION', 'MATCHTEXT', 'MATCHUNICODE', 'MAX', + 'MERGE', 'MERGEJOIN', 'MIN', 'NOLOCAL', 'NONEMPTY', 'NORMALIZE', + 'PARSE', 'PIPE', 'POWER', 'PRELOAD', 'PROCESS', 'PROJECT', 'PULL', + 'RANDOM', 'RANGE', 'RANK', 'RANKED', 'REALFORMAT', 'RECORDOF', + 'REGEXFIND', 'REGEXREPLACE', 'REGROUP', 'REJECTED', 'ROLLUP', + 'ROUND', 'ROUNDUP', 'ROW', 'ROWDIFF', 'SAMPLE', 'SET', 'SIN', + 'SINH', 'SIZEOF', 'SOAPCALL', 'SORT', 'SORTED', 'SQRT', 'STEPPED', + 'STORED', 'SUM', 'TABLE', 'TAN', 'TANH', 'THISNODE', 'TOPN', + 'TOUNICODE', 'TRANSFER', 'TRIM', 'TRUNCATE', 'TYPEOF', 'UNGROUP', + 'UNICODEORDER', 'VARIANCE', 'WHICH', 'WORKUNIT', 'XMLDECODE', + 'XMLENCODE', 'XMLTEXT', 'XMLUNICODE'), suffix=r'\b'), + Name.Function), + ], + 'string': [ + (r'"', String, '#pop'), + (r'\'', String, '#pop'), + (r'[^"\']+', String), + ], + } + + def analyse_text(text): + """This is very difficult to guess relative to other business languages. + -> in conjunction with BEGIN/END seems relatively rare though.""" + result = 0 + + if '->' in text: + result += 0.01 + if 'BEGIN' in text: + result += 0.01 + if 'END' in text: + result += 0.01 + + return result diff --git a/pygments/lexers/eiffel.py b/pygments/lexers/eiffel.py new file mode 100644 index 0000000000000000000000000000000000000000..691c718736d6ab8b00ff3974cc3851b6250318fb --- /dev/null +++ b/pygments/lexers/eiffel.py @@ -0,0 +1,68 @@ +""" + pygments.lexers.eiffel + ~~~~~~~~~~~~~~~~~~~~~~ + + Lexer for the Eiffel language. + + :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS. + :license: BSD, see LICENSE for details. +""" + +from pygments.lexer import RegexLexer, include, words, bygroups +from pygments.token import Comment, Operator, Keyword, Name, String, Number, \ + Punctuation, Whitespace + +__all__ = ['EiffelLexer'] + + +class EiffelLexer(RegexLexer): + """ + For Eiffel source code. + """ + name = 'Eiffel' + url = 'https://www.eiffel.com' + aliases = ['eiffel'] + filenames = ['*.e'] + mimetypes = ['text/x-eiffel'] + version_added = '2.0' + + tokens = { + 'root': [ + (r'[^\S\n]+', Whitespace), + (r'--.*?$', Comment.Single), + (r'[^\S\n]+', Whitespace), + # Please note that keyword and operator are case insensitive. + (r'(?i)(true|false|void|current|result|precursor)\b', Keyword.Constant), + (r'(?i)(not|xor|implies|or)\b', Operator.Word), + (r'(?i)(and)(?:(\s+)(then))?\b', + bygroups(Operator.Word, Whitespace, Operator.Word)), + (r'(?i)(or)(?:(\s+)(else))?\b', + bygroups(Operator.Word, Whitespace, Operator.Word)), + (words(( + 'across', 'agent', 'alias', 'all', 'as', 'assign', 'attached', + 'attribute', 'check', 'class', 'convert', 'create', 'debug', + 'deferred', 'detachable', 'do', 'else', 'elseif', 'end', 'ensure', + 'expanded', 'export', 'external', 'feature', 'from', 'frozen', 'if', + 'inherit', 'inspect', 'invariant', 'like', 'local', 'loop', 'none', + 'note', 'obsolete', 'old', 'once', 'only', 'redefine', 'rename', + 'require', 'rescue', 'retry', 'select', 'separate', 'then', + 'undefine', 'until', 'variant', 'when'), prefix=r'(?i)\b', suffix=r'\b'), + Keyword.Reserved), + (r'"\[([^\]%]|%(.|\n)|\][^"])*?\]"', String), + (r'"([^"%\n]|%.)*?"', String), + include('numbers'), + (r"'([^'%]|%'|%%)'", String.Char), + (r"(//|\\\\|>=|<=|:=|/=|~|/~|[\\?!#%&@|+/\-=>*$<^\[\]])", Operator), + (r"([{}():;,.])", Punctuation), + (r'([a-z]\w*)|([A-Z][A-Z0-9_]*[a-z]\w*)', Name), + (r'([A-Z][A-Z0-9_]*)', Name.Class), + (r'\n+', Whitespace), + ], + 'numbers': [ + (r'0[xX][a-fA-F0-9]+', Number.Hex), + (r'0[bB][01]+', Number.Bin), + (r'0[cC][0-7]+', Number.Oct), + (r'([0-9]+\.[0-9]*)|([0-9]*\.[0-9]+)', Number.Float), + (r'[0-9]+', Number.Integer), + ], + } diff --git a/pygments/lexers/elm.py b/pygments/lexers/elm.py new file mode 100644 index 0000000000000000000000000000000000000000..cebbbf00295ec70f092f54581567b5bbcf17a5fb --- /dev/null +++ b/pygments/lexers/elm.py @@ -0,0 +1,123 @@ +""" + pygments.lexers.elm + ~~~~~~~~~~~~~~~~~~~ + + Lexer for the Elm programming language. + + :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS. + :license: BSD, see LICENSE for details. +""" + +from pygments.lexer import RegexLexer, words, include, bygroups +from pygments.token import Comment, Keyword, Name, Number, Punctuation, \ + String, Whitespace + +__all__ = ['ElmLexer'] + + +class ElmLexer(RegexLexer): + """ + For Elm source code. + """ + + name = 'Elm' + url = 'https://elm-lang.org/' + aliases = ['elm'] + filenames = ['*.elm'] + mimetypes = ['text/x-elm'] + version_added = '2.1' + + validName = r'[a-z_][a-zA-Z0-9_\']*' + + specialName = r'^main ' + + builtinOps = ( + '~', '||', '|>', '|', '`', '^', '\\', '\'', '>>', '>=', '>', '==', + '=', '<~', '<|', '<=', '<<', '<-', '<', '::', ':', '/=', '//', '/', + '..', '.', '->', '-', '++', '+', '*', '&&', '%', + ) + + reservedWords = words(( + 'alias', 'as', 'case', 'else', 'if', 'import', 'in', + 'let', 'module', 'of', 'port', 'then', 'type', 'where', + ), suffix=r'\b') + + tokens = { + 'root': [ + + # Comments + (r'\{-', Comment.Multiline, 'comment'), + (r'--.*', Comment.Single), + + # Whitespace + (r'\s+', Whitespace), + + # Strings + (r'"', String, 'doublequote'), + + # Modules + (r'^(\s*)(module)(\s*)', bygroups(Whitespace, Keyword.Namespace, + Whitespace), 'imports'), + + # Imports + (r'^(\s*)(import)(\s*)', bygroups(Whitespace, Keyword.Namespace, + Whitespace), 'imports'), + + # Shaders + (r'\[glsl\|.*', Name.Entity, 'shader'), + + # Keywords + (reservedWords, Keyword.Reserved), + + # Types + (r'[A-Z][a-zA-Z0-9_]*', Keyword.Type), + + # Main + (specialName, Keyword.Reserved), + + # Prefix Operators + (words((builtinOps), prefix=r'\(', suffix=r'\)'), Name.Function), + + # Infix Operators + (words(builtinOps), Name.Function), + + # Numbers + include('numbers'), + + # Variable Names + (validName, Name.Variable), + + # Parens + (r'[,()\[\]{}]', Punctuation), + + ], + + 'comment': [ + (r'-(?!\})', Comment.Multiline), + (r'\{-', Comment.Multiline, 'comment'), + (r'[^-}]', Comment.Multiline), + (r'-\}', Comment.Multiline, '#pop'), + ], + + 'doublequote': [ + (r'\\u[0-9a-fA-F]{4}', String.Escape), + (r'\\[nrfvb\\"]', String.Escape), + (r'[^"]', String), + (r'"', String, '#pop'), + ], + + 'imports': [ + (r'\w+(\.\w+)*', Name.Class, '#pop'), + ], + + 'numbers': [ + (r'_?\d+\.(?=\d+)', Number.Float), + (r'_?\d+', Number.Integer), + ], + + 'shader': [ + (r'\|(?!\])', Name.Entity), + (r'\|\]', Name.Entity, '#pop'), + (r'(.*)(\n)', bygroups(Name.Entity, Whitespace)), + ], + } diff --git a/pygments/lexers/elpi.py b/pygments/lexers/elpi.py new file mode 100644 index 0000000000000000000000000000000000000000..5e4b1e5a0ed446a613413c01661d3016dc9c9412 --- /dev/null +++ b/pygments/lexers/elpi.py @@ -0,0 +1,175 @@ +""" + pygments.lexers.elpi + ~~~~~~~~~~~~~~~~~~~~ + + Lexer for the `Elpi `_ programming language. + + :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS. + :license: BSD, see LICENSE for details. +""" + +from pygments.lexer import RegexLexer, bygroups, include, using +from pygments.token import Text, Comment, Operator, Keyword, Name, String, \ + Number, Punctuation + +__all__ = ['ElpiLexer'] + +from pygments.lexers.theorem import CoqLexer + +class ElpiLexer(RegexLexer): + """ + Lexer for the Elpi programming language. + """ + + name = 'Elpi' + url = 'http://github.com/LPCIC/elpi' + aliases = ['elpi'] + filenames = ['*.elpi'] + mimetypes = ['text/x-elpi'] + version_added = '2.11' + + lcase_re = r"[a-z]" + ucase_re = r"[A-Z]" + digit_re = r"[0-9]" + schar2_re = r"([+*^?/<>`'@#~=&!])" + schar_re = rf"({schar2_re}|-|\$|_)" + idchar_re = rf"({lcase_re}|{ucase_re}|{digit_re}|{schar_re})" + idcharstarns_re = rf"({idchar_re}*(\.({lcase_re}|{ucase_re}){idchar_re}*)*)" + symbchar_re = rf"({lcase_re}|{ucase_re}|{digit_re}|{schar_re}|:)" + constant_re = rf"({ucase_re}{idchar_re}*|{lcase_re}{idcharstarns_re}|{schar2_re}{symbchar_re}*|_{idchar_re}+)" + symbol_re = r"(,|<=>|->|:-|;|\?-|->|&|=>|\bas\b|\buvar\b|<|=<|=|==|>=|>|\bi<|\bi=<|\bi>=|\bi>|\bis\b|\br<|\br=<|\br>=|\br>|\bs<|\bs=<|\bs>=|\bs>|@|::|\[\]|`->|`:|`:=|\^|-|\+|\bi-|\bi\+|r-|r\+|/|\*|\bdiv\b|\bi\*|\bmod\b|\br\*|~|\bi~|\br~)" + escape_re = rf"\(({constant_re}|{symbol_re})\)" + const_sym_re = rf"({constant_re}|{symbol_re}|{escape_re})" + + tokens = { + 'root': [ + include('elpi') + ], + + 'elpi': [ + include('_elpi-comment'), + + (r"(:before|:after|:if|:name)(\s*)(\")", + bygroups(Keyword.Mode, Text.Whitespace, String.Double), + 'elpi-string'), + (r"(:index)(\s*)(\()", bygroups(Keyword.Mode, Text.Whitespace, Punctuation), + 'elpi-indexing-expr'), + (rf"\b(external pred|pred)(\s+)({const_sym_re})", + bygroups(Keyword.Declaration, Text.Whitespace, Name.Function), + 'elpi-pred-item'), + (rf"\b(external type|type)(\s+)(({const_sym_re}(,\s*)?)+)", + bygroups(Keyword.Declaration, Text.Whitespace, Name.Function), + 'elpi-type'), + (rf"\b(kind)(\s+)(({const_sym_re}|,)+)", + bygroups(Keyword.Declaration, Text.Whitespace, Name.Function), + 'elpi-type'), + (rf"\b(typeabbrev)(\s+)({const_sym_re})", + bygroups(Keyword.Declaration, Text.Whitespace, Name.Function), + 'elpi-type'), + (r"\b(typeabbrev)(\s+)(\([^)]+\))", + bygroups(Keyword.Declaration, Text.Whitespace, Name.Function), + 'elpi-type'), + (r"\b(accumulate)(\s+)(\")", + bygroups(Keyword.Declaration, Text.Whitespace, String.Double), + 'elpi-string'), + (rf"\b(accumulate|namespace|local)(\s+)({constant_re})", + bygroups(Keyword.Declaration, Text.Whitespace, Text)), + (rf"\b(shorten)(\s+)({constant_re}\.)", + bygroups(Keyword.Declaration, Text.Whitespace, Text)), + (r"\b(pi|sigma)(\s+)([a-zA-Z][A-Za-z0-9_ ]*)(\\)", + bygroups(Keyword.Declaration, Text.Whitespace, Name.Variable, Text)), + (rf"\b(constraint)(\s+)(({const_sym_re}(\s+)?)+)", + bygroups(Keyword.Declaration, Text.Whitespace, Name.Function), + 'elpi-chr-rule-start'), + + (rf"(?=[A-Z_]){constant_re}", Name.Variable), + (rf"(?=[a-z_])({constant_re}|_)\\", Name.Variable), + (r"_", Name.Variable), + (rf"({symbol_re}|!|=>|;)", Keyword.Declaration), + (constant_re, Text), + (r"\[|\]|\||=>", Keyword.Declaration), + (r'"', String.Double, 'elpi-string'), + (r'`', String.Double, 'elpi-btick'), + (r'\'', String.Double, 'elpi-tick'), + (r'\{\{', Punctuation, 'elpi-quote'), + (r'\{[^\{]', Text, 'elpi-spill'), + (r"\(", Punctuation, 'elpi-in-parens'), + (r'\d[\d_]*', Number.Integer), + (r'-?\d[\d_]*(.[\d_]*)?([eE][+\-]?\d[\d_]*)', Number.Float), + (r"[\+\*\-/\^\.]", Operator), + ], + '_elpi-comment': [ + (r'%[^\n]*\n', Comment), + (r'/(?:\\\n)?[*](?:[^*]|[*](?!(?:\\\n)?/))*[*](?:\\\n)?/', Comment), + (r"\s+", Text.Whitespace), + ], + 'elpi-indexing-expr':[ + (r'[0-9 _]+', Number.Integer), + (r'\)', Punctuation, '#pop'), + ], + 'elpi-type': [ + (r"(ctype\s+)(\")", bygroups(Keyword.Type, String.Double), 'elpi-string'), + (r'->', Keyword.Type), + (constant_re, Keyword.Type), + (r"\(|\)", Keyword.Type), + (r"\.", Text, '#pop'), + include('_elpi-comment'), + ], + 'elpi-chr-rule-start': [ + (r"\{", Punctuation, 'elpi-chr-rule'), + include('_elpi-comment'), + ], + 'elpi-chr-rule': [ + (r"\brule\b", Keyword.Declaration), + (r"\\", Keyword.Declaration), + (r"\}", Punctuation, '#pop:2'), + include('elpi'), + ], + 'elpi-pred-item': [ + (r"[io]:", Keyword.Mode, 'elpi-ctype'), + (r"\.", Text, '#pop'), + include('_elpi-comment'), + ], + 'elpi-ctype': [ + (r"(ctype\s+)(\")", bygroups(Keyword.Type, String.Double), 'elpi-string'), + (r'->', Keyword.Type), + (constant_re, Keyword.Type), + (r"\(|\)", Keyword.Type), + (r",", Text, '#pop'), + (r"\.", Text, '#pop:2'), + include('_elpi-comment'), + ], + 'elpi-btick': [ + (r'[^` ]+', String.Double), + (r'`', String.Double, '#pop'), + ], + 'elpi-tick': [ + (r'[^\' ]+', String.Double), + (r'\'', String.Double, '#pop'), + ], + 'elpi-string': [ + (r'[^\"]+', String.Double), + (r'"', String.Double, '#pop'), + ], + 'elpi-quote': [ + (r'\}\}', Punctuation, '#pop'), + (r"\s+", Text.Whitespace), + (r"(lp:)(\{\{)", bygroups(Number, Punctuation), 'elpi-quote-exit'), + (rf"(lp:)((?=[A-Z_]){constant_re})", bygroups(Number, Name.Variable)), + (r"((?!lp:|\}\}).)+", using(CoqLexer)), + ], + 'elpi-quote-exit': [ + include('elpi'), + (r'\}\}', Punctuation, '#pop'), + ], + 'elpi-spill': [ + (r'\{[^\{]', Text, '#push'), + (r'\}[^\}]', Text, '#pop'), + include('elpi'), + ], + 'elpi-in-parens': [ + (r"\(", Punctuation, '#push'), + include('elpi'), + (r"\)", Punctuation, '#pop'), + ], + } diff --git a/pygments/lexers/email.py b/pygments/lexers/email.py new file mode 100644 index 0000000000000000000000000000000000000000..ea071c80026f3399cef79d31b17b2e9587b4e038 --- /dev/null +++ b/pygments/lexers/email.py @@ -0,0 +1,132 @@ +""" + pygments.lexers.email + ~~~~~~~~~~~~~~~~~~~~~ + + Lexer for the raw E-mail. + + :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS. + :license: BSD, see LICENSE for details. +""" + +from pygments.lexer import RegexLexer, DelegatingLexer, bygroups +from pygments.lexers.mime import MIMELexer +from pygments.token import Text, Keyword, Name, String, Number, Comment +from pygments.util import get_bool_opt + +__all__ = ["EmailLexer"] + + +class EmailHeaderLexer(RegexLexer): + """ + Sub-lexer for raw E-mail. This lexer only process header part of e-mail. + + .. versionadded:: 2.5 + """ + + def __init__(self, **options): + super().__init__(**options) + self.highlight_x = get_bool_opt(options, "highlight-X-header", False) + + def get_x_header_tokens(self, match): + if self.highlight_x: + # field + yield match.start(1), Name.Tag, match.group(1) + + # content + default_actions = self.get_tokens_unprocessed( + match.group(2), stack=("root", "header")) + yield from default_actions + else: + # lowlight + yield match.start(1), Comment.Special, match.group(1) + yield match.start(2), Comment.Multiline, match.group(2) + + tokens = { + "root": [ + (r"^(?:[A-WYZ]|X400)[\w\-]*:", Name.Tag, "header"), + (r"^(X-(?:\w[\w\-]*:))([\s\S]*?\n)(?![ \t])", get_x_header_tokens), + ], + "header": [ + # folding + (r"\n[ \t]", Text.Whitespace), + (r"\n(?![ \t])", Text.Whitespace, "#pop"), + + # keywords + (r"\bE?SMTPS?\b", Keyword), + (r"\b(?:HE|EH)LO\b", Keyword), + + # mailbox + (r"[\w\.\-\+=]+@[\w\.\-]+", Name.Label), + (r"<[\w\.\-\+=]+@[\w\.\-]+>", Name.Label), + + # domain + (r"\b(\w[\w\.-]*\.[\w\.-]*\w[a-zA-Z]+)\b", Name.Function), + + # IPv4 + (r"(?<=\b)(?:(?:25[0-5]|2[0-4][0-9]|1?[0-9][0-9]?)\.){3}(?:25[0" + r"-5]|2[0-4][0-9]|1?[0-9][0-9]?)(?=\b)", + Number.Integer), + + # IPv6 + (r"(?<=\b)([0-9a-fA-F]{1,4}:){1,7}:(?!\b)", Number.Hex), + (r"(?<=\b):((:[0-9a-fA-F]{1,4}){1,7}|:)(?=\b)", Number.Hex), + (r"(?<=\b)([0-9a-fA-F]{1,4}:){7,7}[0-9a-fA-F]{1,4}(?=\b)", Number.Hex), + (r"(?<=\b)([0-9a-fA-F]{1,4}:){1,6}:[0-9a-fA-F]{1,4}(?=\b)", Number.Hex), + (r"(?<=\b)[0-9a-fA-F]{1,4}:((:[0-9a-fA-F]{1,4}){1,6})(?=\b)", Number.Hex), + (r"(?<=\b)fe80:(:[0-9a-fA-F]{0,4}){0,4}%[0-9a-zA-Z]{1,}(?=\b)", Number.Hex), + (r"(?<=\b)([0-9a-fA-F]{1,4}:){1,5}(:[0-9a-fA-F]{1,4}){1,2}(?=\b)", Number.Hex), + (r"(?<=\b)([0-9a-fA-F]{1,4}:){1,4}(:[0-9a-fA-F]{1,4}){1,3}(?=\b)", + Number.Hex), + (r"(?<=\b)([0-9a-fA-F]{1,4}:){1,3}(:[0-9a-fA-F]{1,4}){1,4}(?=\b)", + Number.Hex), + (r"(?<=\b)([0-9a-fA-F]{1,4}:){1,2}(:[0-9a-fA-F]{1,4}){1,5}(?=\b)", + Number.Hex), + (r"(?<=\b)::(ffff(:0{1,4}){0,1}:){0,1}((25[0-5]|(2[0-4]|1{0,1}" + r"[0-9]){0,1}[0-9])\.){3,3}(25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}" + r"[0-9])(?=\b)", + Number.Hex), + (r"(?<=\b)([0-9a-fA-F]{1,4}:){1,4}:((25[0-5]|(2[0-4]|1{0,1}[0-9])" + r"{0,1}[0-9])\.){3,3}(25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}[0-9])(?=\b)", + Number.Hex), + + # Date time + (r"(?:(Sun|Mon|Tue|Wed|Thu|Fri|Sat),\s+)?(0[1-9]|[1-2]?[0-9]|3[" + r"01])\s+(Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)\s+(" + r"19[0-9]{2}|[2-9][0-9]{3})\s+(2[0-3]|[0-1][0-9]):([0-5][0-9])" + r"(?::(60|[0-5][0-9]))?(?:\.\d{1,5})?\s+([-\+][0-9]{2}[0-5][0-" + r"9]|\(?(?:UTC?|GMT|(?:E|C|M|P)(?:ST|ET|DT)|[A-IK-Z])\)?)", + Name.Decorator), + + # RFC-2047 encoded string + (r"(=\?)([\w-]+)(\?)([BbQq])(\?)([\[\w!\"#$%&\'()*+,-./:;<=>@[\\" + r"\]^_`{|}~]+)(\?=)", + bygroups(String.Affix, Name.Constant, String.Affix, Keyword.Constant, + String.Affix, Number.Hex, String.Affix)), + + # others + (r'[\s]+', Text.Whitespace), + (r'[\S]', Text), + ], + } + + +class EmailLexer(DelegatingLexer): + """ + Lexer for raw E-mail. + + Additional options accepted: + + `highlight-X-header` + Highlight the fields of ``X-`` user-defined email header. (default: + ``False``). + """ + + name = "E-mail" + aliases = ["email", "eml"] + filenames = ["*.eml"] + mimetypes = ["message/rfc822"] + url = "https://en.wikipedia.org/wiki/Email#Message_format" + version_added = '2.5' + + def __init__(self, **options): + super().__init__(EmailHeaderLexer, MIMELexer, Comment, **options) diff --git a/pygments/lexers/erlang.py b/pygments/lexers/erlang.py new file mode 100644 index 0000000000000000000000000000000000000000..120f5049fdc3942173710950c3fbebf7a36ef6cd --- /dev/null +++ b/pygments/lexers/erlang.py @@ -0,0 +1,526 @@ +""" + pygments.lexers.erlang + ~~~~~~~~~~~~~~~~~~~~~~ + + Lexers for Erlang. + + :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS. + :license: BSD, see LICENSE for details. +""" + +import re + +from pygments.lexer import Lexer, RegexLexer, bygroups, words, do_insertions, \ + include, default, line_re +from pygments.token import Comment, Operator, Keyword, Name, String, \ + Number, Punctuation, Generic, Whitespace + +__all__ = ['ErlangLexer', 'ErlangShellLexer', 'ElixirConsoleLexer', + 'ElixirLexer'] + + +class ErlangLexer(RegexLexer): + """ + For the Erlang functional programming language. + """ + + name = 'Erlang' + url = 'https://www.erlang.org/' + aliases = ['erlang'] + filenames = ['*.erl', '*.hrl', '*.es', '*.escript'] + mimetypes = ['text/x-erlang'] + version_added = '0.9' + + keywords = ( + 'after', 'begin', 'case', 'catch', 'cond', 'end', 'fun', 'if', + 'let', 'of', 'query', 'receive', 'try', 'when', + ) + + builtins = ( # See erlang(3) man page + 'abs', 'append_element', 'apply', 'atom_to_list', 'binary_to_list', + 'bitstring_to_list', 'binary_to_term', 'bit_size', 'bump_reductions', + 'byte_size', 'cancel_timer', 'check_process_code', 'delete_module', + 'demonitor', 'disconnect_node', 'display', 'element', 'erase', 'exit', + 'float', 'float_to_list', 'fun_info', 'fun_to_list', + 'function_exported', 'garbage_collect', 'get', 'get_keys', + 'group_leader', 'hash', 'hd', 'integer_to_list', 'iolist_to_binary', + 'iolist_size', 'is_atom', 'is_binary', 'is_bitstring', 'is_boolean', + 'is_builtin', 'is_float', 'is_function', 'is_integer', 'is_list', + 'is_number', 'is_pid', 'is_port', 'is_process_alive', 'is_record', + 'is_reference', 'is_tuple', 'length', 'link', 'list_to_atom', + 'list_to_binary', 'list_to_bitstring', 'list_to_existing_atom', + 'list_to_float', 'list_to_integer', 'list_to_pid', 'list_to_tuple', + 'load_module', 'localtime_to_universaltime', 'make_tuple', 'md5', + 'md5_final', 'md5_update', 'memory', 'module_loaded', 'monitor', + 'monitor_node', 'node', 'nodes', 'open_port', 'phash', 'phash2', + 'pid_to_list', 'port_close', 'port_command', 'port_connect', + 'port_control', 'port_call', 'port_info', 'port_to_list', + 'process_display', 'process_flag', 'process_info', 'purge_module', + 'put', 'read_timer', 'ref_to_list', 'register', 'resume_process', + 'round', 'send', 'send_after', 'send_nosuspend', 'set_cookie', + 'setelement', 'size', 'spawn', 'spawn_link', 'spawn_monitor', + 'spawn_opt', 'split_binary', 'start_timer', 'statistics', + 'suspend_process', 'system_flag', 'system_info', 'system_monitor', + 'system_profile', 'term_to_binary', 'tl', 'trace', 'trace_delivered', + 'trace_info', 'trace_pattern', 'trunc', 'tuple_size', 'tuple_to_list', + 'universaltime_to_localtime', 'unlink', 'unregister', 'whereis' + ) + + operators = r'(\+\+?|--?|\*|/|<|>|/=|=:=|=/=|=<|>=|==?|<-|!|\?)' + word_operators = ( + 'and', 'andalso', 'band', 'bnot', 'bor', 'bsl', 'bsr', 'bxor', + 'div', 'not', 'or', 'orelse', 'rem', 'xor' + ) + + atom_re = r"(?:[a-z]\w*|'[^\n']*[^\\]')" + + variable_re = r'(?:[A-Z_]\w*)' + + esc_char_re = r'[bdefnrstv\'"\\]' + esc_octal_re = r'[0-7][0-7]?[0-7]?' + esc_hex_re = r'(?:x[0-9a-fA-F]{2}|x\{[0-9a-fA-F]+\})' + esc_ctrl_re = r'\^[a-zA-Z]' + escape_re = r'(?:\\(?:'+esc_char_re+r'|'+esc_octal_re+r'|'+esc_hex_re+r'|'+esc_ctrl_re+r'))' + + macro_re = r'(?:'+variable_re+r'|'+atom_re+r')' + + base_re = r'(?:[2-9]|[12][0-9]|3[0-6])' + + tokens = { + 'root': [ + (r'\s+', Whitespace), + (r'(%.*)(\n)', bygroups(Comment, Whitespace)), + (words(keywords, suffix=r'\b'), Keyword), + (words(builtins, suffix=r'\b'), Name.Builtin), + (words(word_operators, suffix=r'\b'), Operator.Word), + (r'^-', Punctuation, 'directive'), + (operators, Operator), + (r'"', String, 'string'), + (r'<<', Name.Label), + (r'>>', Name.Label), + ('(' + atom_re + ')(:)', bygroups(Name.Namespace, Punctuation)), + ('(?:^|(?<=:))(' + atom_re + r')(\s*)(\()', + bygroups(Name.Function, Whitespace, Punctuation)), + (r'[+-]?' + base_re + r'#[0-9a-zA-Z]+', Number.Integer), + (r'[+-]?\d+', Number.Integer), + (r'[+-]?\d+.\d+', Number.Float), + (r'[]\[:_@\".{}()|;,]', Punctuation), + (variable_re, Name.Variable), + (atom_re, Name), + (r'\?'+macro_re, Name.Constant), + (r'\$(?:'+escape_re+r'|\\[ %]|[^\\])', String.Char), + (r'#'+atom_re+r'(:?\.'+atom_re+r')?', Name.Label), + + # Erlang script shebang + (r'\A#!.+\n', Comment.Hashbang), + + # EEP 43: Maps + # http://www.erlang.org/eeps/eep-0043.html + (r'#\{', Punctuation, 'map_key'), + ], + 'string': [ + (escape_re, String.Escape), + (r'"', String, '#pop'), + (r'~[0-9.*]*[~#+BPWXb-ginpswx]', String.Interpol), + (r'[^"\\~]+', String), + (r'~', String), + ], + 'directive': [ + (r'(define)(\s*)(\()('+macro_re+r')', + bygroups(Name.Entity, Whitespace, Punctuation, Name.Constant), '#pop'), + (r'(record)(\s*)(\()('+macro_re+r')', + bygroups(Name.Entity, Whitespace, Punctuation, Name.Label), '#pop'), + (atom_re, Name.Entity, '#pop'), + ], + 'map_key': [ + include('root'), + (r'=>', Punctuation, 'map_val'), + (r':=', Punctuation, 'map_val'), + (r'\}', Punctuation, '#pop'), + ], + 'map_val': [ + include('root'), + (r',', Punctuation, '#pop'), + (r'(?=\})', Punctuation, '#pop'), + ], + } + + +class ErlangShellLexer(Lexer): + """ + Shell sessions in erl (for Erlang code). + """ + name = 'Erlang erl session' + aliases = ['erl'] + filenames = ['*.erl-sh'] + mimetypes = ['text/x-erl-shellsession'] + url = 'https://www.erlang.org/' + version_added = '1.1' + + _prompt_re = re.compile(r'(?:\([\w@_.]+\))?\d+>(?=\s|\Z)') + + def get_tokens_unprocessed(self, text): + erlexer = ErlangLexer(**self.options) + + curcode = '' + insertions = [] + for match in line_re.finditer(text): + line = match.group() + m = self._prompt_re.match(line) + if m is not None: + end = m.end() + insertions.append((len(curcode), + [(0, Generic.Prompt, line[:end])])) + curcode += line[end:] + else: + if curcode: + yield from do_insertions(insertions, + erlexer.get_tokens_unprocessed(curcode)) + curcode = '' + insertions = [] + if line.startswith('*'): + yield match.start(), Generic.Traceback, line + else: + yield match.start(), Generic.Output, line + if curcode: + yield from do_insertions(insertions, + erlexer.get_tokens_unprocessed(curcode)) + + +def gen_elixir_string_rules(name, symbol, token): + states = {} + states['string_' + name] = [ + (rf'[^#{symbol}\\]+', token), + include('escapes'), + (r'\\.', token), + (rf'({symbol})', bygroups(token), "#pop"), + include('interpol') + ] + return states + + +def gen_elixir_sigstr_rules(term, term_class, token, interpol=True): + if interpol: + return [ + (rf'[^#{term_class}\\]+', token), + include('escapes'), + (r'\\.', token), + (rf'{term}[a-zA-Z]*', token, '#pop'), + include('interpol') + ] + else: + return [ + (rf'[^{term_class}\\]+', token), + (r'\\.', token), + (rf'{term}[a-zA-Z]*', token, '#pop'), + ] + + +class ElixirLexer(RegexLexer): + """ + For the Elixir language. + """ + + name = 'Elixir' + url = 'https://elixir-lang.org' + aliases = ['elixir', 'ex', 'exs'] + filenames = ['*.ex', '*.eex', '*.exs', '*.leex'] + mimetypes = ['text/x-elixir'] + version_added = '1.5' + + KEYWORD = ('fn', 'do', 'end', 'after', 'else', 'rescue', 'catch') + KEYWORD_OPERATOR = ('not', 'and', 'or', 'when', 'in') + BUILTIN = ( + 'case', 'cond', 'for', 'if', 'unless', 'try', 'receive', 'raise', + 'quote', 'unquote', 'unquote_splicing', 'throw', 'super', + ) + BUILTIN_DECLARATION = ( + 'def', 'defp', 'defmodule', 'defprotocol', 'defmacro', 'defmacrop', + 'defdelegate', 'defexception', 'defstruct', 'defimpl', 'defcallback', + ) + + BUILTIN_NAMESPACE = ('import', 'require', 'use', 'alias') + CONSTANT = ('nil', 'true', 'false') + + PSEUDO_VAR = ('_', '__MODULE__', '__DIR__', '__ENV__', '__CALLER__') + + OPERATORS3 = ( + '<<<', '>>>', '|||', '&&&', '^^^', '~~~', '===', '!==', + '~>>', '<~>', '|~>', '<|>', + ) + OPERATORS2 = ( + '==', '!=', '<=', '>=', '&&', '||', '<>', '++', '--', '|>', '=~', + '->', '<-', '|', '.', '=', '~>', '<~', + ) + OPERATORS1 = ('<', '>', '+', '-', '*', '/', '!', '^', '&') + + PUNCTUATION = ( + '\\\\', '<<', '>>', '=>', '(', ')', ':', ';', ',', '[', ']', + ) + + def get_tokens_unprocessed(self, text): + for index, token, value in RegexLexer.get_tokens_unprocessed(self, text): + if token is Name: + if value in self.KEYWORD: + yield index, Keyword, value + elif value in self.KEYWORD_OPERATOR: + yield index, Operator.Word, value + elif value in self.BUILTIN: + yield index, Keyword, value + elif value in self.BUILTIN_DECLARATION: + yield index, Keyword.Declaration, value + elif value in self.BUILTIN_NAMESPACE: + yield index, Keyword.Namespace, value + elif value in self.CONSTANT: + yield index, Name.Constant, value + elif value in self.PSEUDO_VAR: + yield index, Name.Builtin.Pseudo, value + else: + yield index, token, value + else: + yield index, token, value + + def gen_elixir_sigil_rules(): + # all valid sigil terminators (excluding heredocs) + terminators = [ + (r'\{', r'\}', '}', 'cb'), + (r'\[', r'\]', r'\]', 'sb'), + (r'\(', r'\)', ')', 'pa'), + ('<', '>', '>', 'ab'), + ('/', '/', '/', 'slas'), + (r'\|', r'\|', '|', 'pipe'), + ('"', '"', '"', 'quot'), + ("'", "'", "'", 'apos'), + ] + + # heredocs have slightly different rules + triquotes = [(r'"""', 'triquot'), (r"'''", 'triapos')] + + token = String.Other + states = {'sigils': []} + + for term, name in triquotes: + states['sigils'] += [ + (rf'(~[a-z])({term})', bygroups(token, String.Heredoc), + (name + '-end', name + '-intp')), + (rf'(~[A-Z])({term})', bygroups(token, String.Heredoc), + (name + '-end', name + '-no-intp')), + ] + + states[name + '-end'] = [ + (r'[a-zA-Z]+', token, '#pop'), + default('#pop'), + ] + states[name + '-intp'] = [ + (r'^(\s*)(' + term + ')', bygroups(Whitespace, String.Heredoc), '#pop'), + include('heredoc_interpol'), + ] + states[name + '-no-intp'] = [ + (r'^(\s*)(' + term +')', bygroups(Whitespace, String.Heredoc), '#pop'), + include('heredoc_no_interpol'), + ] + + for lterm, rterm, rterm_class, name in terminators: + states['sigils'] += [ + (r'~[a-z]' + lterm, token, name + '-intp'), + (r'~[A-Z]' + lterm, token, name + '-no-intp'), + ] + states[name + '-intp'] = \ + gen_elixir_sigstr_rules(rterm, rterm_class, token) + states[name + '-no-intp'] = \ + gen_elixir_sigstr_rules(rterm, rterm_class, token, interpol=False) + + return states + + op3_re = "|".join(re.escape(s) for s in OPERATORS3) + op2_re = "|".join(re.escape(s) for s in OPERATORS2) + op1_re = "|".join(re.escape(s) for s in OPERATORS1) + ops_re = rf'(?:{op3_re}|{op2_re}|{op1_re})' + punctuation_re = "|".join(re.escape(s) for s in PUNCTUATION) + alnum = r'\w' + name_re = rf'(?:\.\.\.|[a-z_]{alnum}*[!?]?)' + modname_re = rf'[A-Z]{alnum}*(?:\.[A-Z]{alnum}*)*' + complex_name_re = rf'(?:{name_re}|{modname_re}|{ops_re})' + special_atom_re = r'(?:\.\.\.|<<>>|%\{\}|%|\{\})' + + long_hex_char_re = r'(\\x\{)([\da-fA-F]+)(\})' + hex_char_re = r'(\\x[\da-fA-F]{1,2})' + escape_char_re = r'(\\[abdefnrstv])' + + tokens = { + 'root': [ + (r'\s+', Whitespace), + (r'#.*$', Comment.Single), + + # Various kinds of characters + (r'(\?)' + long_hex_char_re, + bygroups(String.Char, + String.Escape, Number.Hex, String.Escape)), + (r'(\?)' + hex_char_re, + bygroups(String.Char, String.Escape)), + (r'(\?)' + escape_char_re, + bygroups(String.Char, String.Escape)), + (r'\?\\?.', String.Char), + + # '::' has to go before atoms + (r':::', String.Symbol), + (r'::', Operator), + + # atoms + (r':' + special_atom_re, String.Symbol), + (r':' + complex_name_re, String.Symbol), + (r':"', String.Symbol, 'string_double_atom'), + (r":'", String.Symbol, 'string_single_atom'), + + # [keywords: ...] + (rf'({special_atom_re}|{complex_name_re})(:)(?=\s|\n)', + bygroups(String.Symbol, Punctuation)), + + # @attributes + (r'@' + name_re, Name.Attribute), + + # identifiers + (name_re, Name), + (rf'(%?)({modname_re})', bygroups(Punctuation, Name.Class)), + + # operators and punctuation + (op3_re, Operator), + (op2_re, Operator), + (punctuation_re, Punctuation), + (r'&\d', Name.Entity), # anon func arguments + (op1_re, Operator), + + # numbers + (r'0b[01]+', Number.Bin), + (r'0o[0-7]+', Number.Oct), + (r'0x[\da-fA-F]+', Number.Hex), + (r'\d(_?\d)*\.\d(_?\d)*([eE][-+]?\d(_?\d)*)?', Number.Float), + (r'\d(_?\d)*', Number.Integer), + + # strings and heredocs + (r'(""")(\s*)', bygroups(String.Heredoc, Whitespace), + 'heredoc_double'), + (r"(''')(\s*)$", bygroups(String.Heredoc, Whitespace), + 'heredoc_single'), + (r'"', String.Double, 'string_double'), + (r"'", String.Single, 'string_single'), + + include('sigils'), + + (r'%\{', Punctuation, 'map_key'), + (r'\{', Punctuation, 'tuple'), + ], + 'heredoc_double': [ + (r'^(\s*)(""")', bygroups(Whitespace, String.Heredoc), '#pop'), + include('heredoc_interpol'), + ], + 'heredoc_single': [ + (r"^\s*'''", String.Heredoc, '#pop'), + include('heredoc_interpol'), + ], + 'heredoc_interpol': [ + (r'[^#\\\n]+', String.Heredoc), + include('escapes'), + (r'\\.', String.Heredoc), + (r'\n+', String.Heredoc), + include('interpol'), + ], + 'heredoc_no_interpol': [ + (r'[^\\\n]+', String.Heredoc), + (r'\\.', String.Heredoc), + (r'\n+', Whitespace), + ], + 'escapes': [ + (long_hex_char_re, + bygroups(String.Escape, Number.Hex, String.Escape)), + (hex_char_re, String.Escape), + (escape_char_re, String.Escape), + ], + 'interpol': [ + (r'#\{', String.Interpol, 'interpol_string'), + ], + 'interpol_string': [ + (r'\}', String.Interpol, "#pop"), + include('root') + ], + 'map_key': [ + include('root'), + (r':', Punctuation, 'map_val'), + (r'=>', Punctuation, 'map_val'), + (r'\}', Punctuation, '#pop'), + ], + 'map_val': [ + include('root'), + (r',', Punctuation, '#pop'), + (r'(?=\})', Punctuation, '#pop'), + ], + 'tuple': [ + include('root'), + (r'\}', Punctuation, '#pop'), + ], + } + tokens.update(gen_elixir_string_rules('double', '"', String.Double)) + tokens.update(gen_elixir_string_rules('single', "'", String.Single)) + tokens.update(gen_elixir_string_rules('double_atom', '"', String.Symbol)) + tokens.update(gen_elixir_string_rules('single_atom', "'", String.Symbol)) + tokens.update(gen_elixir_sigil_rules()) + + +class ElixirConsoleLexer(Lexer): + """ + For Elixir interactive console (iex) output like: + + .. sourcecode:: iex + + iex> [head | tail] = [1,2,3] + [1,2,3] + iex> head + 1 + iex> tail + [2,3] + iex> [head | tail] + [1,2,3] + iex> length [head | tail] + 3 + """ + + name = 'Elixir iex session' + aliases = ['iex'] + mimetypes = ['text/x-elixir-shellsession'] + url = 'https://elixir-lang.org' + version_added = '1.5' + + _prompt_re = re.compile(r'(iex|\.{3})((?:\([\w@_.]+\))?\d+|\(\d+\))?> ') + + def get_tokens_unprocessed(self, text): + exlexer = ElixirLexer(**self.options) + + curcode = '' + in_error = False + insertions = [] + for match in line_re.finditer(text): + line = match.group() + if line.startswith('** '): + in_error = True + insertions.append((len(curcode), + [(0, Generic.Error, line[:-1])])) + curcode += line[-1:] + else: + m = self._prompt_re.match(line) + if m is not None: + in_error = False + end = m.end() + insertions.append((len(curcode), + [(0, Generic.Prompt, line[:end])])) + curcode += line[end:] + else: + if curcode: + yield from do_insertions( + insertions, exlexer.get_tokens_unprocessed(curcode)) + curcode = '' + insertions = [] + token = Generic.Error if in_error else Generic.Output + yield match.start(), token, line + if curcode: + yield from do_insertions( + insertions, exlexer.get_tokens_unprocessed(curcode)) diff --git a/pygments/lexers/esoteric.py b/pygments/lexers/esoteric.py new file mode 100644 index 0000000000000000000000000000000000000000..fe3825c356efc5b20395f8ac694087fb7d68d43f --- /dev/null +++ b/pygments/lexers/esoteric.py @@ -0,0 +1,300 @@ +""" + pygments.lexers.esoteric + ~~~~~~~~~~~~~~~~~~~~~~~~ + + Lexers for esoteric languages. + + :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS. + :license: BSD, see LICENSE for details. +""" + +from pygments.lexer import RegexLexer, include, words, bygroups +from pygments.token import Comment, Operator, Keyword, Name, String, Number, \ + Punctuation, Error, Whitespace + +__all__ = ['BrainfuckLexer', 'BefungeLexer', 'RedcodeLexer', 'CAmkESLexer', + 'CapDLLexer', 'AheuiLexer'] + + +class BrainfuckLexer(RegexLexer): + """ + Lexer for the esoteric BrainFuck language. + """ + + name = 'Brainfuck' + url = 'http://www.muppetlabs.com/~breadbox/bf/' + aliases = ['brainfuck', 'bf'] + filenames = ['*.bf', '*.b'] + mimetypes = ['application/x-brainfuck'] + version_added = '' + + tokens = { + 'common': [ + # use different colors for different instruction types + (r'[.,]+', Name.Tag), + (r'[+-]+', Name.Builtin), + (r'[<>]+', Name.Variable), + (r'[^.,+\-<>\[\]]+', Comment), + ], + 'root': [ + (r'\[', Keyword, 'loop'), + (r'\]', Error), + include('common'), + ], + 'loop': [ + (r'\[', Keyword, '#push'), + (r'\]', Keyword, '#pop'), + include('common'), + ] + } + + def analyse_text(text): + """It's safe to assume that a program which mostly consists of + - + and < > is brainfuck.""" + plus_minus_count = 0 + greater_less_count = 0 + + range_to_check = max(256, len(text)) + + for c in text[:range_to_check]: + if c == '+' or c == '-': + plus_minus_count += 1 + if c == '<' or c == '>': + greater_less_count += 1 + + if plus_minus_count > (0.25 * range_to_check): + return 1.0 + if greater_less_count > (0.25 * range_to_check): + return 1.0 + + result = 0 + if '[-]' in text: + result += 0.5 + + return result + + +class BefungeLexer(RegexLexer): + """ + Lexer for the esoteric Befunge language. + """ + name = 'Befunge' + url = 'http://en.wikipedia.org/wiki/Befunge' + aliases = ['befunge'] + filenames = ['*.befunge'] + mimetypes = ['application/x-befunge'] + version_added = '0.7' + + tokens = { + 'root': [ + (r'[0-9a-f]', Number), + (r'[+*/%!`-]', Operator), # Traditional math + (r'[<>^v?\[\]rxjk]', Name.Variable), # Move, imperatives + (r'[:\\$.,n]', Name.Builtin), # Stack ops, imperatives + (r'[|_mw]', Keyword), + (r'[{}]', Name.Tag), # Befunge-98 stack ops + (r'".*?"', String.Double), # Strings don't appear to allow escapes + (r'\'.', String.Single), # Single character + (r'[#;]', Comment), # Trampoline... depends on direction hit + (r'[pg&~=@iotsy]', Keyword), # Misc + (r'[()A-Z]', Comment), # Fingerprints + (r'\s+', Whitespace), # Whitespace doesn't matter + ], + } + + +class CAmkESLexer(RegexLexer): + """ + Basic lexer for the input language for the CAmkES component platform. + """ + name = 'CAmkES' + url = 'https://sel4.systems/CAmkES/' + aliases = ['camkes', 'idl4'] + filenames = ['*.camkes', '*.idl4'] + version_added = '2.1' + + tokens = { + 'root': [ + # C pre-processor directive + (r'^(\s*)(#.*)(\n)', bygroups(Whitespace, Comment.Preproc, + Whitespace)), + + # Whitespace, comments + (r'\s+', Whitespace), + (r'/\*(.|\n)*?\*/', Comment), + (r'//.*$', Comment), + + (r'[\[(){},.;\]]', Punctuation), + (r'[~!%^&*+=|?:<>/-]', Operator), + + (words(('assembly', 'attribute', 'component', 'composition', + 'configuration', 'connection', 'connector', 'consumes', + 'control', 'dataport', 'Dataport', 'Dataports', 'emits', + 'event', 'Event', 'Events', 'export', 'from', 'group', + 'hardware', 'has', 'interface', 'Interface', 'maybe', + 'procedure', 'Procedure', 'Procedures', 'provides', + 'template', 'thread', 'threads', 'to', 'uses', 'with'), + suffix=r'\b'), Keyword), + + (words(('bool', 'boolean', 'Buf', 'char', 'character', 'double', + 'float', 'in', 'inout', 'int', 'int16_6', 'int32_t', + 'int64_t', 'int8_t', 'integer', 'mutex', 'out', 'real', + 'refin', 'semaphore', 'signed', 'string', 'struct', + 'uint16_t', 'uint32_t', 'uint64_t', 'uint8_t', 'uintptr_t', + 'unsigned', 'void'), + suffix=r'\b'), Keyword.Type), + + # Recognised attributes + (r'[a-zA-Z_]\w*_(priority|domain|buffer)', Keyword.Reserved), + (words(('dma_pool', 'from_access', 'to_access'), suffix=r'\b'), + Keyword.Reserved), + + # CAmkES-level include + (r'(import)(\s+)((?:<[^>]*>|"[^"]*");)', + bygroups(Comment.Preproc, Whitespace, Comment.Preproc)), + + # C-level include + (r'(include)(\s+)((?:<[^>]*>|"[^"]*");)', + bygroups(Comment.Preproc, Whitespace, Comment.Preproc)), + + # Literals + (r'0[xX][\da-fA-F]+', Number.Hex), + (r'-?[\d]+', Number), + (r'-?[\d]+\.[\d]+', Number.Float), + (r'"[^"]*"', String), + (r'[Tt]rue|[Ff]alse', Name.Builtin), + + # Identifiers + (r'[a-zA-Z_]\w*', Name), + ], + } + + +class CapDLLexer(RegexLexer): + """ + Basic lexer for CapDL. + + The source of the primary tool that reads such specifications is available + at https://github.com/seL4/capdl/tree/master/capDL-tool. Note that this + lexer only supports a subset of the grammar. For example, identifiers can + shadow type names, but these instances are currently incorrectly + highlighted as types. Supporting this would need a stateful lexer that is + considered unnecessarily complex for now. + """ + name = 'CapDL' + url = 'https://ssrg.nicta.com.au/publications/nictaabstracts/Kuz_KLW_10.abstract.pml' + aliases = ['capdl'] + filenames = ['*.cdl'] + version_added = '2.2' + + tokens = { + 'root': [ + # C pre-processor directive + (r'^(\s*)(#.*)(\n)', + bygroups(Whitespace, Comment.Preproc, Whitespace)), + + # Whitespace, comments + (r'\s+', Whitespace), + (r'/\*(.|\n)*?\*/', Comment), + (r'(//|--).*$', Comment), + + (r'[<>\[(){},:;=\]]', Punctuation), + (r'\.\.', Punctuation), + + (words(('arch', 'arm11', 'caps', 'child_of', 'ia32', 'irq', 'maps', + 'objects'), suffix=r'\b'), Keyword), + + (words(('aep', 'asid_pool', 'cnode', 'ep', 'frame', 'io_device', + 'io_ports', 'io_pt', 'notification', 'pd', 'pt', 'tcb', + 'ut', 'vcpu'), suffix=r'\b'), Keyword.Type), + + # Properties + (words(('asid', 'addr', 'badge', 'cached', 'dom', 'domainID', 'elf', + 'fault_ep', 'G', 'guard', 'guard_size', 'init', 'ip', + 'prio', 'sp', 'R', 'RG', 'RX', 'RW', 'RWG', 'RWX', 'W', + 'WG', 'WX', 'level', 'masked', 'master_reply', 'paddr', + 'ports', 'reply', 'uncached'), suffix=r'\b'), + Keyword.Reserved), + + # Literals + (r'0[xX][\da-fA-F]+', Number.Hex), + (r'\d+(\.\d+)?(k|M)?', Number), + (words(('bits',), suffix=r'\b'), Number), + (words(('cspace', 'vspace', 'reply_slot', 'caller_slot', + 'ipc_buffer_slot'), suffix=r'\b'), Number), + + # Identifiers + (r'[a-zA-Z_][-@\.\w]*', Name), + ], + } + + +class RedcodeLexer(RegexLexer): + """ + A simple Redcode lexer based on ICWS'94. + Contributed by Adam Blinkinsop . + """ + name = 'Redcode' + aliases = ['redcode'] + filenames = ['*.cw'] + url = 'https://en.wikipedia.org/wiki/Core_War' + version_added = '0.8' + + opcodes = ('DAT', 'MOV', 'ADD', 'SUB', 'MUL', 'DIV', 'MOD', + 'JMP', 'JMZ', 'JMN', 'DJN', 'CMP', 'SLT', 'SPL', + 'ORG', 'EQU', 'END') + modifiers = ('A', 'B', 'AB', 'BA', 'F', 'X', 'I') + + tokens = { + 'root': [ + # Whitespace: + (r'\s+', Whitespace), + (r';.*$', Comment.Single), + # Lexemes: + # Identifiers + (r'\b({})\b'.format('|'.join(opcodes)), Name.Function), + (r'\b({})\b'.format('|'.join(modifiers)), Name.Decorator), + (r'[A-Za-z_]\w+', Name), + # Operators + (r'[-+*/%]', Operator), + (r'[#$@<>]', Operator), # mode + (r'[.,]', Punctuation), # mode + # Numbers + (r'[-+]?\d+', Number.Integer), + ], + } + + +class AheuiLexer(RegexLexer): + """ + Aheui is esoteric language based on Korean alphabets. + """ + + name = 'Aheui' + url = 'http://aheui.github.io/' + aliases = ['aheui'] + filenames = ['*.aheui'] + version_added = '' + + tokens = { + 'root': [ + ('[' + '나-낳냐-냫너-넣녀-녛노-놓뇨-눟뉴-닇' + '다-닿댜-댷더-덯뎌-뎧도-돟됴-둫듀-딓' + '따-땋땨-떃떠-떻뗘-뗳또-똫뚀-뚷뜌-띟' + '라-랗랴-럏러-렇려-렿로-롷료-뤃류-릫' + '마-맣먀-먛머-멓며-몋모-뫃묘-뭏뮤-믷' + '바-밯뱌-뱧버-벟벼-볗보-봏뵤-붛뷰-빃' + '빠-빻뺘-뺳뻐-뻫뼈-뼣뽀-뽛뾰-뿧쀼-삏' + '사-샇샤-샿서-섷셔-셯소-솧쇼-숳슈-싛' + '싸-쌓쌰-썋써-쎃쎠-쎻쏘-쏳쑈-쑿쓔-씧' + '자-잫쟈-쟣저-젛져-졓조-좋죠-줗쥬-즿' + '차-챃챠-챻처-첳쳐-쳫초-촣쵸-춯츄-칗' + '카-캏캬-컇커-컿켜-켷코-콯쿄-쿻큐-킣' + '타-탛탸-턓터-텋텨-톃토-톻툐-퉇튜-틯' + '파-팧퍄-퍟퍼-펗펴-폏포-퐇표-풓퓨-픻' + '하-핳햐-햫허-헣혀-혛호-홓효-훟휴-힇' + ']', Operator), + ('.', Comment), + ], + } diff --git a/pygments/lexers/ezhil.py b/pygments/lexers/ezhil.py new file mode 100644 index 0000000000000000000000000000000000000000..aeb71ab0ce793656d0e56d72d72af50b0265d4f5 --- /dev/null +++ b/pygments/lexers/ezhil.py @@ -0,0 +1,76 @@ +""" + pygments.lexers.ezhil + ~~~~~~~~~~~~~~~~~~~~~ + + Pygments lexers for Ezhil language. + + :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS. + :license: BSD, see LICENSE for details. +""" + +import re + +from pygments.lexer import RegexLexer, include, words +from pygments.token import Keyword, Comment, Name, String, Number, \ + Punctuation, Operator, Whitespace + +__all__ = ['EzhilLexer'] + + +class EzhilLexer(RegexLexer): + """ + Lexer for Ezhil, a Tamil script-based programming language. + """ + name = 'Ezhil' + url = 'http://ezhillang.org' + aliases = ['ezhil'] + filenames = ['*.n'] + mimetypes = ['text/x-ezhil'] + version_added = '2.1' + # Refer to tamil.utf8.tamil_letters from open-tamil for a stricter version of this. + # This much simpler version is close enough, and includes combining marks. + _TALETTERS = '[a-zA-Z_]|[\u0b80-\u0bff]' + tokens = { + 'root': [ + include('keywords'), + (r'#.*$', Comment.Single), + (r'[@+/*,^\-%]|[!<>=]=?|&&?|\|\|?', Operator), + ('இல்', Operator.Word), + (words(('assert', 'max', 'min', + 'நீளம்', 'சரம்_இடமாற்று', 'சரம்_கண்டுபிடி', + 'பட்டியல்', 'பின்இணை', 'வரிசைப்படுத்து', + 'எடு', 'தலைகீழ்', 'நீட்டிக்க', 'நுழைக்க', 'வை', + 'கோப்பை_திற', 'கோப்பை_எழுது', 'கோப்பை_மூடு', + 'pi', 'sin', 'cos', 'tan', 'sqrt', 'hypot', 'pow', + 'exp', 'log', 'log10', 'exit', + ), suffix=r'\b'), Name.Builtin), + (r'(True|False)\b', Keyword.Constant), + (r'[^\S\n]+', Whitespace), + include('identifier'), + include('literal'), + (r'[(){}\[\]:;.]', Punctuation), + ], + 'keywords': [ + ('பதிப்பி|தேர்ந்தெடு|தேர்வு|ஏதேனில்|ஆனால்|இல்லைஆனால்|இல்லை|ஆக|ஒவ்வொன்றாக|இல்|வரை|செய்|முடியேனில்|பின்கொடு|முடி|நிரல்பாகம்|தொடர்|நிறுத்து|நிரல்பாகம்', Keyword), + ], + 'identifier': [ + ('(?:'+_TALETTERS+')(?:[0-9]|'+_TALETTERS+')*', Name), + ], + 'literal': [ + (r'".*?"', String), + (r'\d+((\.\d*)?[eE][+-]?\d+|\.\d*)', Number.Float), + (r'\d+', Number.Integer), + ] + } + + def analyse_text(text): + """This language uses Tamil-script. We'll assume that if there's a + decent amount of Tamil-characters, it's this language. This assumption + is obviously horribly off if someone uses string literals in tamil + in another language.""" + if len(re.findall(r'[\u0b80-\u0bff]', text)) > 10: + return 0.25 + + def __init__(self, **options): + super().__init__(**options) + self.encoding = options.get('encoding', 'utf-8') diff --git a/pygments/lexers/factor.py b/pygments/lexers/factor.py new file mode 100644 index 0000000000000000000000000000000000000000..2ec3f9be25174c23d48b8734963a2c97461505d6 --- /dev/null +++ b/pygments/lexers/factor.py @@ -0,0 +1,363 @@ +""" + pygments.lexers.factor + ~~~~~~~~~~~~~~~~~~~~~~ + + Lexers for the Factor language. + + :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS. + :license: BSD, see LICENSE for details. +""" + +from pygments.lexer import RegexLexer, bygroups, default, words +from pygments.token import Text, Comment, Keyword, Name, String, Number, \ + Whitespace, Punctuation + +__all__ = ['FactorLexer'] + + +class FactorLexer(RegexLexer): + """ + Lexer for the Factor language. + """ + name = 'Factor' + url = 'http://factorcode.org' + aliases = ['factor'] + filenames = ['*.factor'] + mimetypes = ['text/x-factor'] + version_added = '1.4' + + builtin_kernel = words(( + '-rot', '2bi', '2bi@', '2bi*', '2curry', '2dip', '2drop', '2dup', '2keep', '2nip', + '2over', '2tri', '2tri@', '2tri*', '3bi', '3curry', '3dip', '3drop', '3dup', '3keep', + '3tri', '4dip', '4drop', '4dup', '4keep', '', '=', '>boolean', 'clone', + '?', '?execute', '?if', 'and', 'assert', 'assert=', 'assert?', 'bi', 'bi-curry', + 'bi-curry@', 'bi-curry*', 'bi@', 'bi*', 'boa', 'boolean', 'boolean?', 'both?', + 'build', 'call', 'callstack', 'callstack>array', 'callstack?', 'clear', '(clone)', + 'compose', 'compose?', 'curry', 'curry?', 'datastack', 'die', 'dip', 'do', 'drop', + 'dup', 'dupd', 'either?', 'eq?', 'equal?', 'execute', 'hashcode', 'hashcode*', + 'identity-hashcode', 'identity-tuple', 'identity-tuple?', 'if', 'if*', + 'keep', 'loop', 'most', 'new', 'nip', 'not', 'null', 'object', 'or', 'over', + 'pick', 'prepose', 'retainstack', 'rot', 'same?', 'swap', 'swapd', 'throw', + 'tri', 'tri-curry', 'tri-curry@', 'tri-curry*', 'tri@', 'tri*', 'tuple', + 'tuple?', 'unless', 'unless*', 'until', 'when', 'when*', 'while', 'with', + 'wrapper', 'wrapper?', 'xor'), suffix=r'(\s+)') + + builtin_assocs = words(( + '2cache', '', '>alist', '?at', '?of', 'assoc', 'assoc-all?', + 'assoc-any?', 'assoc-clone-like', 'assoc-combine', 'assoc-diff', + 'assoc-diff!', 'assoc-differ', 'assoc-each', 'assoc-empty?', + 'assoc-filter', 'assoc-filter!', 'assoc-filter-as', 'assoc-find', + 'assoc-hashcode', 'assoc-intersect', 'assoc-like', 'assoc-map', + 'assoc-map-as', 'assoc-partition', 'assoc-refine', 'assoc-size', + 'assoc-stack', 'assoc-subset?', 'assoc-union', 'assoc-union!', + 'assoc=', 'assoc>map', 'assoc?', 'at', 'at+', 'at*', 'cache', 'change-at', + 'clear-assoc', 'delete-at', 'delete-at*', 'enum', 'enum?', 'extract-keys', + 'inc-at', 'key?', 'keys', 'map>assoc', 'maybe-set-at', 'new-assoc', 'of', + 'push-at', 'rename-at', 'set-at', 'sift-keys', 'sift-values', 'substitute', + 'unzip', 'value-at', 'value-at*', 'value?', 'values', 'zip'), suffix=r'(\s+)') + + builtin_combinators = words(( + '2cleave', '2cleave>quot', '3cleave', '3cleave>quot', '4cleave', + '4cleave>quot', 'alist>quot', 'call-effect', 'case', 'case-find', + 'case>quot', 'cleave', 'cleave>quot', 'cond', 'cond>quot', 'deep-spread>quot', + 'execute-effect', 'linear-case-quot', 'no-case', 'no-case?', 'no-cond', + 'no-cond?', 'recursive-hashcode', 'shallow-spread>quot', 'spread', + 'to-fixed-point', 'wrong-values', 'wrong-values?'), suffix=r'(\s+)') + + builtin_math = words(( + '-', '/', '/f', '/i', '/mod', '2/', '2^', '<', '<=', '', '>', + '>=', '>bignum', '>fixnum', '>float', '>integer', '(all-integers?)', + '(each-integer)', '(find-integer)', '*', '+', '?1+', + 'abs', 'align', 'all-integers?', 'bignum', 'bignum?', 'bit?', 'bitand', + 'bitnot', 'bitor', 'bits>double', 'bits>float', 'bitxor', 'complex', + 'complex?', 'denominator', 'double>bits', 'each-integer', 'even?', + 'find-integer', 'find-last-integer', 'fixnum', 'fixnum?', 'float', + 'float>bits', 'float?', 'fp-bitwise=', 'fp-infinity?', 'fp-nan-payload', + 'fp-nan?', 'fp-qnan?', 'fp-sign', 'fp-snan?', 'fp-special?', + 'if-zero', 'imaginary-part', 'integer', 'integer>fixnum', + 'integer>fixnum-strict', 'integer?', 'log2', 'log2-expects-positive', + 'log2-expects-positive?', 'mod', 'neg', 'neg?', 'next-float', + 'next-power-of-2', 'number', 'number=', 'number?', 'numerator', 'odd?', + 'out-of-fixnum-range', 'out-of-fixnum-range?', 'power-of-2?', + 'prev-float', 'ratio', 'ratio?', 'rational', 'rational?', 'real', + 'real-part', 'real?', 'recip', 'rem', 'sgn', 'shift', 'sq', 'times', + 'u<', 'u<=', 'u>', 'u>=', 'unless-zero', 'unordered?', 'when-zero', + 'zero?'), suffix=r'(\s+)') + + builtin_sequences = words(( + '1sequence', '2all?', '2each', '2map', '2map-as', '2map-reduce', '2reduce', + '2selector', '2sequence', '3append', '3append-as', '3each', '3map', '3map-as', + '3sequence', '4sequence', '', '', '', '?first', + '?last', '?nth', '?second', '?set-nth', 'accumulate', 'accumulate!', + 'accumulate-as', 'all?', 'any?', 'append', 'append!', 'append-as', + 'assert-sequence', 'assert-sequence=', 'assert-sequence?', + 'binary-reduce', 'bounds-check', 'bounds-check?', 'bounds-error', + 'bounds-error?', 'but-last', 'but-last-slice', 'cartesian-each', + 'cartesian-map', 'cartesian-product', 'change-nth', 'check-slice', + 'check-slice-error', 'clone-like', 'collapse-slice', 'collector', + 'collector-for', 'concat', 'concat-as', 'copy', 'count', 'cut', 'cut-slice', + 'cut*', 'delete-all', 'delete-slice', 'drop-prefix', 'each', 'each-from', + 'each-index', 'empty?', 'exchange', 'filter', 'filter!', 'filter-as', 'find', + 'find-from', 'find-index', 'find-index-from', 'find-last', 'find-last-from', + 'first', 'first2', 'first3', 'first4', 'flip', 'follow', 'fourth', 'glue', 'halves', + 'harvest', 'head', 'head-slice', 'head-slice*', 'head*', 'head?', + 'if-empty', 'immutable', 'immutable-sequence', 'immutable-sequence?', + 'immutable?', 'index', 'index-from', 'indices', 'infimum', 'infimum-by', + 'insert-nth', 'interleave', 'iota', 'iota-tuple', 'iota-tuple?', 'join', + 'join-as', 'last', 'last-index', 'last-index-from', 'length', 'lengthen', + 'like', 'longer', 'longer?', 'longest', 'map', 'map!', 'map-as', 'map-find', + 'map-find-last', 'map-index', 'map-integers', 'map-reduce', 'map-sum', + 'max-length', 'member-eq?', 'member?', 'midpoint@', 'min-length', + 'mismatch', 'move', 'new-like', 'new-resizable', 'new-sequence', + 'non-negative-integer-expected', 'non-negative-integer-expected?', + 'nth', 'nths', 'pad-head', 'pad-tail', 'padding', 'partition', 'pop', 'pop*', + 'prefix', 'prepend', 'prepend-as', 'produce', 'produce-as', 'product', 'push', + 'push-all', 'push-either', 'push-if', 'reduce', 'reduce-index', 'remove', + 'remove!', 'remove-eq', 'remove-eq!', 'remove-nth', 'remove-nth!', 'repetition', + 'repetition?', 'replace-slice', 'replicate', 'replicate-as', 'rest', + 'rest-slice', 'reverse', 'reverse!', 'reversed', 'reversed?', 'second', + 'selector', 'selector-for', 'sequence', 'sequence-hashcode', 'sequence=', + 'sequence?', 'set-first', 'set-fourth', 'set-last', 'set-length', 'set-nth', + 'set-second', 'set-third', 'short', 'shorten', 'shorter', 'shorter?', + 'shortest', 'sift', 'slice', 'slice-error', 'slice-error?', 'slice?', + 'snip', 'snip-slice', 'start', 'start*', 'subseq', 'subseq?', 'suffix', + 'suffix!', 'sum', 'sum-lengths', 'supremum', 'supremum-by', 'surround', 'tail', + 'tail-slice', 'tail-slice*', 'tail*', 'tail?', 'third', 'trim', + 'trim-head', 'trim-head-slice', 'trim-slice', 'trim-tail', 'trim-tail-slice', + 'unclip', 'unclip-last', 'unclip-last-slice', 'unclip-slice', 'unless-empty', + 'virtual-exemplar', 'virtual-sequence', 'virtual-sequence?', 'virtual@', + 'when-empty'), suffix=r'(\s+)') + + builtin_namespaces = words(( + '+@', 'change', 'change-global', 'counter', 'dec', 'get', 'get-global', + 'global', 'inc', 'init-namespaces', 'initialize', 'is-global', 'make-assoc', + 'namespace', 'namestack', 'off', 'on', 'set', 'set-global', 'set-namestack', + 'toggle', 'with-global', 'with-scope', 'with-variable', 'with-variables'), + suffix=r'(\s+)') + + builtin_arrays = words(( + '1array', '2array', '3array', '4array', '', '>array', 'array', + 'array?', 'pair', 'pair?', 'resize-array'), suffix=r'(\s+)') + + builtin_io = words(( + '(each-stream-block-slice)', '(each-stream-block)', + '(stream-contents-by-block)', '(stream-contents-by-element)', + '(stream-contents-by-length-or-block)', + '(stream-contents-by-length)', '+byte+', '+character+', + 'bad-seek-type', 'bad-seek-type?', 'bl', 'contents', 'each-block', + 'each-block-size', 'each-block-slice', 'each-line', 'each-morsel', + 'each-stream-block', 'each-stream-block-slice', 'each-stream-line', + 'error-stream', 'flush', 'input-stream', 'input-stream?', + 'invalid-read-buffer', 'invalid-read-buffer?', 'lines', 'nl', + 'output-stream', 'output-stream?', 'print', 'read', 'read-into', + 'read-partial', 'read-partial-into', 'read-until', 'read1', 'readln', + 'seek-absolute', 'seek-absolute?', 'seek-end', 'seek-end?', + 'seek-input', 'seek-output', 'seek-relative', 'seek-relative?', + 'stream-bl', 'stream-contents', 'stream-contents*', 'stream-copy', + 'stream-copy*', 'stream-element-type', 'stream-flush', + 'stream-length', 'stream-lines', 'stream-nl', 'stream-print', + 'stream-read', 'stream-read-into', 'stream-read-partial', + 'stream-read-partial-into', 'stream-read-partial-unsafe', + 'stream-read-unsafe', 'stream-read-until', 'stream-read1', + 'stream-readln', 'stream-seek', 'stream-seekable?', 'stream-tell', + 'stream-write', 'stream-write1', 'tell-input', 'tell-output', + 'with-error-stream', 'with-error-stream*', 'with-error>output', + 'with-input-output+error-streams', + 'with-input-output+error-streams*', 'with-input-stream', + 'with-input-stream*', 'with-output-stream', 'with-output-stream*', + 'with-output>error', 'with-output+error-stream', + 'with-output+error-stream*', 'with-streams', 'with-streams*', + 'write', 'write1'), suffix=r'(\s+)') + + builtin_strings = words(( + '1string', '', '>string', 'resize-string', 'string', + 'string?'), suffix=r'(\s+)') + + builtin_vectors = words(( + '1vector', '', '>vector', '?push', 'vector', 'vector?'), + suffix=r'(\s+)') + + builtin_continuations = words(( + '', '', '', 'attempt-all', + 'attempt-all-error', 'attempt-all-error?', 'callback-error-hook', + 'callcc0', 'callcc1', 'cleanup', 'compute-restarts', 'condition', + 'condition?', 'continuation', 'continuation?', 'continue', + 'continue-restart', 'continue-with', 'current-continuation', + 'error', 'error-continuation', 'error-in-thread', 'error-thread', + 'ifcc', 'ignore-errors', 'in-callback?', 'original-error', 'recover', + 'restart', 'restart?', 'restarts', 'rethrow', 'rethrow-restarts', + 'return', 'return-continuation', 'thread-error-hook', 'throw-continue', + 'throw-restarts', 'with-datastack', 'with-return'), suffix=r'(\s+)') + + tokens = { + 'root': [ + # factor allows a file to start with a shebang + (r'#!.*$', Comment.Preproc), + default('base'), + ], + 'base': [ + (r'\s+', Whitespace), + + # defining words + (r'((?:MACRO|MEMO|TYPED)?:[:]?)(\s+)(\S+)', + bygroups(Keyword, Whitespace, Name.Function)), + (r'(M:[:]?)(\s+)(\S+)(\s+)(\S+)', + bygroups(Keyword, Whitespace, Name.Class, Whitespace, + Name.Function)), + (r'(C:)(\s+)(\S+)(\s+)(\S+)', + bygroups(Keyword, Whitespace, Name.Function, Whitespace, + Name.Class)), + (r'(GENERIC:)(\s+)(\S+)', + bygroups(Keyword, Whitespace, Name.Function)), + (r'(HOOK:|GENERIC#)(\s+)(\S+)(\s+)(\S+)', + bygroups(Keyword, Whitespace, Name.Function, Whitespace, + Name.Function)), + (r'(\()(\s)', bygroups(Name.Function, Whitespace), 'stackeffect'), + (r'(;)(\s)', bygroups(Keyword, Whitespace)), + + # imports and namespaces + (r'(USING:)(\s+)', + bygroups(Keyword.Namespace, Whitespace), 'vocabs'), + (r'(USE:|UNUSE:|IN:|QUALIFIED:)(\s+)(\S+)', + bygroups(Keyword.Namespace, Whitespace, Name.Namespace)), + (r'(QUALIFIED-WITH:)(\s+)(\S+)(\s+)(\S+)', + bygroups(Keyword.Namespace, Whitespace, Name.Namespace, + Whitespace, Name.Namespace)), + (r'(FROM:|EXCLUDE:)(\s+)(\S+)(\s+=>\s)', + bygroups(Keyword.Namespace, Whitespace, Name.Namespace, + Whitespace), 'words'), + (r'(RENAME:)(\s+)(\S+)(\s+)(\S+)(\s+)(=>)(\s+)(\S+)', + bygroups(Keyword.Namespace, Whitespace, Name.Function, Whitespace, + Name.Namespace, Whitespace, Punctuation, Whitespace, + Name.Function)), + (r'(ALIAS:|TYPEDEF:)(\s+)(\S+)(\s+)(\S+)', + bygroups(Keyword.Namespace, Whitespace, Name.Function, Whitespace, + Name.Function)), + (r'(DEFER:|FORGET:|POSTPONE:)(\s+)(\S+)', + bygroups(Keyword.Namespace, Whitespace, Name.Function)), + + # tuples and classes + (r'(TUPLE:|ERROR:)(\s+)(\S+)(\s+)(<)(\s+)(\S+)', + bygroups(Keyword, Whitespace, Name.Class, Whitespace, Punctuation, + Whitespace, Name.Class), 'slots'), + (r'(TUPLE:|ERROR:|BUILTIN:)(\s+)(\S+)', + bygroups(Keyword, Whitespace, Name.Class), 'slots'), + (r'(MIXIN:|UNION:|INTERSECTION:)(\s+)(\S+)', + bygroups(Keyword, Whitespace, Name.Class)), + (r'(PREDICATE:)(\s+)(\S+)(\s+)(<)(\s+)(\S+)', + bygroups(Keyword, Whitespace, Name.Class, Whitespace, + Punctuation, Whitespace, Name.Class)), + (r'(C:)(\s+)(\S+)(\s+)(\S+)', + bygroups(Keyword, Whitespace, Name.Function, Whitespace, Name.Class)), + (r'(INSTANCE:)(\s+)(\S+)(\s+)(\S+)', + bygroups(Keyword, Whitespace, Name.Class, Whitespace, Name.Class)), + (r'(SLOT:)(\s+)(\S+)', bygroups(Keyword, Whitespace, Name.Function)), + (r'(SINGLETON:)(\s+)(\S+)', bygroups(Keyword, Whitespace, Name.Class)), + (r'SINGLETONS:', Keyword, 'classes'), + + # other syntax + (r'(CONSTANT:|SYMBOL:|MAIN:|HELP:)(\s+)(\S+)', + bygroups(Keyword, Whitespace, Name.Function)), + (r'(SYMBOLS:)(\s+)', bygroups(Keyword, Whitespace), 'words'), + (r'(SYNTAX:)(\s+)', bygroups(Keyword, Whitespace)), + (r'(ALIEN:)(\s+)', bygroups(Keyword, Whitespace)), + (r'(STRUCT:)(\s+)(\S+)', bygroups(Keyword, Whitespace, Name.Class)), + (r'(FUNCTION:)(\s+)' + r'(\S+)(\s+)(\S+)(\s+)' + r'(\()(\s+)([^)]+)(\))(\s)', + bygroups(Keyword.Namespace, Whitespace, + Text, Whitespace, Name.Function, Whitespace, + Punctuation, Whitespace, Text, Punctuation, Whitespace)), + (r'(FUNCTION-ALIAS:)(\s+)' + r'(\S+)(\s+)(\S+)(\s+)' + r'(\S+)(\s+)' + r'(\()(\s+)([^)]+)(\))(\s)', + bygroups(Keyword.Namespace, Whitespace, + Text, Whitespace, Name.Function, Whitespace, + Name.Function, Whitespace, + Punctuation, Whitespace, Text, Punctuation, Whitespace)), + + # vocab.private + (r'()(\s)', bygroups(Keyword.Namespace, Whitespace)), + + # strings + (r'"""\s(?:.|\n)*?\s"""', String), + (r'"(?:\\\\|\\"|[^"])*"', String), + (r'(\S+")(\s+)((?:\\\\|\\"|[^"])*")', + bygroups(String, Whitespace, String)), + (r'(CHAR:)(\s+)(\\[\\abfnrstv]|[^\\]\S*)(\s)', + bygroups(String.Char, Whitespace, String.Char, Whitespace)), + + # comments + (r'!\s+.*$', Comment), + (r'#!\s+.*$', Comment), + (r'/\*\s+(?:.|\n)*?\s\*/', Comment), + + # boolean constants + (r'[tf]\b', Name.Constant), + + # symbols and literals + (r'[\\$]\s+\S+', Name.Constant), + (r'M\\\s+\S+\s+\S+', Name.Constant), + + # numbers + (r'[+-]?(?:[\d,]*\d)?\.(?:\d([\d,]*\d)?)?(?:[eE][+-]?\d+)?\s', Number), + (r'[+-]?\d(?:[\d,]*\d)?(?:[eE][+-]?\d+)?\s', Number), + (r'0x[a-fA-F\d](?:[a-fA-F\d,]*[a-fA-F\d])?(?:p\d([\d,]*\d)?)?\s', Number), + (r'NAN:\s+[a-fA-F\d](?:[a-fA-F\d,]*[a-fA-F\d])?(?:p\d([\d,]*\d)?)?\s', Number), + (r'0b[01]+\s', Number.Bin), + (r'0o[0-7]+\s', Number.Oct), + (r'(?:\d([\d,]*\d)?)?\+\d(?:[\d,]*\d)?/\d(?:[\d,]*\d)?\s', Number), + (r'(?:\-\d([\d,]*\d)?)?\-\d(?:[\d,]*\d)?/\d(?:[\d,]*\d)?\s', Number), + + # keywords + (r'(?:deprecated|final|foldable|flushable|inline|recursive)\s', + Keyword), + + # builtins + (builtin_kernel, bygroups(Name.Builtin, Whitespace)), + (builtin_assocs, bygroups(Name.Builtin, Whitespace)), + (builtin_combinators, bygroups(Name.Builtin, Whitespace)), + (builtin_math, bygroups(Name.Builtin, Whitespace)), + (builtin_sequences, bygroups(Name.Builtin, Whitespace)), + (builtin_namespaces, bygroups(Name.Builtin, Whitespace)), + (builtin_arrays, bygroups(Name.Builtin, Whitespace)), + (builtin_io, bygroups(Name.Builtin, Whitespace)), + (builtin_strings, bygroups(Name.Builtin, Whitespace)), + (builtin_vectors, bygroups(Name.Builtin, Whitespace)), + (builtin_continuations, bygroups(Name.Builtin, Whitespace)), + + # everything else is text + (r'\S+', Text), + ], + 'stackeffect': [ + (r'\s+', Whitespace), + (r'(\()(\s+)', bygroups(Name.Function, Whitespace), 'stackeffect'), + (r'(\))(\s+)', bygroups(Name.Function, Whitespace), '#pop'), + (r'(--)(\s+)', bygroups(Name.Function, Whitespace)), + (r'\S+', Name.Variable), + ], + 'slots': [ + (r'\s+', Whitespace), + (r'(;)(\s+)', bygroups(Keyword, Whitespace), '#pop'), + (r'(\{)(\s+)(\S+)(\s+)([^}]+)(\s+)(\})(\s+)', + bygroups(Text, Whitespace, Name.Variable, Whitespace, + Text, Whitespace, Text, Whitespace)), + (r'\S+', Name.Variable), + ], + 'vocabs': [ + (r'\s+', Whitespace), + (r'(;)(\s+)', bygroups(Keyword, Whitespace), '#pop'), + (r'\S+', Name.Namespace), + ], + 'classes': [ + (r'\s+', Whitespace), + (r'(;)(\s+)', bygroups(Keyword, Whitespace), '#pop'), + (r'\S+', Name.Class), + ], + 'words': [ + (r'\s+', Whitespace), + (r'(;)(\s+)', bygroups(Keyword, Whitespace), '#pop'), + (r'\S+', Name.Function), + ], + } diff --git a/pygments/lexers/fantom.py b/pygments/lexers/fantom.py new file mode 100644 index 0000000000000000000000000000000000000000..7228554ba9b42b0d8667e8eefe16b622253082d0 --- /dev/null +++ b/pygments/lexers/fantom.py @@ -0,0 +1,251 @@ +""" + pygments.lexers.fantom + ~~~~~~~~~~~~~~~~~~~~~~ + + Lexer for the Fantom language. + + :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS. + :license: BSD, see LICENSE for details. +""" + +from string import Template + +from pygments.lexer import RegexLexer, include, bygroups, using, \ + this, default, words +from pygments.token import Text, Comment, Operator, Keyword, Name, String, \ + Number, Punctuation, Literal, Whitespace + +__all__ = ['FantomLexer'] + + +class FantomLexer(RegexLexer): + """ + For Fantom source code. + """ + name = 'Fantom' + aliases = ['fan'] + filenames = ['*.fan'] + mimetypes = ['application/x-fantom'] + url = 'https://www.fantom.org' + version_added = '1.5' + + # often used regexes + def s(str): + return Template(str).substitute( + dict( + pod=r'[\"\w\.]+', + eos=r'\n|;', + id=r'[a-zA-Z_]\w*', + # all chars which can be part of type definition. Starts with + # either letter, or [ (maps), or | (funcs) + type=r'(?:\[|[a-zA-Z_]|\|)[:\w\[\]|\->?]*?', + ) + ) + + tokens = { + 'comments': [ + (r'(?s)/\*.*?\*/', Comment.Multiline), # Multiline + (r'//.*?$', Comment.Single), # Single line + # TODO: highlight references in fandocs + (r'\*\*.*?$', Comment.Special), # Fandoc + (r'#.*$', Comment.Single) # Shell-style + ], + 'literals': [ + (r'\b-?[\d_]+(ns|ms|sec|min|hr|day)', Number), # Duration + (r'\b-?[\d_]*\.[\d_]+(ns|ms|sec|min|hr|day)', Number), # Duration with dot + (r'\b-?(\d+)?\.\d+(f|F|d|D)?', Number.Float), # Float/Decimal + (r'\b-?0x[0-9a-fA-F_]+', Number.Hex), # Hex + (r'\b-?[\d_]+', Number.Integer), # Int + (r"'\\.'|'[^\\]'|'\\u[0-9a-f]{4}'", String.Char), # Char + (r'"', Punctuation, 'insideStr'), # Opening quote + (r'`', Punctuation, 'insideUri'), # Opening accent + (r'\b(true|false|null)\b', Keyword.Constant), # Bool & null + (r'(?:(\w+)(::))?(\w+)(<\|)(.*?)(\|>)', # DSL + bygroups(Name.Namespace, Punctuation, Name.Class, + Punctuation, String, Punctuation)), + (r'(?:(\w+)(::))?(\w+)?(#)(\w+)?', # Type/slot literal + bygroups(Name.Namespace, Punctuation, Name.Class, + Punctuation, Name.Function)), + (r'\[,\]', Literal), # Empty list + (s(r'($type)(\[,\])'), # Typed empty list + bygroups(using(this, state='inType'), Literal)), + (r'\[:\]', Literal), # Empty Map + (s(r'($type)(\[:\])'), + bygroups(using(this, state='inType'), Literal)), + ], + 'insideStr': [ + (r'\\\\', String.Escape), # Escaped backslash + (r'\\"', String.Escape), # Escaped " + (r'\\`', String.Escape), # Escaped ` + (r'\$\w+', String.Interpol), # Subst var + (r'\$\{.*?\}', String.Interpol), # Subst expr + (r'"', Punctuation, '#pop'), # Closing quot + (r'.', String) # String content + ], + 'insideUri': [ # TODO: remove copy/paste str/uri + (r'\\\\', String.Escape), # Escaped backslash + (r'\\"', String.Escape), # Escaped " + (r'\\`', String.Escape), # Escaped ` + (r'\$\w+', String.Interpol), # Subst var + (r'\$\{.*?\}', String.Interpol), # Subst expr + (r'`', Punctuation, '#pop'), # Closing tick + (r'.', String.Backtick) # URI content + ], + 'protectionKeywords': [ + (r'\b(public|protected|private|internal)\b', Keyword), + ], + 'typeKeywords': [ + (r'\b(abstract|final|const|native|facet|enum)\b', Keyword), + ], + 'methodKeywords': [ + (r'\b(abstract|native|once|override|static|virtual|final)\b', + Keyword), + ], + 'fieldKeywords': [ + (r'\b(abstract|const|final|native|override|static|virtual|' + r'readonly)\b', Keyword) + ], + 'otherKeywords': [ + (words(( + 'try', 'catch', 'throw', 'finally', 'for', 'if', 'else', 'while', + 'as', 'is', 'isnot', 'switch', 'case', 'default', 'continue', + 'break', 'do', 'return', 'get', 'set'), prefix=r'\b', suffix=r'\b'), + Keyword), + (r'\b(it|this|super)\b', Name.Builtin.Pseudo), + ], + 'operators': [ + (r'\+\+|\-\-|\+|\-|\*|/|\|\||&&|<=>|<=|<|>=|>|=|!|\[|\]', Operator) + ], + 'inType': [ + (r'[\[\]|\->:?]', Punctuation), + (s(r'$id'), Name.Class), + default('#pop'), + + ], + 'root': [ + include('comments'), + include('protectionKeywords'), + include('typeKeywords'), + include('methodKeywords'), + include('fieldKeywords'), + include('literals'), + include('otherKeywords'), + include('operators'), + (r'using\b', Keyword.Namespace, 'using'), # Using stmt + (r'@\w+', Name.Decorator, 'facet'), # Symbol + (r'(class|mixin)(\s+)(\w+)', bygroups(Keyword, Whitespace, Name.Class), + 'inheritance'), # Inheritance list + + # Type var := val + (s(r'($type)([ \t]+)($id)(\s*)(:=)'), + bygroups(using(this, state='inType'), Whitespace, + Name.Variable, Whitespace, Operator)), + + # var := val + (s(r'($id)(\s*)(:=)'), + bygroups(Name.Variable, Whitespace, Operator)), + + # .someId( or ->someId( ### + (s(r'(\.|(?:\->))($id)(\s*)(\()'), + bygroups(Operator, Name.Function, Whitespace, Punctuation), + 'insideParen'), + + # .someId or ->someId + (s(r'(\.|(?:\->))($id)'), + bygroups(Operator, Name.Function)), + + # new makeXXX ( + (r'(new)(\s+)(make\w*)(\s*)(\()', + bygroups(Keyword, Whitespace, Name.Function, Whitespace, Punctuation), + 'insideMethodDeclArgs'), + + # Type name ( + (s(r'($type)([ \t]+)' # Return type and whitespace + r'($id)(\s*)(\()'), # method name + open brace + bygroups(using(this, state='inType'), Whitespace, + Name.Function, Whitespace, Punctuation), + 'insideMethodDeclArgs'), + + # ArgType argName, + (s(r'($type)(\s+)($id)(\s*)(,)'), + bygroups(using(this, state='inType'), Whitespace, Name.Variable, + Whitespace, Punctuation)), + + # ArgType argName) + # Covered in 'insideParen' state + + # ArgType argName -> ArgType| + (s(r'($type)(\s+)($id)(\s*)(\->)(\s*)($type)(\|)'), + bygroups(using(this, state='inType'), Whitespace, Name.Variable, + Whitespace, Punctuation, Whitespace, using(this, state='inType'), + Punctuation)), + + # ArgType argName| + (s(r'($type)(\s+)($id)(\s*)(\|)'), + bygroups(using(this, state='inType'), Whitespace, Name.Variable, + Whitespace, Punctuation)), + + # Type var + (s(r'($type)([ \t]+)($id)'), + bygroups(using(this, state='inType'), Whitespace, + Name.Variable)), + + (r'\(', Punctuation, 'insideParen'), + (r'\{', Punctuation, 'insideBrace'), + (r'\s+', Whitespace), + (r'.', Text) + ], + 'insideParen': [ + (r'\)', Punctuation, '#pop'), + include('root'), + ], + 'insideMethodDeclArgs': [ + (r'\)', Punctuation, '#pop'), + (s(r'($type)(\s+)($id)(\s*)(\))'), + bygroups(using(this, state='inType'), Whitespace, Name.Variable, + Whitespace, Punctuation), '#pop'), + include('root'), + ], + 'insideBrace': [ + (r'\}', Punctuation, '#pop'), + include('root'), + ], + 'inheritance': [ + (r'\s+', Whitespace), # Whitespace + (r':|,', Punctuation), + (r'(?:(\w+)(::))?(\w+)', + bygroups(Name.Namespace, Punctuation, Name.Class)), + (r'\{', Punctuation, '#pop') + ], + 'using': [ + (r'[ \t]+', Whitespace), # consume whitespaces + (r'(\[)(\w+)(\])', + bygroups(Punctuation, Comment.Special, Punctuation)), # ffi + (r'(\")?([\w.]+)(\")?', + bygroups(Punctuation, Name.Namespace, Punctuation)), # podname + (r'::', Punctuation, 'usingClass'), + default('#pop') + ], + 'usingClass': [ + (r'[ \t]+', Whitespace), # consume whitespaces + (r'(as)(\s+)(\w+)', + bygroups(Keyword.Declaration, Whitespace, Name.Class), '#pop:2'), + (r'[\w$]+', Name.Class), + default('#pop:2') # jump out to root state + ], + 'facet': [ + (r'\s+', Whitespace), + (r'\{', Punctuation, 'facetFields'), + default('#pop') + ], + 'facetFields': [ + include('comments'), + include('literals'), + include('operators'), + (r'\s+', Whitespace), + (r'(\s*)(\w+)(\s*)(=)', bygroups(Whitespace, Name, Whitespace, Operator)), + (r'\}', Punctuation, '#pop'), + (r'\s+', Whitespace), + (r'.', Text) + ], + } diff --git a/pygments/lexers/felix.py b/pygments/lexers/felix.py new file mode 100644 index 0000000000000000000000000000000000000000..a42ac08057642c26978bef4d0d624dd896d4c6cd --- /dev/null +++ b/pygments/lexers/felix.py @@ -0,0 +1,275 @@ +""" + pygments.lexers.felix + ~~~~~~~~~~~~~~~~~~~~~ + + Lexer for the Felix language. + + :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS. + :license: BSD, see LICENSE for details. +""" + +from pygments.lexer import RegexLexer, include, bygroups, default, words, \ + combined +from pygments.token import Text, Comment, Operator, Keyword, Name, String, \ + Number, Punctuation, Whitespace + +__all__ = ['FelixLexer'] + + +class FelixLexer(RegexLexer): + """ + For Felix source code. + """ + + name = 'Felix' + url = 'http://www.felix-lang.org' + aliases = ['felix', 'flx'] + filenames = ['*.flx', '*.flxh'] + mimetypes = ['text/x-felix'] + version_added = '1.2' + + preproc = ( + 'elif', 'else', 'endif', 'if', 'ifdef', 'ifndef', + ) + + keywords = ( + '_', '_deref', 'all', 'as', + 'assert', 'attempt', 'call', 'callback', 'case', 'caseno', 'cclass', + 'code', 'compound', 'ctypes', 'do', 'done', 'downto', 'elif', 'else', + 'endattempt', 'endcase', 'endif', 'endmatch', 'enum', 'except', + 'exceptions', 'expect', 'finally', 'for', 'forall', 'forget', 'fork', + 'functor', 'goto', 'ident', 'if', 'incomplete', 'inherit', 'instance', + 'interface', 'jump', 'lambda', 'loop', 'match', 'module', 'namespace', + 'new', 'noexpand', 'nonterm', 'obj', 'of', 'open', 'parse', 'raise', + 'regexp', 'reglex', 'regmatch', 'rename', 'return', 'the', 'then', + 'to', 'type', 'typecase', 'typedef', 'typematch', 'typeof', 'upto', + 'when', 'whilst', 'with', 'yield', + ) + + keyword_directives = ( + '_gc_pointer', '_gc_type', 'body', 'comment', 'const', 'export', + 'header', 'inline', 'lval', 'macro', 'noinline', 'noreturn', + 'package', 'private', 'pod', 'property', 'public', 'publish', + 'requires', 'todo', 'virtual', 'use', + ) + + keyword_declarations = ( + 'def', 'let', 'ref', 'val', 'var', + ) + + keyword_types = ( + 'unit', 'void', 'any', 'bool', + 'byte', 'offset', + 'address', 'caddress', 'cvaddress', 'vaddress', + 'tiny', 'short', 'int', 'long', 'vlong', + 'utiny', 'ushort', 'vshort', 'uint', 'ulong', 'uvlong', + 'int8', 'int16', 'int32', 'int64', + 'uint8', 'uint16', 'uint32', 'uint64', + 'float', 'double', 'ldouble', + 'complex', 'dcomplex', 'lcomplex', + 'imaginary', 'dimaginary', 'limaginary', + 'char', 'wchar', 'uchar', + 'charp', 'charcp', 'ucharp', 'ucharcp', + 'string', 'wstring', 'ustring', + 'cont', + 'array', 'varray', 'list', + 'lvalue', 'opt', 'slice', + ) + + keyword_constants = ( + 'false', 'true', + ) + + operator_words = ( + 'and', 'not', 'in', 'is', 'isin', 'or', 'xor', + ) + + name_builtins = ( + '_svc', 'while', + ) + + name_pseudo = ( + 'root', 'self', 'this', + ) + + decimal_suffixes = '([tTsSiIlLvV]|ll|LL|([iIuU])(8|16|32|64))?' + + tokens = { + 'root': [ + include('whitespace'), + + # Keywords + (words(('axiom', 'ctor', 'fun', 'gen', 'proc', 'reduce', + 'union'), suffix=r'\b'), + Keyword, 'funcname'), + (words(('class', 'cclass', 'cstruct', 'obj', 'struct'), suffix=r'\b'), + Keyword, 'classname'), + (r'(instance|module|typeclass)\b', Keyword, 'modulename'), + + (words(keywords, suffix=r'\b'), Keyword), + (words(keyword_directives, suffix=r'\b'), Name.Decorator), + (words(keyword_declarations, suffix=r'\b'), Keyword.Declaration), + (words(keyword_types, suffix=r'\b'), Keyword.Type), + (words(keyword_constants, suffix=r'\b'), Keyword.Constant), + + # Operators + include('operators'), + + # Float Literal + # -- Hex Float + (r'0[xX]([0-9a-fA-F_]*\.[0-9a-fA-F_]+|[0-9a-fA-F_]+)' + r'[pP][+\-]?[0-9_]+[lLfFdD]?', Number.Float), + # -- DecimalFloat + (r'[0-9_]+(\.[0-9_]+[eE][+\-]?[0-9_]+|' + r'\.[0-9_]*|[eE][+\-]?[0-9_]+)[lLfFdD]?', Number.Float), + (r'\.(0|[1-9][0-9_]*)([eE][+\-]?[0-9_]+)?[lLfFdD]?', + Number.Float), + + # IntegerLiteral + # -- Binary + (rf'0[Bb][01_]+{decimal_suffixes}', Number.Bin), + # -- Octal + (rf'0[0-7_]+{decimal_suffixes}', Number.Oct), + # -- Hexadecimal + (rf'0[xX][0-9a-fA-F_]+{decimal_suffixes}', Number.Hex), + # -- Decimal + (rf'(0|[1-9][0-9_]*){decimal_suffixes}', Number.Integer), + + # Strings + ('([rR][cC]?|[cC][rR])"""', String, 'tdqs'), + ("([rR][cC]?|[cC][rR])'''", String, 'tsqs'), + ('([rR][cC]?|[cC][rR])"', String, 'dqs'), + ("([rR][cC]?|[cC][rR])'", String, 'sqs'), + ('[cCfFqQwWuU]?"""', String, combined('stringescape', 'tdqs')), + ("[cCfFqQwWuU]?'''", String, combined('stringescape', 'tsqs')), + ('[cCfFqQwWuU]?"', String, combined('stringescape', 'dqs')), + ("[cCfFqQwWuU]?'", String, combined('stringescape', 'sqs')), + + # Punctuation + (r'[\[\]{}:(),;?]', Punctuation), + + # Labels + (r'[a-zA-Z_]\w*:>', Name.Label), + + # Identifiers + (r'({})\b'.format('|'.join(name_builtins)), Name.Builtin), + (r'({})\b'.format('|'.join(name_pseudo)), Name.Builtin.Pseudo), + (r'[a-zA-Z_]\w*', Name), + ], + 'whitespace': [ + (r'\s+', Whitespace), + + include('comment'), + + # Preprocessor + (r'(#)(\s*)(if)(\s+)(0)', + bygroups(Comment.Preproc, Whitespace, Comment.Preproc, + Whitespace, Comment.Preproc), 'if0'), + (r'#', Comment.Preproc, 'macro'), + ], + 'operators': [ + (r'({})\b'.format('|'.join(operator_words)), Operator.Word), + (r'!=|==|<<|>>|\|\||&&|[-~+/*%=<>&^|.$]', Operator), + ], + 'comment': [ + (r'//(.*?)$', Comment.Single), + (r'/[*]', Comment.Multiline, 'comment2'), + ], + 'comment2': [ + (r'[^/*]', Comment.Multiline), + (r'/[*]', Comment.Multiline, '#push'), + (r'[*]/', Comment.Multiline, '#pop'), + (r'[/*]', Comment.Multiline), + ], + 'if0': [ + (r'^(\s*)(#if.*?(?]*?>)', + bygroups(Comment.Preproc, Whitespace, String), '#pop'), + (r'(import|include)(\s+)("[^"]*?")', + bygroups(Comment.Preproc, Whitespace, String), '#pop'), + (r"(import|include)(\s+)('[^']*?')", + bygroups(Comment.Preproc, Whitespace, String), '#pop'), + (r'[^/\n]+', Comment.Preproc), + # (r'/[*](.|\n)*?[*]/', Comment), + # (r'//.*?\n', Comment, '#pop'), + (r'/', Comment.Preproc), + (r'(?<=\\)\n', Comment.Preproc), + (r'\n', Whitespace, '#pop'), + ], + 'funcname': [ + include('whitespace'), + (r'[a-zA-Z_]\w*', Name.Function, '#pop'), + # anonymous functions + (r'(?=\()', Text, '#pop'), + ], + 'classname': [ + include('whitespace'), + (r'[a-zA-Z_]\w*', Name.Class, '#pop'), + # anonymous classes + (r'(?=\{)', Text, '#pop'), + ], + 'modulename': [ + include('whitespace'), + (r'\[', Punctuation, ('modulename2', 'tvarlist')), + default('modulename2'), + ], + 'modulename2': [ + include('whitespace'), + (r'([a-zA-Z_]\w*)', Name.Namespace, '#pop:2'), + ], + 'tvarlist': [ + include('whitespace'), + include('operators'), + (r'\[', Punctuation, '#push'), + (r'\]', Punctuation, '#pop'), + (r',', Punctuation), + (r'(with|where)\b', Keyword), + (r'[a-zA-Z_]\w*', Name), + ], + 'stringescape': [ + (r'\\([\\abfnrtv"\']|\n|N\{.*?\}|u[a-fA-F0-9]{4}|' + r'U[a-fA-F0-9]{8}|x[a-fA-F0-9]{2}|[0-7]{1,3})', String.Escape) + ], + 'strings': [ + (r'%(\([a-zA-Z0-9]+\))?[-#0 +]*([0-9]+|[*])?(\.([0-9]+|[*]))?' + '[hlL]?[E-GXc-giorsux%]', String.Interpol), + (r'[^\\\'"%\n]+', String), + # quotes, percents and backslashes must be parsed one at a time + (r'[\'"\\]', String), + # unhandled string formatting sign + (r'%', String) + # newlines are an error (use "nl" state) + ], + 'nl': [ + (r'\n', String) + ], + 'dqs': [ + (r'"', String, '#pop'), + # included here again for raw strings + (r'\\\\|\\"|\\\n', String.Escape), + include('strings') + ], + 'sqs': [ + (r"'", String, '#pop'), + # included here again for raw strings + (r"\\\\|\\'|\\\n", String.Escape), + include('strings') + ], + 'tdqs': [ + (r'"""', String, '#pop'), + include('strings'), + include('nl') + ], + 'tsqs': [ + (r"'''", String, '#pop'), + include('strings'), + include('nl') + ], + } diff --git a/pygments/lexers/fift.py b/pygments/lexers/fift.py new file mode 100644 index 0000000000000000000000000000000000000000..1263236c9c83b557f486b4bd4cef9ec3f894df58 --- /dev/null +++ b/pygments/lexers/fift.py @@ -0,0 +1,68 @@ +""" + pygments.lexers.fift + ~~~~~~~~~~~~~~~~~~~~ + + Lexers for fift. + + :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS. + :license: BSD, see LICENSE for details. +""" + +from pygments.lexer import RegexLexer, include +from pygments.token import Literal, Comment, Name, String, Number, Whitespace + +__all__ = ['FiftLexer'] + + +class FiftLexer(RegexLexer): + """ + For Fift source code. + """ + + name = 'Fift' + aliases = ['fift', 'fif'] + filenames = ['*.fif'] + url = 'https://ton-blockchain.github.io/docs/fiftbase.pdf' + version_added = '' + + tokens = { + 'root': [ + (r'\s+', Whitespace), + + include('comments'), + + (r'[\.+]?\"', String, 'string'), + + # numbers + (r'0x[0-9a-fA-F]+', Number.Hex), + (r'0b[01]+', Number.Bin), + (r'-?[0-9]+("/"-?[0-9]+)?', Number.Decimal), + + # slices + (r'b\{[01]+\}', Literal), + (r'x\{[0-9a-fA-F_]+\}', Literal), + + # byte literal + (r'B\{[0-9a-fA-F_]+\}', Literal), + + # treat anything as word + (r'\S+', Name) + ], + + 'string': [ + (r'\\.', String.Escape), + (r'\"', String, '#pop'), + (r'[^\"\r\n\\]+', String) + ], + + 'comments': [ + (r'//.*', Comment.Singleline), + (r'/\*', Comment.Multiline, 'comment'), + ], + 'comment': [ + (r'[^/*]+', Comment.Multiline), + (r'/\*', Comment.Multiline, '#push'), + (r'\*/', Comment.Multiline, '#pop'), + (r'[*/]', Comment.Multiline), + ], + } diff --git a/pygments/lexers/floscript.py b/pygments/lexers/floscript.py new file mode 100644 index 0000000000000000000000000000000000000000..9c7eaab5a6edd73ef292bc10bd08b727b60b9859 --- /dev/null +++ b/pygments/lexers/floscript.py @@ -0,0 +1,81 @@ +""" + pygments.lexers.floscript + ~~~~~~~~~~~~~~~~~~~~~~~~~ + + Lexer for FloScript + + :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS. + :license: BSD, see LICENSE for details. +""" + +from pygments.lexer import RegexLexer, include, bygroups +from pygments.token import Text, Comment, Operator, Keyword, Name, String, \ + Number, Punctuation, Whitespace + +__all__ = ['FloScriptLexer'] + + +class FloScriptLexer(RegexLexer): + """ + For FloScript configuration language source code. + """ + + name = 'FloScript' + url = 'https://github.com/ioflo/ioflo' + aliases = ['floscript', 'flo'] + filenames = ['*.flo'] + version_added = '2.4' + + def innerstring_rules(ttype): + return [ + # the old style '%s' % (...) string formatting + (r'%(\(\w+\))?[-#0 +]*([0-9]+|[*])?(\.([0-9]+|[*]))?' + '[hlL]?[E-GXc-giorsux%]', String.Interpol), + # backslashes, quotes and formatting signs must be parsed one at a time + (r'[^\\\'"%\n]+', ttype), + (r'[\'"\\]', ttype), + # unhandled string formatting sign + (r'%', ttype), + # newlines are an error (use "nl" state) + ] + + tokens = { + 'root': [ + (r'\s+', Whitespace), + + (r'[]{}:(),;[]', Punctuation), + (r'(\\)(\n)', bygroups(Text, Whitespace)), + (r'\\', Text), + (r'(to|by|with|from|per|for|cum|qua|via|as|at|in|of|on|re|is|if|be|into|' + r'and|not)\b', Operator.Word), + (r'!=|==|<<|>>|[-~+/*%=<>&^|.]', Operator), + (r'(load|init|server|logger|log|loggee|first|over|under|next|done|timeout|' + r'repeat|native|benter|enter|recur|exit|precur|renter|rexit|print|put|inc|' + r'copy|set|aux|rear|raze|go|let|do|bid|ready|start|stop|run|abort|use|flo|' + r'give|take)\b', Name.Builtin), + (r'(frame|framer|house)\b', Keyword), + ('"', String, 'string'), + + include('name'), + include('numbers'), + (r'#.+$', Comment.Single), + ], + 'string': [ + ('[^"]+', String), + ('"', String, '#pop'), + ], + 'numbers': [ + (r'(\d+\.\d*|\d*\.\d+)([eE][+-]?[0-9]+)?j?', Number.Float), + (r'\d+[eE][+-]?[0-9]+j?', Number.Float), + (r'0[0-7]+j?', Number.Oct), + (r'0[bB][01]+', Number.Bin), + (r'0[xX][a-fA-F0-9]+', Number.Hex), + (r'\d+L', Number.Integer.Long), + (r'\d+j?', Number.Integer) + ], + + 'name': [ + (r'@[\w.]+', Name.Decorator), + (r'[a-zA-Z_]\w*', Name), + ], + } diff --git a/pygments/lexers/forth.py b/pygments/lexers/forth.py new file mode 100644 index 0000000000000000000000000000000000000000..51f75affb4dd886a5f9089685673520e5f192204 --- /dev/null +++ b/pygments/lexers/forth.py @@ -0,0 +1,178 @@ +""" + pygments.lexers.forth + ~~~~~~~~~~~~~~~~~~~~~ + + Lexer for the Forth language. + + :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS. + :license: BSD, see LICENSE for details. +""" + +import re + +from pygments.lexer import RegexLexer, bygroups +from pygments.token import Text, Comment, Keyword, Name, String, Number, \ + Whitespace + + +__all__ = ['ForthLexer'] + + +class ForthLexer(RegexLexer): + """ + Lexer for Forth files. + """ + name = 'Forth' + url = 'https://www.forth.com/forth/' + aliases = ['forth'] + filenames = ['*.frt', '*.fs'] + mimetypes = ['application/x-forth'] + version_added = '2.2' + + flags = re.IGNORECASE | re.MULTILINE + + tokens = { + 'root': [ + (r'\s+', Whitespace), + # All comment types + (r'\\.*?$', Comment.Single), + (r'\([\s].*?\)', Comment.Single), + # defining words. The next word is a new command name + (r'(:|variable|constant|value|buffer:)(\s+)', + bygroups(Keyword.Namespace, Whitespace), 'worddef'), + # strings are rather simple + (r'([.sc]")(\s+?)', bygroups(String, Whitespace), 'stringdef'), + # keywords from the various wordsets + # *** Wordset BLOCK + (r'(blk|block|buffer|evaluate|flush|load|save-buffers|update|' + # *** Wordset BLOCK-EXT + r'empty-buffers|list|refill|scr|thru|' + # *** Wordset CORE + r'\#s|\*\/mod|\+loop|\/mod|0<|0=|1\+|1-|2!|' + r'2\*|2\/|2@|2drop|2dup|2over|2swap|>body|' + r'>in|>number|>r|\?dup|abort|abort\"|abs|' + r'accept|align|aligned|allot|and|base|begin|' + r'bl|c!|c,|c@|cell\+|cells|char|char\+|' + r'chars|constant|count|cr|create|decimal|' + r'depth|do|does>|drop|dup|else|emit|environment\?|' + r'evaluate|execute|exit|fill|find|fm\/mod|' + r'here|hold|i|if|immediate|invert|j|key|' + r'leave|literal|loop|lshift|m\*|max|min|' + r'mod|move|negate|or|over|postpone|quit|' + r'r>|r@|recurse|repeat|rot|rshift|s\"|s>d|' + r'sign|sm\/rem|source|space|spaces|state|swap|' + r'then|type|u\.|u\<|um\*|um\/mod|unloop|until|' + r'variable|while|word|xor|\[char\]|\[\'\]|' + r'@|!|\#|<\#|\#>|:|;|\+|-|\*|\/|,|<|>|\|1\+|1-|\.|' + # *** Wordset CORE-EXT + r'\.r|0<>|' + r'0>|2>r|2r>|2r@|:noname|\?do|again|c\"|' + r'case|compile,|endcase|endof|erase|false|' + r'hex|marker|nip|of|pad|parse|pick|refill|' + r'restore-input|roll|save-input|source-id|to|' + r'true|tuck|u\.r|u>|unused|value|within|' + r'\[compile\]|' + # *** Wordset CORE-EXT-obsolescent + r'\#tib|convert|expect|query|span|' + r'tib|' + # *** Wordset DOUBLE + r'2constant|2literal|2variable|d\+|d-|' + r'd\.|d\.r|d0<|d0=|d2\*|d2\/|d<|d=|d>s|' + r'dabs|dmax|dmin|dnegate|m\*\/|m\+|' + # *** Wordset DOUBLE-EXT + r'2rot|du<|' + # *** Wordset EXCEPTION + r'catch|throw|' + # *** Wordset EXCEPTION-EXT + r'abort|abort\"|' + # *** Wordset FACILITY + r'at-xy|key\?|page|' + # *** Wordset FACILITY-EXT + r'ekey|ekey>char|ekey\?|emit\?|ms|time&date|' + # *** Wordset FILE + r'BIN|CLOSE-FILE|CREATE-FILE|DELETE-FILE|FILE-POSITION|' + r'FILE-SIZE|INCLUDE-FILE|INCLUDED|OPEN-FILE|R\/O|' + r'R\/W|READ-FILE|READ-LINE|REPOSITION-FILE|RESIZE-FILE|' + r'S\"|SOURCE-ID|W/O|WRITE-FILE|WRITE-LINE|' + # *** Wordset FILE-EXT + r'FILE-STATUS|FLUSH-FILE|REFILL|RENAME-FILE|' + # *** Wordset FLOAT + r'>float|d>f|' + r'f!|f\*|f\+|f-|f\/|f0<|f0=|f<|f>d|f@|' + r'falign|faligned|fconstant|fdepth|fdrop|fdup|' + r'fliteral|float\+|floats|floor|fmax|fmin|' + r'fnegate|fover|frot|fround|fswap|fvariable|' + r'represent|' + # *** Wordset FLOAT-EXT + r'df!|df@|dfalign|dfaligned|dfloat\+|' + r'dfloats|f\*\*|f\.|fabs|facos|facosh|falog|' + r'fasin|fasinh|fatan|fatan2|fatanh|fcos|fcosh|' + r'fe\.|fexp|fexpm1|fln|flnp1|flog|fs\.|fsin|' + r'fsincos|fsinh|fsqrt|ftan|ftanh|f~|precision|' + r'set-precision|sf!|sf@|sfalign|sfaligned|sfloat\+|' + r'sfloats|' + # *** Wordset LOCAL + r'\(local\)|to|' + # *** Wordset LOCAL-EXT + r'locals\||' + # *** Wordset MEMORY + r'allocate|free|resize|' + # *** Wordset SEARCH + r'definitions|find|forth-wordlist|get-current|' + r'get-order|search-wordlist|set-current|set-order|' + r'wordlist|' + # *** Wordset SEARCH-EXT + r'also|forth|only|order|previous|' + # *** Wordset STRING + r'-trailing|\/string|blank|cmove|cmove>|compare|' + r'search|sliteral|' + # *** Wordset TOOLS + r'.s|dump|see|words|' + # *** Wordset TOOLS-EXT + r';code|' + r'ahead|assembler|bye|code|cs-pick|cs-roll|' + r'editor|state|\[else\]|\[if\]|\[then\]|' + # *** Wordset TOOLS-EXT-obsolescent + r'forget|' + # Forth 2012 + r'defer|defer@|defer!|action-of|begin-structure|field:|buffer:|' + r'parse-name|buffer:|traverse-wordlist|n>r|nr>|2value|fvalue|' + r'name>interpret|name>compile|name>string|' + r'cfield:|end-structure)(?!\S)', Keyword), + + # Numbers + (r'(\$[0-9A-F]+)', Number.Hex), + (r'(\#|%|&|\-|\+)?[0-9]+', Number.Integer), + (r'(\#|%|&|\-|\+)?[0-9.]+', Keyword.Type), + # amforth specific + (r'(@i|!i|@e|!e|pause|noop|turnkey|sleep|' + r'itype|icompare|sp@|sp!|rp@|rp!|up@|up!|' + r'>a|a>|a@|a!|a@+|a@-|>b|b>|b@|b!|b@+|b@-|' + r'find-name|1ms|' + r'sp0|rp0|\(evaluate\)|int-trap|int!)(?!\S)', + Name.Constant), + # a proposal + (r'(do-recognizer|r:fail|recognizer:|get-recognizers|' + r'set-recognizers|r:float|r>comp|r>int|r>post|' + r'r:name|r:word|r:dnum|r:num|recognizer|forth-recognizer|' + r'rec:num|rec:float|rec:word)(?!\S)', Name.Decorator), + # defining words. The next word is a new command name + (r'(Evalue|Rvalue|Uvalue|Edefer|Rdefer|Udefer)(\s+)', + bygroups(Keyword.Namespace, Text), 'worddef'), + + (r'\S+', Name.Function), # Anything else is executed + + ], + 'worddef': [ + (r'\S+', Name.Class, '#pop'), + ], + 'stringdef': [ + (r'[^"]+', String, '#pop'), + ], + } + + def analyse_text(text): + """Forth uses : COMMAND ; quite a lot in a single line, so we're trying + to find that.""" + if re.search('\n:[^\n]+;\n', text): + return 0.3 diff --git a/pygments/lexers/fortran.py b/pygments/lexers/fortran.py new file mode 100644 index 0000000000000000000000000000000000000000..a6230f0249c5eddb550f92b0dcf14346d3124d95 --- /dev/null +++ b/pygments/lexers/fortran.py @@ -0,0 +1,212 @@ +""" + pygments.lexers.fortran + ~~~~~~~~~~~~~~~~~~~~~~~ + + Lexers for Fortran languages. + + :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS. + :license: BSD, see LICENSE for details. +""" + +import re + +from pygments.lexer import RegexLexer, bygroups, include, words, using, default +from pygments.token import Text, Comment, Operator, Keyword, Name, String, \ + Number, Punctuation, Generic + +__all__ = ['FortranLexer', 'FortranFixedLexer'] + + +class FortranLexer(RegexLexer): + """ + Lexer for FORTRAN 90 code. + """ + name = 'Fortran' + url = 'https://fortran-lang.org/' + aliases = ['fortran', 'f90'] + filenames = ['*.f03', '*.f90', '*.F03', '*.F90'] + mimetypes = ['text/x-fortran'] + version_added = '0.10' + flags = re.IGNORECASE | re.MULTILINE + + # Data Types: INTEGER, REAL, COMPLEX, LOGICAL, CHARACTER and DOUBLE PRECISION + # Operators: **, *, +, -, /, <, >, <=, >=, ==, /= + # Logical (?): NOT, AND, OR, EQV, NEQV + + # Builtins: + # http://gcc.gnu.org/onlinedocs/gcc-3.4.6/g77/Table-of-Intrinsic-Functions.html + + tokens = { + 'root': [ + (r'^#.*\n', Comment.Preproc), + (r'!.*\n', Comment), + include('strings'), + include('core'), + (r'[a-z][\w$]*', Name), + include('nums'), + (r'[\s]+', Text.Whitespace), + ], + 'core': [ + # Statements + + (r'\b(DO)(\s+)(CONCURRENT)\b', bygroups(Keyword, Text.Whitespace, Keyword)), + (r'\b(GO)(\s*)(TO)\b', bygroups(Keyword, Text.Whitespace, Keyword)), + + (words(( + 'ABSTRACT', 'ACCEPT', 'ALL', 'ALLSTOP', 'ALLOCATABLE', 'ALLOCATE', + 'ARRAY', 'ASSIGN', 'ASSOCIATE', 'ASYNCHRONOUS', 'BACKSPACE', 'BIND', + 'BLOCK', 'BLOCKDATA', 'BYTE', 'CALL', 'CASE', 'CLASS', 'CLOSE', + 'CODIMENSION', 'COMMON', 'CONTIGUOUS', 'CONTAINS', + 'CONTINUE', 'CRITICAL', 'CYCLE', 'DATA', 'DEALLOCATE', 'DECODE', + 'DEFERRED', 'DIMENSION', 'DO', 'ELEMENTAL', 'ELSE', 'ELSEIF', 'ENCODE', + 'END', 'ENDASSOCIATE', 'ENDBLOCK', 'ENDDO', 'ENDENUM', 'ENDFORALL', + 'ENDFUNCTION', 'ENDIF', 'ENDINTERFACE', 'ENDMODULE', 'ENDPROGRAM', + 'ENDSELECT', 'ENDSUBMODULE', 'ENDSUBROUTINE', 'ENDTYPE', 'ENDWHERE', + 'ENTRY', 'ENUM', 'ENUMERATOR', 'EQUIVALENCE', 'ERROR STOP', 'EXIT', + 'EXTENDS', 'EXTERNAL', 'EXTRINSIC', 'FILE', 'FINAL', 'FORALL', 'FORMAT', + 'FUNCTION', 'GENERIC', 'IF', 'IMAGES', 'IMPLICIT', + 'IMPORT', 'IMPURE', 'INCLUDE', 'INQUIRE', 'INTENT', 'INTERFACE', + 'INTRINSIC', 'IS', 'LOCK', 'MEMORY', 'MODULE', 'NAMELIST', 'NULLIFY', + 'NONE', 'NON_INTRINSIC', 'NON_OVERRIDABLE', 'NOPASS', 'ONLY', 'OPEN', + 'OPTIONAL', 'OPTIONS', 'PARAMETER', 'PASS', 'PAUSE', 'POINTER', 'PRINT', + 'PRIVATE', 'PROGRAM', 'PROCEDURE', 'PROTECTED', 'PUBLIC', 'PURE', 'READ', + 'RECURSIVE', 'RESULT', 'RETURN', 'REWIND', 'SAVE', 'SELECT', 'SEQUENCE', + 'STOP', 'SUBMODULE', 'SUBROUTINE', 'SYNC', 'SYNCALL', 'SYNCIMAGES', + 'SYNCMEMORY', 'TARGET', 'THEN', 'TYPE', 'UNLOCK', 'USE', 'VALUE', + 'VOLATILE', 'WHERE', 'WRITE', 'WHILE'), prefix=r'\b', suffix=r'\s*\b'), + Keyword), + + # Data Types + (words(( + 'CHARACTER', 'COMPLEX', 'DOUBLE PRECISION', 'DOUBLE COMPLEX', 'INTEGER', + 'LOGICAL', 'REAL', 'C_INT', 'C_SHORT', 'C_LONG', 'C_LONG_LONG', + 'C_SIGNED_CHAR', 'C_SIZE_T', 'C_INT8_T', 'C_INT16_T', 'C_INT32_T', + 'C_INT64_T', 'C_INT_LEAST8_T', 'C_INT_LEAST16_T', 'C_INT_LEAST32_T', + 'C_INT_LEAST64_T', 'C_INT_FAST8_T', 'C_INT_FAST16_T', 'C_INT_FAST32_T', + 'C_INT_FAST64_T', 'C_INTMAX_T', 'C_INTPTR_T', 'C_FLOAT', 'C_DOUBLE', + 'C_LONG_DOUBLE', 'C_FLOAT_COMPLEX', 'C_DOUBLE_COMPLEX', + 'C_LONG_DOUBLE_COMPLEX', 'C_BOOL', 'C_CHAR', 'C_PTR', 'C_FUNPTR'), + prefix=r'\b', suffix=r'\s*\b'), + Keyword.Type), + + # Operators + (r'(\*\*|\*|\+|-|\/|<|>|<=|>=|==|\/=|=)', Operator), + + (r'(::)', Keyword.Declaration), + + (r'[()\[\],:&%;.]', Punctuation), + # Intrinsics + (words(( + 'Abort', 'Abs', 'Access', 'AChar', 'ACos', 'ACosH', 'AdjustL', + 'AdjustR', 'AImag', 'AInt', 'Alarm', 'All', 'Allocated', 'ALog', + 'AMax', 'AMin', 'AMod', 'And', 'ANInt', 'Any', 'ASin', 'ASinH', + 'Associated', 'ATan', 'ATanH', 'Atomic_Define', 'Atomic_Ref', + 'BesJ', 'BesJN', 'Bessel_J0', 'Bessel_J1', 'Bessel_JN', 'Bessel_Y0', + 'Bessel_Y1', 'Bessel_YN', 'BesY', 'BesYN', 'BGE', 'BGT', 'BLE', + 'BLT', 'Bit_Size', 'BTest', 'CAbs', 'CCos', 'Ceiling', 'CExp', + 'Char', 'ChDir', 'ChMod', 'CLog', 'Cmplx', 'Command_Argument_Count', + 'Complex', 'Conjg', 'Cos', 'CosH', 'Count', 'CPU_Time', 'CShift', + 'CSin', 'CSqRt', 'CTime', 'C_Loc', 'C_Associated', + 'C_Null_Ptr', 'C_Null_Funptr', 'C_F_Pointer', 'C_F_ProcPointer', + 'C_Null_Char', 'C_Alert', 'C_Backspace', 'C_Form_Feed', 'C_FunLoc', + 'C_Sizeof', 'C_New_Line', 'C_Carriage_Return', + 'C_Horizontal_Tab', 'C_Vertical_Tab', 'DAbs', 'DACos', 'DASin', + 'DATan', 'Date_and_Time', 'DbesJ', 'DbesJN', 'DbesY', + 'DbesYN', 'Dble', 'DCos', 'DCosH', 'DDiM', 'DErF', + 'DErFC', 'DExp', 'Digits', 'DiM', 'DInt', 'DLog', 'DMax', + 'DMin', 'DMod', 'DNInt', 'Dot_Product', 'DProd', 'DSign', 'DSinH', + 'DShiftL', 'DShiftR', 'DSin', 'DSqRt', 'DTanH', 'DTan', 'DTime', + 'EOShift', 'Epsilon', 'ErF', 'ErFC', 'ErFC_Scaled', 'ETime', + 'Execute_Command_Line', 'Exit', 'Exp', 'Exponent', 'Extends_Type_Of', + 'FDate', 'FGet', 'FGetC', 'FindLoc', 'Float', 'Floor', 'Flush', + 'FNum', 'FPutC', 'FPut', 'Fraction', 'FSeek', 'FStat', 'FTell', + 'Gamma', 'GError', 'GetArg', 'Get_Command', 'Get_Command_Argument', + 'Get_Environment_Variable', 'GetCWD', 'GetEnv', 'GetGId', 'GetLog', + 'GetPId', 'GetUId', 'GMTime', 'HostNm', 'Huge', 'Hypot', 'IAbs', + 'IAChar', 'IAll', 'IAnd', 'IAny', 'IArgC', 'IBClr', 'IBits', + 'IBSet', 'IChar', 'IDate', 'IDiM', 'IDInt', 'IDNInt', 'IEOr', + 'IErrNo', 'IFix', 'Imag', 'ImagPart', 'Image_Index', 'Index', + 'Int', 'IOr', 'IParity', 'IRand', 'IsaTty', 'IShft', 'IShftC', + 'ISign', 'Iso_C_Binding', 'Is_Contiguous', 'Is_Iostat_End', + 'Is_Iostat_Eor', 'ITime', 'Kill', 'Kind', 'LBound', 'LCoBound', + 'Len', 'Len_Trim', 'LGe', 'LGt', 'Link', 'LLe', 'LLt', 'LnBlnk', + 'Loc', 'Log', 'Log_Gamma', 'Logical', 'Long', 'LShift', 'LStat', + 'LTime', 'MaskL', 'MaskR', 'MatMul', 'Max', 'MaxExponent', + 'MaxLoc', 'MaxVal', 'MClock', 'Merge', 'Merge_Bits', 'Move_Alloc', + 'Min', 'MinExponent', 'MinLoc', 'MinVal', 'Mod', 'Modulo', 'MvBits', + 'Nearest', 'New_Line', 'NInt', 'Norm2', 'Not', 'Null', 'Num_Images', + 'Or', 'Pack', 'Parity', 'PError', 'Precision', 'Present', 'Product', + 'Radix', 'Rand', 'Random_Number', 'Random_Seed', 'Range', 'Real', + 'RealPart', 'Rename', 'Repeat', 'Reshape', 'RRSpacing', 'RShift', + 'Same_Type_As', 'Scale', 'Scan', 'Second', 'Selected_Char_Kind', + 'Selected_Int_Kind', 'Selected_Real_Kind', 'Set_Exponent', 'Shape', + 'ShiftA', 'ShiftL', 'ShiftR', 'Short', 'Sign', 'Signal', 'SinH', + 'Sin', 'Sleep', 'Sngl', 'Spacing', 'Spread', 'SqRt', 'SRand', + 'Stat', 'Storage_Size', 'Sum', 'SymLnk', 'System', 'System_Clock', + 'Tan', 'TanH', 'Time', 'This_Image', 'Tiny', 'TrailZ', 'Transfer', + 'Transpose', 'Trim', 'TtyNam', 'UBound', 'UCoBound', 'UMask', + 'Unlink', 'Unpack', 'Verify', 'XOr', 'ZAbs', 'ZCos', 'ZExp', + 'ZLog', 'ZSin', 'ZSqRt'), prefix=r'\b', suffix=r'\s*\b'), + Name.Builtin), + + # Booleans + (r'\.(true|false)\.', Name.Builtin), + # Comparing Operators + (r'\.(eq|ne|lt|le|gt|ge|not|and|or|eqv|neqv)\.', Operator.Word), + ], + + 'strings': [ + (r'"(\\[0-7]+|\\[^0-7]|[^"\\])*"', String.Double), + (r"'(\\[0-7]+|\\[^0-7]|[^'\\])*'", String.Single), + ], + + 'nums': [ + (r'\d+(?![.e])(_([1-9]|[a-z]\w*))?', Number.Integer), + (r'[+-]?\d*\.\d+([ed][-+]?\d+)?(_([1-9]|[a-z]\w*))?', Number.Float), + (r'[+-]?\d+\.\d*([ed][-+]?\d+)?(_([1-9]|[a-z]\w*))?', Number.Float), + (r'[+-]?\d+(\.\d*)?[ed][-+]?\d+(_([1-9]|[a-z]\w*))?', Number.Float), + ], + } + + +class FortranFixedLexer(RegexLexer): + """ + Lexer for fixed format Fortran. + """ + name = 'FortranFixed' + aliases = ['fortranfixed'] + filenames = ['*.f', '*.F'] + url = 'https://fortran-lang.org/' + version_added = '2.1' + + flags = re.IGNORECASE + + def _lex_fortran(self, match, ctx=None): + """Lex a line just as free form fortran without line break.""" + lexer = FortranLexer() + text = match.group(0) + "\n" + for index, token, value in lexer.get_tokens_unprocessed(text): + value = value.replace('\n', '') + if value != '': + yield index, token, value + + tokens = { + 'root': [ + (r'[C*].*\n', Comment), + (r'#.*\n', Comment.Preproc), + (r' {0,4}!.*\n', Comment), + (r'(.{5})', Name.Label, 'cont-char'), + (r'.*\n', using(FortranLexer)), + ], + 'cont-char': [ + (' ', Text, 'code'), + ('0', Comment, 'code'), + ('.', Generic.Strong, 'code'), + ], + 'code': [ + (r'(.{66})(.*)(\n)', + bygroups(_lex_fortran, Comment, Text.Whitespace), 'root'), + (r'(.*)(\n)', bygroups(_lex_fortran, Text.Whitespace), 'root'), + default('root'), + ] + } diff --git a/pygments/lexers/foxpro.py b/pygments/lexers/foxpro.py new file mode 100644 index 0000000000000000000000000000000000000000..3e7c056975e03357e067fd2a3810ac57670d8dbc --- /dev/null +++ b/pygments/lexers/foxpro.py @@ -0,0 +1,427 @@ +""" + pygments.lexers.foxpro + ~~~~~~~~~~~~~~~~~~~~~~ + + Simple lexer for Microsoft Visual FoxPro source code. + + :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS. + :license: BSD, see LICENSE for details. +""" + +import re + +from pygments.lexer import RegexLexer +from pygments.token import Punctuation, Text, Comment, Operator, Keyword, \ + Name, String + +__all__ = ['FoxProLexer'] + + +class FoxProLexer(RegexLexer): + """Lexer for Microsoft Visual FoxPro language. + + FoxPro syntax allows to shorten all keywords and function names + to 4 characters. Shortened forms are not recognized by this lexer. + """ + + name = 'FoxPro' + aliases = ['foxpro', 'vfp', 'clipper', 'xbase'] + filenames = ['*.PRG', '*.prg'] + version_added = '1.6' + mimetype = [] + url = 'https://learn.microsoft.com/en-us/previous-versions/visualstudio/foxpro' + + flags = re.IGNORECASE | re.MULTILINE + + tokens = { + 'root': [ + (r';\s*\n', Punctuation), # consume newline + (r'(^|\n)\s*', Text, 'newline'), + + # Square brackets may be used for array indices + # and for string literal. Look for arrays + # before matching string literals. + (r'(?<=\w)\[[0-9, ]+\]', Text), + (r'\'[^\'\n]*\'|"[^"\n]*"|\[[^]*]\]', String), + (r'(^\s*\*|&&|&&).*?\n', Comment.Single), + + (r'(ABS|ACLASS|ACOPY|ACOS|ADATABASES|ADBOBJECTS|ADDBS|' + r'ADDPROPERTY|ADEL|ADIR|ADLLS|ADOCKSTATE|AELEMENT|AERROR|' + r'AEVENTS|AFIELDS|AFONT|AGETCLASS|AGETFILEVERSION|AINS|' + r'AINSTANCE|ALANGUAGE|ALEN|ALIAS|ALINES|ALLTRIM|' + r'AMEMBERS|AMOUSEOBJ|ANETRESOURCES|APRINTERS|APROCINFO|' + r'ASC|ASCAN|ASELOBJ|ASESSIONS|ASIN|ASORT|ASQLHANDLES|' + r'ASTACKINFO|ASUBSCRIPT|AT|AT_C|ATAGINFO|ATAN|ATC|ATCC|' + r'ATCLINE|ATLINE|ATN2|AUSED|AVCXCLASSES|BAR|BARCOUNT|' + r'BARPROMPT|BETWEEN|BINDEVENT|BINTOC|BITAND|BITCLEAR|' + r'BITLSHIFT|BITNOT|BITOR|BITRSHIFT|BITSET|BITTEST|BITXOR|' + r'BOF|CANDIDATE|CAPSLOCK|CAST|CDOW|CDX|CEILING|CHR|CHRSAW|' + r'CHRTRAN|CHRTRANC|CLEARRESULTSET|CMONTH|CNTBAR|CNTPAD|COL|' + r'COM|Functions|COMARRAY|COMCLASSINFO|COMPOBJ|COMPROP|' + r'COMRETURNERROR|COS|CPCONVERT|CPCURRENT|CPDBF|CREATEBINARY|' + r'CREATEOBJECT|CREATEOBJECTEX|CREATEOFFLINE|CTOBIN|CTOD|' + r'CTOT|CURDIR|CURSORGETPROP|CURSORSETPROP|CURSORTOXML|' + r'CURVAL|DATE|DATETIME|DAY|DBC|DBF|DBGETPROP|DBSETPROP|' + r'DBUSED|DDEAbortTrans|DDEAdvise|DDEEnabled|DDEExecute|' + r'DDEInitiate|DDELastError|DDEPoke|DDERequest|DDESetOption|' + r'DDESetService|DDESetTopic|DDETerminate|DEFAULTEXT|' + r'DELETED|DESCENDING|DIFFERENCE|DIRECTORY|DISKSPACE|' + r'DisplayPath|DMY|DODEFAULT|DOW|DRIVETYPE|DROPOFFLINE|' + r'DTOC|DTOR|DTOS|DTOT|EDITSOURCE|EMPTY|EOF|ERROR|EVAL(UATE)?|' + r'EVENTHANDLER|EVL|EXECSCRIPT|EXP|FCHSIZE|FCLOSE|FCOUNT|' + r'FCREATE|FDATE|FEOF|FERROR|FFLUSH|FGETS|FIELD|FILE|' + r'FILETOSTR|FILTER|FKLABEL|FKMAX|FLDLIST|FLOCK|FLOOR|' + r'FONTMETRIC|FOPEN|FOR|FORCEEXT|FORCEPATH|FOUND|FPUTS|' + r'FREAD|FSEEK|FSIZE|FTIME|FULLPATH|FV|FWRITE|' + r'GETAUTOINCVALUE|GETBAR|GETCOLOR|GETCP|GETDIR|GETENV|' + r'GETFILE|GETFLDSTATE|GETFONT|GETINTERFACE|' + r'GETNEXTMODIFIED|GETOBJECT|GETPAD|GETPEM|GETPICT|' + r'GETPRINTER|GETRESULTSET|GETWORDCOUNT|GETWORDNUM|' + r'GETCURSORADAPTER|GOMONTH|HEADER|HOME|HOUR|ICASE|' + r'IDXCOLLATE|IIF|IMESTATUS|INDBC|INDEXSEEK|INKEY|INLIST|' + r'INPUTBOX|INSMODE|INT|ISALPHA|ISBLANK|ISCOLOR|ISDIGIT|' + r'ISEXCLUSIVE|ISFLOCKED|ISLEADBYTE|ISLOWER|ISMEMOFETCHED|' + r'ISMOUSE|ISNULL|ISPEN|ISREADONLY|ISRLOCKED|' + r'ISTRANSACTABLE|ISUPPER|JUSTDRIVE|JUSTEXT|JUSTFNAME|' + r'JUSTPATH|JUSTSTEM|KEY|KEYMATCH|LASTKEY|LEFT|LEFTC|LEN|' + r'LENC|LIKE|LIKEC|LINENO|LOADPICTURE|LOCFILE|LOCK|LOG|' + r'LOG10|LOOKUP|LOWER|LTRIM|LUPDATE|MAKETRANSACTABLE|MAX|' + r'MCOL|MDOWN|MDX|MDY|MEMLINES|MEMORY|MENU|MESSAGE|' + r'MESSAGEBOX|MIN|MINUTE|MLINE|MOD|MONTH|MRKBAR|MRKPAD|' + r'MROW|MTON|MWINDOW|NDX|NEWOBJECT|NORMALIZE|NTOM|NUMLOCK|' + r'NVL|OBJNUM|OBJTOCLIENT|OBJVAR|OCCURS|OEMTOANSI|OLDVAL|' + r'ON|ORDER|OS|PAD|PADL|PARAMETERS|PAYMENT|PCOL|PCOUNT|' + r'PEMSTATUS|PI|POPUP|PRIMARY|PRINTSTATUS|PRMBAR|PRMPAD|' + r'PROGRAM|PROMPT|PROPER|PROW|PRTINFO|PUTFILE|PV|QUARTER|' + r'RAISEEVENT|RAND|RAT|RATC|RATLINE|RDLEVEL|READKEY|RECCOUNT|' + r'RECNO|RECSIZE|REFRESH|RELATION|REPLICATE|REQUERY|RGB|' + r'RGBSCHEME|RIGHT|RIGHTC|RLOCK|ROUND|ROW|RTOD|RTRIM|' + r'SAVEPICTURE|SCHEME|SCOLS|SEC|SECONDS|SEEK|SELECT|SET|' + r'SETFLDSTATE|SETRESULTSET|SIGN|SIN|SKPBAR|SKPPAD|SOUNDEX|' + r'SPACE|SQLCANCEL|SQLCOLUMNS|SQLCOMMIT|SQLCONNECT|' + r'SQLDISCONNECT|SQLEXEC|SQLGETPROP|SQLIDLEDISCONNECT|' + r'SQLMORERESULTS|SQLPREPARE|SQLROLLBACK|SQLSETPROP|' + r'SQLSTRINGCONNECT|SQLTABLES|SQRT|SROWS|STR|STRCONV|' + r'STREXTRACT|STRTOFILE|STRTRAN|STUFF|STUFFC|SUBSTR|' + r'SUBSTRC|SYS|SYSMETRIC|TABLEREVERT|TABLEUPDATE|TAG|' + r'TAGCOUNT|TAGNO|TAN|TARGET|TEXTMERGE|TIME|TRANSFORM|' + r'TRIM|TTOC|TTOD|TXNLEVEL|TXTWIDTH|TYPE|UNBINDEVENTS|' + r'UNIQUE|UPDATED|UPPER|USED|VAL|VARREAD|VARTYPE|VERSION|' + r'WBORDER|WCHILD|WCOLS|WDOCKABLE|WEEK|WEXIST|WFONT|WLAST|' + r'WLCOL|WLROW|WMAXIMUM|WMINIMUM|WONTOP|WOUTPUT|WPARENT|' + r'WREAD|WROWS|WTITLE|WVISIBLE|XMLTOCURSOR|XMLUPDATEGRAM|' + r'YEAR)(?=\s*\()', Name.Function), + + (r'_ALIGNMENT|_ASCIICOLS|_ASCIIROWS|_ASSIST|_BEAUTIFY|_BOX|' + r'_BROWSER|_BUILDER|_CALCMEM|_CALCVALUE|_CLIPTEXT|_CONVERTER|' + r'_COVERAGE|_CUROBJ|_DBLCLICK|_DIARYDATE|_DOS|_FOXDOC|_FOXREF|' + r'_GALLERY|_GENGRAPH|_GENHTML|_GENMENU|_GENPD|_GENSCRN|' + r'_GENXTAB|_GETEXPR|_INCLUDE|_INCSEEK|_INDENT|_LMARGIN|_MAC|' + r'_MENUDESIGNER|_MLINE|_PADVANCE|_PAGENO|_PAGETOTAL|_PBPAGE|' + r'_PCOLNO|_PCOPIES|_PDRIVER|_PDSETUP|_PECODE|_PEJECT|_PEPAGE|' + r'_PLENGTH|_PLINENO|_PLOFFSET|_PPITCH|_PQUALITY|_PRETEXT|' + r'_PSCODE|_PSPACING|_PWAIT|_RMARGIN|_REPORTBUILDER|' + r'_REPORTOUTPUT|_REPORTPREVIEW|_SAMPLES|_SCCTEXT|_SCREEN|' + r'_SHELL|_SPELLCHK|_STARTUP|_TABS|_TALLY|_TASKPANE|_TEXT|' + r'_THROTTLE|_TOOLBOX|_TOOLTIPTIMEOUT|_TRANSPORT|_TRIGGERLEVEL|' + r'_UNIX|_VFP|_WINDOWS|_WIZARD|_WRAP', Keyword.Pseudo), + + (r'THISFORMSET|THISFORM|THIS', Name.Builtin), + + (r'Application|CheckBox|Collection|Column|ComboBox|' + r'CommandButton|CommandGroup|Container|Control|CursorAdapter|' + r'Cursor|Custom|DataEnvironment|DataObject|EditBox|' + r'Empty|Exception|Fields|Files|File|FormSet|Form|FoxCode|' + r'Grid|Header|Hyperlink|Image|Label|Line|ListBox|Objects|' + r'OptionButton|OptionGroup|PageFrame|Page|ProjectHook|Projects|' + r'Project|Relation|ReportListener|Separator|Servers|Server|' + r'Session|Shape|Spinner|Tables|TextBox|Timer|ToolBar|' + r'XMLAdapter|XMLField|XMLTable', Name.Class), + + (r'm\.[a-z_]\w*', Name.Variable), + (r'\.(F|T|AND|OR|NOT|NULL)\.|\b(AND|OR|NOT|NULL)\b', Operator.Word), + + (r'\.(ActiveColumn|ActiveControl|ActiveForm|ActivePage|' + r'ActiveProject|ActiveRow|AddLineFeeds|ADOCodePage|Alias|' + r'Alignment|Align|AllowAddNew|AllowAutoColumnFit|' + r'AllowCellSelection|AllowDelete|AllowHeaderSizing|' + r'AllowInsert|AllowModalMessages|AllowOutput|AllowRowSizing|' + r'AllowSimultaneousFetch|AllowTabs|AllowUpdate|' + r'AlwaysOnBottom|AlwaysOnTop|Anchor|Application|' + r'AutoActivate|AutoCenter|AutoCloseTables|AutoComplete|' + r'AutoCompSource|AutoCompTable|AutoHideScrollBar|' + r'AutoIncrement|AutoOpenTables|AutoRelease|AutoSize|' + r'AutoVerbMenu|AutoYield|BackColor|ForeColor|BackStyle|' + r'BaseClass|BatchUpdateCount|BindControls|BorderColor|' + r'BorderStyle|BorderWidth|BoundColumn|BoundTo|Bound|' + r'BreakOnError|BufferModeOverride|BufferMode|' + r'BuildDateTime|ButtonCount|Buttons|Cancel|Caption|' + r'Centered|Century|ChildAlias|ChildOrder|ChildTable|' + r'ClassLibrary|Class|ClipControls|Closable|CLSID|CodePage|' + r'ColorScheme|ColorSource|ColumnCount|ColumnLines|' + r'ColumnOrder|Columns|ColumnWidths|CommandClauses|' + r'Comment|CompareMemo|ConflictCheckCmd|ConflictCheckType|' + r'ContinuousScroll|ControlBox|ControlCount|Controls|' + r'ControlSource|ConversionFunc|Count|CurrentControl|' + r'CurrentDataSession|CurrentPass|CurrentX|CurrentY|' + r'CursorSchema|CursorSource|CursorStatus|Curvature|' + r'Database|DataSessionID|DataSession|DataSourceType|' + r'DataSource|DataType|DateFormat|DateMark|Debug|' + r'DeclareXMLPrefix|DEClassLibrary|DEClass|DefaultFilePath|' + r'Default|DefOLELCID|DeleteCmdDataSourceType|DeleteCmdDataSource|' + r'DeleteCmd|DeleteMark|Description|Desktop|' + r'Details|DisabledBackColor|DisabledForeColor|' + r'DisabledItemBackColor|DisabledItemForeColor|' + r'DisabledPicture|DisableEncode|DisplayCount|' + r'DisplayValue|Dockable|Docked|DockPosition|' + r'DocumentFile|DownPicture|DragIcon|DragMode|DrawMode|' + r'DrawStyle|DrawWidth|DynamicAlignment|DynamicBackColor|' + r'DynamicForeColor|DynamicCurrentControl|DynamicFontBold|' + r'DynamicFontItalic|DynamicFontStrikethru|' + r'DynamicFontUnderline|DynamicFontName|DynamicFontOutline|' + r'DynamicFontShadow|DynamicFontSize|DynamicInputMask|' + r'DynamicLineHeight|EditorOptions|Enabled|' + r'EnableHyperlinks|Encrypted|ErrorNo|Exclude|Exclusive|' + r'FetchAsNeeded|FetchMemoCmdList|FetchMemoDataSourceType|' + r'FetchMemoDataSource|FetchMemo|FetchSize|' + r'FileClassLibrary|FileClass|FillColor|FillStyle|Filter|' + r'FirstElement|FirstNestedTable|Flags|FontBold|FontItalic|' + r'FontStrikethru|FontUnderline|FontCharSet|FontCondense|' + r'FontExtend|FontName|FontOutline|FontShadow|FontSize|' + r'ForceCloseTag|Format|FormCount|FormattedOutput|Forms|' + r'FractionDigits|FRXDataSession|FullName|GDIPlusGraphics|' + r'GridLineColor|GridLines|GridLineWidth|HalfHeightCaption|' + r'HeaderClassLibrary|HeaderClass|HeaderHeight|Height|' + r'HelpContextID|HideSelection|HighlightBackColor|' + r'HighlightForeColor|HighlightStyle|HighlightRowLineWidth|' + r'HighlightRow|Highlight|HomeDir|Hours|HostName|' + r'HScrollSmallChange|hWnd|Icon|IncrementalSearch|Increment|' + r'InitialSelectedAlias|InputMask|InsertCmdDataSourceType|' + r'InsertCmdDataSource|InsertCmdRefreshCmd|' + r'InsertCmdRefreshFieldList|InsertCmdRefreshKeyFieldList|' + r'InsertCmd|Instancing|IntegralHeight|' + r'Interval|IMEMode|IsAttribute|IsBase64|IsBinary|IsNull|' + r'IsDiffGram|IsLoaded|ItemBackColor,|ItemData|ItemIDData|' + r'ItemTips|IXMLDOMElement|KeyboardHighValue|KeyboardLowValue|' + r'Keyfield|KeyFieldList|KeyPreview|KeySort|LanguageOptions|' + r'LeftColumn|Left|LineContents|LineNo|LineSlant|LinkMaster|' + r'ListCount|ListenerType|ListIndex|ListItemID|ListItem|' + r'List|LockColumnsLeft|LockColumns|LockScreen|MacDesktop|' + r'MainFile|MapN19_4ToCurrency|MapBinary|MapVarchar|Margin|' + r'MaxButton|MaxHeight|MaxLeft|MaxLength|MaxRecords|MaxTop|' + r'MaxWidth|MDIForm|MemberClassLibrary|MemberClass|' + r'MemoWindow|Message|MinButton|MinHeight|MinWidth|' + r'MouseIcon|MousePointer|Movable|MoverBars|MultiSelect|' + r'Name|NestedInto|NewIndex|NewItemID|NextSiblingTable|' + r'NoCpTrans|NoDataOnLoad|NoData|NullDisplay|' + r'NumberOfElements|Object|OLEClass|OLEDragMode|' + r'OLEDragPicture|OLEDropEffects|OLEDropHasData|' + r'OLEDropMode|OLEDropTextInsertion|OLELCID|' + r'OLERequestPendingTimeout|OLEServerBusyRaiseError|' + r'OLEServerBusyTimeout|OLETypeAllowed|OneToMany|' + r'OpenViews|OpenWindow|Optimize|OrderDirection|Order|' + r'OutputPageCount|OutputType|PageCount|PageHeight|' + r'PageNo|PageOrder|Pages|PageTotal|PageWidth|' + r'PanelLink|Panel|ParentAlias|ParentClass|ParentTable|' + r'Parent|Partition|PasswordChar|PictureMargin|' + r'PicturePosition|PictureSpacing|PictureSelectionDisplay|' + r'PictureVal|Picture|Prepared|' + r'PolyPoints|PreserveWhiteSpace|PreviewContainer|' + r'PrintJobName|Procedure|PROCESSID|ProgID|ProjectHookClass|' + r'ProjectHookLibrary|ProjectHook|QuietMode|' + r'ReadCycle|ReadLock|ReadMouse|ReadObject|ReadOnly|' + r'ReadSave|ReadTimeout|RecordMark|RecordSourceType|' + r'RecordSource|RefreshAlias|' + r'RefreshCmdDataSourceType|RefreshCmdDataSource|RefreshCmd|' + r'RefreshIgnoreFieldList|RefreshTimeStamp|RelationalExpr|' + r'RelativeColumn|RelativeRow|ReleaseType|Resizable|' + r'RespectCursorCP|RespectNesting|RightToLeft|RotateFlip|' + r'Rotation|RowColChange|RowHeight|RowSourceType|' + r'RowSource|ScaleMode|SCCProvider|SCCStatus|ScrollBars|' + r'Seconds|SelectCmd|SelectedID|' + r'SelectedItemBackColor|SelectedItemForeColor|Selected|' + r'SelectionNamespaces|SelectOnEntry|SelLength|SelStart|' + r'SelText|SendGDIPlusImage|SendUpdates|ServerClassLibrary|' + r'ServerClass|ServerHelpFile|ServerName|' + r'ServerProject|ShowTips|ShowInTaskbar|ShowWindow|' + r'Sizable|SizeBox|SOM|Sorted|Sparse|SpecialEffect|' + r'SpinnerHighValue|SpinnerLowValue|SplitBar|StackLevel|' + r'StartMode|StatusBarText|StatusBar|Stretch|StrictDateEntry|' + r'Style|TabIndex|Tables|TabOrientation|Tabs|TabStop|' + r'TabStretch|TabStyle|Tag|TerminateRead|Text|Themes|' + r'ThreadID|TimestampFieldList|TitleBar|ToolTipText|' + r'TopIndex|TopItemID|Top|TwoPassProcess|TypeLibCLSID|' + r'TypeLibDesc|TypeLibName|Type|Unicode|UpdatableFieldList|' + r'UpdateCmdDataSourceType|UpdateCmdDataSource|' + r'UpdateCmdRefreshCmd|UpdateCmdRefreshFieldList|' + r'UpdateCmdRefreshKeyFieldList|UpdateCmd|' + r'UpdateGramSchemaLocation|UpdateGram|UpdateNameList|UpdateType|' + r'UseCodePage|UseCursorSchema|UseDeDataSource|UseMemoSize|' + r'UserValue|UseTransactions|UTF8Encoded|Value|VersionComments|' + r'VersionCompany|VersionCopyright|VersionDescription|' + r'VersionNumber|VersionProduct|VersionTrademarks|Version|' + r'VFPXMLProgID|ViewPortHeight|ViewPortLeft|' + r'ViewPortTop|ViewPortWidth|VScrollSmallChange|View|Visible|' + r'VisualEffect|WhatsThisButton|WhatsThisHelpID|WhatsThisHelp|' + r'WhereType|Width|WindowList|WindowState|WindowType|WordWrap|' + r'WrapCharInCDATA|WrapInCDATA|WrapMemoInCDATA|XMLAdapter|' + r'XMLConstraints|XMLNameIsXPath|XMLNamespace|XMLName|' + r'XMLPrefix|XMLSchemaLocation|XMLTable|XMLType|' + r'XSDfractionDigits|XSDmaxLength|XSDtotalDigits|' + r'XSDtype|ZoomBox)', Name.Attribute), + + (r'\.(ActivateCell|AddColumn|AddItem|AddListItem|AddObject|' + r'AddProperty|AddTableSchema|AddToSCC|Add|' + r'ApplyDiffgram|Attach|AutoFit|AutoOpen|Box|Build|' + r'CancelReport|ChangesToCursor|CheckIn|CheckOut|Circle|' + r'CleanUp|ClearData|ClearStatus|Clear|CloneObject|CloseTables|' + r'Close|Cls|CursorAttach|CursorDetach|CursorFill|' + r'CursorRefresh|DataToClip|DelayedMemoFetch|DeleteColumn|' + r'Dock|DoMessage|DoScroll|DoStatus|DoVerb|Drag|Draw|Eval|' + r'GetData|GetDockState|GetFormat|GetKey|GetLatestVersion|' + r'GetPageHeight|GetPageWidth|Help|Hide|IncludePageInOutput|' + r'IndexToItemID|ItemIDToIndex|Item|LoadXML|Line|Modify|' + r'MoveItem|Move|Nest|OLEDrag|OnPreviewClose|OutputPage|' + r'Point|Print|PSet|Quit|ReadExpression|ReadMethod|' + r'RecordRefresh|Refresh|ReleaseXML|Release|RemoveFromSCC|' + r'RemoveItem|RemoveListItem|RemoveObject|Remove|' + r'Render|Requery|RequestData|ResetToDefault|Reset|Run|' + r'SaveAsClass|SaveAs|SetAll|SetData|SetFocus|SetFormat|' + r'SetMain|SetVar|SetViewPort|ShowWhatsThis|Show|' + r'SupportsListenerType|TextHeight|TextWidth|ToCursor|' + r'ToXML|UndoCheckOut|Unnest|UpdateStatus|WhatsThisMode|' + r'WriteExpression|WriteMethod|ZOrder)', Name.Function), + + (r'\.(Activate|AdjustObjectSize|AfterBand|AfterBuild|' + r'AfterCloseTables|AfterCursorAttach|AfterCursorClose|' + r'AfterCursorDetach|AfterCursorFill|AfterCursorRefresh|' + r'AfterCursorUpdate|AfterDelete|AfterInsert|' + r'AfterRecordRefresh|AfterUpdate|AfterDock|AfterReport|' + r'AfterRowColChange|BeforeBand|BeforeCursorAttach|' + r'BeforeCursorClose|BeforeCursorDetach|BeforeCursorFill|' + r'BeforeCursorRefresh|BeforeCursorUpdate|BeforeDelete|' + r'BeforeInsert|BeforeDock|BeforeOpenTables|' + r'BeforeRecordRefresh|BeforeReport|BeforeRowColChange|' + r'BeforeUpdate|Click|dbc_Activate|dbc_AfterAddTable|' + r'dbc_AfterAppendProc|dbc_AfterCloseTable|dbc_AfterCopyProc|' + r'dbc_AfterCreateConnection|dbc_AfterCreateOffline|' + r'dbc_AfterCreateTable|dbc_AfterCreateView|dbc_AfterDBGetProp|' + r'dbc_AfterDBSetProp|dbc_AfterDeleteConnection|' + r'dbc_AfterDropOffline|dbc_AfterDropTable|' + r'dbc_AfterModifyConnection|dbc_AfterModifyProc|' + r'dbc_AfterModifyTable|dbc_AfterModifyView|dbc_AfterOpenTable|' + r'dbc_AfterRemoveTable|dbc_AfterRenameConnection|' + r'dbc_AfterRenameTable|dbc_AfterRenameView|' + r'dbc_AfterValidateData|dbc_BeforeAddTable|' + r'dbc_BeforeAppendProc|dbc_BeforeCloseTable|' + r'dbc_BeforeCopyProc|dbc_BeforeCreateConnection|' + r'dbc_BeforeCreateOffline|dbc_BeforeCreateTable|' + r'dbc_BeforeCreateView|dbc_BeforeDBGetProp|' + r'dbc_BeforeDBSetProp|dbc_BeforeDeleteConnection|' + r'dbc_BeforeDropOffline|dbc_BeforeDropTable|' + r'dbc_BeforeModifyConnection|dbc_BeforeModifyProc|' + r'dbc_BeforeModifyTable|dbc_BeforeModifyView|' + r'dbc_BeforeOpenTable|dbc_BeforeRemoveTable|' + r'dbc_BeforeRenameConnection|dbc_BeforeRenameTable|' + r'dbc_BeforeRenameView|dbc_BeforeValidateData|' + r'dbc_CloseData|dbc_Deactivate|dbc_ModifyData|dbc_OpenData|' + r'dbc_PackData|DblClick|Deactivate|Deleted|Destroy|DoCmd|' + r'DownClick|DragDrop|DragOver|DropDown|ErrorMessage|Error|' + r'EvaluateContents|GotFocus|Init|InteractiveChange|KeyPress|' + r'LoadReport|Load|LostFocus|Message|MiddleClick|MouseDown|' + r'MouseEnter|MouseLeave|MouseMove|MouseUp|MouseWheel|Moved|' + r'OLECompleteDrag|OLEDragOver|OLEGiveFeedback|OLESetData|' + r'OLEStartDrag|OnMoveItem|Paint|ProgrammaticChange|' + r'QueryAddFile|QueryModifyFile|QueryNewFile|QueryRemoveFile|' + r'QueryRunFile|QueryUnload|RangeHigh|RangeLow|ReadActivate|' + r'ReadDeactivate|ReadShow|ReadValid|ReadWhen|Resize|' + r'RightClick|SCCInit|SCCDestroy|Scrolled|Timer|UIEnable|' + r'UnDock|UnloadReport|Unload|UpClick|Valid|When)', Name.Function), + + (r'\s+', Text), + # everything else is not colored + (r'.', Text), + ], + 'newline': [ + (r'\*.*?$', Comment.Single, '#pop'), + (r'(ACCEPT|ACTIVATE\s*MENU|ACTIVATE\s*POPUP|ACTIVATE\s*SCREEN|' + r'ACTIVATE\s*WINDOW|APPEND|APPEND\s*FROM|APPEND\s*FROM\s*ARRAY|' + r'APPEND\s*GENERAL|APPEND\s*MEMO|ASSIST|AVERAGE|BLANK|BROWSE|' + r'BUILD\s*APP|BUILD\s*EXE|BUILD\s*PROJECT|CALCULATE|CALL|' + r'CANCEL|CHANGE|CLEAR|CLOSE|CLOSE\s*MEMO|COMPILE|CONTINUE|' + r'COPY\s*FILE|COPY\s*INDEXES|COPY\s*MEMO|COPY\s*STRUCTURE|' + r'COPY\s*STRUCTURE\s*EXTENDED|COPY\s*TAG|COPY\s*TO|' + r'COPY\s*TO\s*ARRAY|COUNT|CREATE|CREATE\s*COLOR\s*SET|' + r'CREATE\s*CURSOR|CREATE\s*FROM|CREATE\s*LABEL|CREATE\s*MENU|' + r'CREATE\s*PROJECT|CREATE\s*QUERY|CREATE\s*REPORT|' + r'CREATE\s*SCREEN|CREATE\s*TABLE|CREATE\s*VIEW|DDE|' + r'DEACTIVATE\s*MENU|DEACTIVATE\s*POPUP|DEACTIVATE\s*WINDOW|' + r'DECLARE|DEFINE\s*BAR|DEFINE\s*BOX|DEFINE\s*MENU|' + r'DEFINE\s*PAD|DEFINE\s*POPUP|DEFINE\s*WINDOW|DELETE|' + r'DELETE\s*FILE|DELETE\s*TAG|DIMENSION|DIRECTORY|DISPLAY|' + r'DISPLAY\s*FILES|DISPLAY\s*MEMORY|DISPLAY\s*STATUS|' + r'DISPLAY\s*STRUCTURE|DO|EDIT|EJECT|EJECT\s*PAGE|ERASE|' + r'EXIT|EXPORT|EXTERNAL|FILER|FIND|FLUSH|FUNCTION|GATHER|' + r'GETEXPR|GO|GOTO|HELP|HIDE\s*MENU|HIDE\s*POPUP|' + r'HIDE\s*WINDOW|IMPORT|INDEX|INPUT|INSERT|JOIN|KEYBOARD|' + r'LABEL|LIST|LOAD|LOCATE|LOOP|MENU|MENU\s*TO|MODIFY\s*COMMAND|' + r'MODIFY\s*FILE|MODIFY\s*GENERAL|MODIFY\s*LABEL|MODIFY\s*MEMO|' + r'MODIFY\s*MENU|MODIFY\s*PROJECT|MODIFY\s*QUERY|' + r'MODIFY\s*REPORT|MODIFY\s*SCREEN|MODIFY\s*STRUCTURE|' + r'MODIFY\s*WINDOW|MOVE\s*POPUP|MOVE\s*WINDOW|NOTE|' + r'ON\s*APLABOUT|ON\s*BAR|ON\s*ERROR|ON\s*ESCAPE|' + r'ON\s*EXIT\s*BAR|ON\s*EXIT\s*MENU|ON\s*EXIT\s*PAD|' + r'ON\s*EXIT\s*POPUP|ON\s*KEY|ON\s*KEY\s*=|ON\s*KEY\s*LABEL|' + r'ON\s*MACHELP|ON\s*PAD|ON\s*PAGE|ON\s*READERROR|' + r'ON\s*SELECTION\s*BAR|ON\s*SELECTION\s*MENU|' + r'ON\s*SELECTION\s*PAD|ON\s*SELECTION\s*POPUP|ON\s*SHUTDOWN|' + r'PACK|PARAMETERS|PLAY\s*MACRO|POP\s*KEY|POP\s*MENU|' + r'POP\s*POPUP|PRIVATE|PROCEDURE|PUBLIC|PUSH\s*KEY|' + r'PUSH\s*MENU|PUSH\s*POPUP|QUIT|READ|READ\s*MENU|RECALL|' + r'REINDEX|RELEASE|RELEASE\s*MODULE|RENAME|REPLACE|' + r'REPLACE\s*FROM\s*ARRAY|REPORT|RESTORE\s*FROM|' + r'RESTORE\s*MACROS|RESTORE\s*SCREEN|RESTORE\s*WINDOW|' + r'RESUME|RETRY|RETURN|RUN|RUN\s*\/N"|RUNSCRIPT|' + r'SAVE\s*MACROS|SAVE\s*SCREEN|SAVE\s*TO|SAVE\s*WINDOWS|' + r'SCATTER|SCROLL|SEEK|SELECT|SET|SET\s*ALTERNATE|' + r'SET\s*ANSI|SET\s*APLABOUT|SET\s*AUTOSAVE|SET\s*BELL|' + r'SET\s*BLINK|SET\s*BLOCKSIZE|SET\s*BORDER|SET\s*BRSTATUS|' + r'SET\s*CARRY|SET\s*CENTURY|SET\s*CLEAR|SET\s*CLOCK|' + r'SET\s*COLLATE|SET\s*COLOR\s*OF|SET\s*COLOR\s*OF\s*SCHEME|' + r'SET\s*COLOR\s*SET|SET\s*COLOR\s*TO|SET\s*COMPATIBLE|' + r'SET\s*CONFIRM|SET\s*CONSOLE|SET\s*CURRENCY|SET\s*CURSOR|' + r'SET\s*DATE|SET\s*DEBUG|SET\s*DECIMALS|SET\s*DEFAULT|' + r'SET\s*DELETED|SET\s*DELIMITERS|SET\s*DEVELOPMENT|' + r'SET\s*DEVICE|SET\s*DISPLAY|SET\s*DOHISTORY|SET\s*ECHO|' + r'SET\s*ESCAPE|SET\s*EXACT|SET\s*EXCLUSIVE|SET\s*FIELDS|' + r'SET\s*FILTER|SET\s*FIXED|SET\s*FORMAT|SET\s*FULLPATH|' + r'SET\s*FUNCTION|SET\s*HEADINGS|SET\s*HELP|SET\s*HELPFILTER|' + r'SET\s*HOURS|SET\s*INDEX|SET\s*INTENSITY|SET\s*KEY|' + r'SET\s*KEYCOMP|SET\s*LIBRARY|SET\s*LOCK|SET\s*LOGERRORS|' + r'SET\s*MACDESKTOP|SET\s*MACHELP|SET\s*MACKEY|SET\s*MARGIN|' + r'SET\s*MARK\s*OF|SET\s*MARK\s*TO|SET\s*MEMOWIDTH|' + r'SET\s*MESSAGE|SET\s*MOUSE|SET\s*MULTILOCKS|SET\s*NEAR|' + r'SET\s*NOCPTRANS|SET\s*NOTIFY|SET\s*ODOMETER|SET\s*OPTIMIZE|' + r'SET\s*ORDER|SET\s*PALETTE|SET\s*PATH|SET\s*PDSETUP|' + r'SET\s*POINT|SET\s*PRINTER|SET\s*PROCEDURE|SET\s*READBORDER|' + r'SET\s*REFRESH|SET\s*RELATION|SET\s*RELATION\s*OFF|' + r'SET\s*REPROCESS|SET\s*RESOURCE|SET\s*SAFETY|SET\s*SCOREBOARD|' + r'SET\s*SEPARATOR|SET\s*SHADOWS|SET\s*SKIP|SET\s*SKIP\s*OF|' + r'SET\s*SPACE|SET\s*STATUS|SET\s*STATUS\s*BAR|SET\s*STEP|' + r'SET\s*STICKY|SET\s*SYSMENU|SET\s*TALK|SET\s*TEXTMERGE|' + r'SET\s*TEXTMERGE\s*DELIMITERS|SET\s*TOPIC|SET\s*TRBETWEEN|' + r'SET\s*TYPEAHEAD|SET\s*UDFPARMS|SET\s*UNIQUE|SET\s*VIEW|' + r'SET\s*VOLUME|SET\s*WINDOW\s*OF\s*MEMO|SET\s*XCMDFILE|' + r'SHOW\s*GET|SHOW\s*GETS|SHOW\s*MENU|SHOW\s*OBJECT|' + r'SHOW\s*POPUP|SHOW\s*WINDOW|SIZE\s*POPUP|SKIP|SORT|' + r'STORE|SUM|SUSPEND|TOTAL|TYPE|UNLOCK|UPDATE|USE|WAIT|' + r'ZAP|ZOOM\s*WINDOW|DO\s*CASE|CASE|OTHERWISE|ENDCASE|' + r'DO\s*WHILE|ENDDO|FOR|ENDFOR|NEXT|IF|ELSE|ENDIF|PRINTJOB|' + r'ENDPRINTJOB|SCAN|ENDSCAN|TEXT|ENDTEXT|=)', + Keyword.Reserved, '#pop'), + (r'#\s*(IF|ELIF|ELSE|ENDIF|DEFINE|IFDEF|IFNDEF|INCLUDE)', + Comment.Preproc, '#pop'), + (r'(m\.)?[a-z_]\w*', Name.Variable, '#pop'), + (r'.', Text, '#pop'), + ], + } diff --git a/pygments/lexers/freefem.py b/pygments/lexers/freefem.py new file mode 100644 index 0000000000000000000000000000000000000000..87b066b1a0dadf89194162427c61706187a33248 --- /dev/null +++ b/pygments/lexers/freefem.py @@ -0,0 +1,893 @@ +""" + pygments.lexers.freefem + ~~~~~~~~~~~~~~~~~~~~~~~ + + Lexer for FreeFem++ language. + + :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS. + :license: BSD, see LICENSE for details. +""" + +from pygments.token import Comment, Operator, Keyword, Name + +from pygments.lexers.c_cpp import CppLexer + +__all__ = ['FreeFemLexer'] + + +class FreeFemLexer(CppLexer): + """ + For FreeFem++ source. + + This is an extension of the CppLexer, as the FreeFem Language is a superset + of C++. + """ + + name = 'Freefem' + url = 'https://freefem.org/' + aliases = ['freefem'] + filenames = ['*.edp'] + mimetypes = ['text/x-freefem'] + version_added = '2.4' + + # Language operators + operators = {'+', '-', '*', '.*', '/', './', '%', '^', '^-1', ':', '\''} + + # types + types = {'bool', 'border', 'complex', 'dmatrix', 'fespace', 'func', 'gslspline', + 'ifstream', 'int', 'macro', 'matrix', 'mesh', 'mesh3', 'mpiComm', + 'mpiGroup', 'mpiRequest', 'NewMacro', 'EndMacro', 'ofstream', 'Pmmap', + 'problem', 'Psemaphore', 'real', 'solve', 'string', 'varf'} + + # finite element spaces + fespaces = {'BDM1', 'BDM1Ortho', 'Edge03d', 'Edge13d', 'Edge23d', 'FEQF', 'HCT', + 'P0', 'P03d', 'P0Edge', 'P1', 'P13d', 'P1b', 'P1b3d', 'P1bl', 'P1bl3d', + 'P1dc', 'P1Edge', 'P1nc', 'P2', 'P23d', 'P2b', 'P2BR', 'P2dc', 'P2Edge', + 'P2h', 'P2Morley', 'P2pnc', 'P3', 'P3dc', 'P3Edge', 'P4', 'P4dc', + 'P4Edge', 'P5Edge', 'RT0', 'RT03d', 'RT0Ortho', 'RT1', 'RT1Ortho', + 'RT2', 'RT2Ortho'} + + # preprocessor + preprocessor = {'ENDIFMACRO', 'include', 'IFMACRO', 'load'} + + # Language keywords + keywords = { + 'adj', + 'append', + 'area', + 'ARGV', + 'be', + 'binary', + 'BoundaryEdge', + 'bordermeasure', + 'CG', + 'Cholesky', + 'cin', + 'cout', + 'Crout', + 'default', + 'diag', + 'edgeOrientation', + 'endl', + 'false', + 'ffind', + 'FILE', + 'find', + 'fixed', + 'flush', + 'GMRES', + 'good', + 'hTriangle', + 'im', + 'imax', + 'imin', + 'InternalEdge', + 'l1', + 'l2', + 'label', + 'lenEdge', + 'length', + 'LINE', + 'linfty', + 'LU', + 'm', + 'max', + 'measure', + 'min', + 'mpiAnySource', + 'mpiBAND', + 'mpiBXOR', + 'mpiCommWorld', + 'mpiLAND', + 'mpiLOR', + 'mpiLXOR', + 'mpiMAX', + 'mpiMIN', + 'mpiPROD', + 'mpirank', + 'mpisize', + 'mpiSUM', + 'mpiUndefined', + 'n', + 'N', + 'nbe', + 'ndof', + 'ndofK', + 'noshowbase', + 'noshowpos', + 'notaregion', + 'nt', + 'nTonEdge', + 'nuEdge', + 'nuTriangle', + 'nv', + 'P', + 'pi', + 'precision', + 'qf1pE', + 'qf1pElump', + 'qf1pT', + 'qf1pTlump', + 'qfV1', + 'qfV1lump', + 'qf2pE', + 'qf2pT', + 'qf2pT4P1', + 'qfV2', + 'qf3pE', + 'qf4pE', + 'qf5pE', + 'qf5pT', + 'qfV5', + 'qf7pT', + 'qf9pT', + 'qfnbpE', + 'quantile', + 're', + 'region', + 'rfind', + 'scientific', + 'searchMethod', + 'setw', + 'showbase', + 'showpos', + 'sparsesolver', + 'sum', + 'tellp', + 'true', + 'UMFPACK', + 'unused', + 'whoinElement', + 'verbosity', + 'version', + 'volume', + 'x', + 'y', + 'z' + } + + # Language shipped functions and class ( ) + functions = { + 'abs', + 'acos', + 'acosh', + 'adaptmesh', + 'adj', + 'AffineCG', + 'AffineGMRES', + 'arg', + 'asin', + 'asinh', + 'assert', + 'atan', + 'atan2', + 'atanh', + 'atof', + 'atoi', + 'BFGS', + 'broadcast', + 'buildlayers', + 'buildmesh', + 'ceil', + 'chi', + 'complexEigenValue', + 'copysign', + 'change', + 'checkmovemesh', + 'clock', + 'cmaes', + 'conj', + 'convect', + 'cos', + 'cosh', + 'cube', + 'd', + 'dd', + 'dfft', + 'diffnp', + 'diffpos', + 'dimKrylov', + 'dist', + 'dumptable', + 'dx', + 'dxx', + 'dxy', + 'dxz', + 'dy', + 'dyx', + 'dyy', + 'dyz', + 'dz', + 'dzx', + 'dzy', + 'dzz', + 'EigenValue', + 'emptymesh', + 'erf', + 'erfc', + 'exec', + 'exit', + 'exp', + 'fdim', + 'floor', + 'fmax', + 'fmin', + 'fmod', + 'freeyams', + 'getARGV', + 'getline', + 'gmshload', + 'gmshload3', + 'gslcdfugaussianP', + 'gslcdfugaussianQ', + 'gslcdfugaussianPinv', + 'gslcdfugaussianQinv', + 'gslcdfgaussianP', + 'gslcdfgaussianQ', + 'gslcdfgaussianPinv', + 'gslcdfgaussianQinv', + 'gslcdfgammaP', + 'gslcdfgammaQ', + 'gslcdfgammaPinv', + 'gslcdfgammaQinv', + 'gslcdfcauchyP', + 'gslcdfcauchyQ', + 'gslcdfcauchyPinv', + 'gslcdfcauchyQinv', + 'gslcdflaplaceP', + 'gslcdflaplaceQ', + 'gslcdflaplacePinv', + 'gslcdflaplaceQinv', + 'gslcdfrayleighP', + 'gslcdfrayleighQ', + 'gslcdfrayleighPinv', + 'gslcdfrayleighQinv', + 'gslcdfchisqP', + 'gslcdfchisqQ', + 'gslcdfchisqPinv', + 'gslcdfchisqQinv', + 'gslcdfexponentialP', + 'gslcdfexponentialQ', + 'gslcdfexponentialPinv', + 'gslcdfexponentialQinv', + 'gslcdfexppowP', + 'gslcdfexppowQ', + 'gslcdftdistP', + 'gslcdftdistQ', + 'gslcdftdistPinv', + 'gslcdftdistQinv', + 'gslcdffdistP', + 'gslcdffdistQ', + 'gslcdffdistPinv', + 'gslcdffdistQinv', + 'gslcdfbetaP', + 'gslcdfbetaQ', + 'gslcdfbetaPinv', + 'gslcdfbetaQinv', + 'gslcdfflatP', + 'gslcdfflatQ', + 'gslcdfflatPinv', + 'gslcdfflatQinv', + 'gslcdflognormalP', + 'gslcdflognormalQ', + 'gslcdflognormalPinv', + 'gslcdflognormalQinv', + 'gslcdfgumbel1P', + 'gslcdfgumbel1Q', + 'gslcdfgumbel1Pinv', + 'gslcdfgumbel1Qinv', + 'gslcdfgumbel2P', + 'gslcdfgumbel2Q', + 'gslcdfgumbel2Pinv', + 'gslcdfgumbel2Qinv', + 'gslcdfweibullP', + 'gslcdfweibullQ', + 'gslcdfweibullPinv', + 'gslcdfweibullQinv', + 'gslcdfparetoP', + 'gslcdfparetoQ', + 'gslcdfparetoPinv', + 'gslcdfparetoQinv', + 'gslcdflogisticP', + 'gslcdflogisticQ', + 'gslcdflogisticPinv', + 'gslcdflogisticQinv', + 'gslcdfbinomialP', + 'gslcdfbinomialQ', + 'gslcdfpoissonP', + 'gslcdfpoissonQ', + 'gslcdfgeometricP', + 'gslcdfgeometricQ', + 'gslcdfnegativebinomialP', + 'gslcdfnegativebinomialQ', + 'gslcdfpascalP', + 'gslcdfpascalQ', + 'gslinterpakima', + 'gslinterpakimaperiodic', + 'gslinterpcsplineperiodic', + 'gslinterpcspline', + 'gslinterpsteffen', + 'gslinterplinear', + 'gslinterppolynomial', + 'gslranbernoullipdf', + 'gslranbeta', + 'gslranbetapdf', + 'gslranbinomialpdf', + 'gslranexponential', + 'gslranexponentialpdf', + 'gslranexppow', + 'gslranexppowpdf', + 'gslrancauchy', + 'gslrancauchypdf', + 'gslranchisq', + 'gslranchisqpdf', + 'gslranerlang', + 'gslranerlangpdf', + 'gslranfdist', + 'gslranfdistpdf', + 'gslranflat', + 'gslranflatpdf', + 'gslrangamma', + 'gslrangammaint', + 'gslrangammapdf', + 'gslrangammamt', + 'gslrangammaknuth', + 'gslrangaussian', + 'gslrangaussianratiomethod', + 'gslrangaussianziggurat', + 'gslrangaussianpdf', + 'gslranugaussian', + 'gslranugaussianratiomethod', + 'gslranugaussianpdf', + 'gslrangaussiantail', + 'gslrangaussiantailpdf', + 'gslranugaussiantail', + 'gslranugaussiantailpdf', + 'gslranlandau', + 'gslranlandaupdf', + 'gslrangeometricpdf', + 'gslrangumbel1', + 'gslrangumbel1pdf', + 'gslrangumbel2', + 'gslrangumbel2pdf', + 'gslranlogistic', + 'gslranlogisticpdf', + 'gslranlognormal', + 'gslranlognormalpdf', + 'gslranlogarithmicpdf', + 'gslrannegativebinomialpdf', + 'gslranpascalpdf', + 'gslranpareto', + 'gslranparetopdf', + 'gslranpoissonpdf', + 'gslranrayleigh', + 'gslranrayleighpdf', + 'gslranrayleightail', + 'gslranrayleightailpdf', + 'gslrantdist', + 'gslrantdistpdf', + 'gslranlaplace', + 'gslranlaplacepdf', + 'gslranlevy', + 'gslranweibull', + 'gslranweibullpdf', + 'gslsfairyAi', + 'gslsfairyBi', + 'gslsfairyAiscaled', + 'gslsfairyBiscaled', + 'gslsfairyAideriv', + 'gslsfairyBideriv', + 'gslsfairyAiderivscaled', + 'gslsfairyBiderivscaled', + 'gslsfairyzeroAi', + 'gslsfairyzeroBi', + 'gslsfairyzeroAideriv', + 'gslsfairyzeroBideriv', + 'gslsfbesselJ0', + 'gslsfbesselJ1', + 'gslsfbesselJn', + 'gslsfbesselY0', + 'gslsfbesselY1', + 'gslsfbesselYn', + 'gslsfbesselI0', + 'gslsfbesselI1', + 'gslsfbesselIn', + 'gslsfbesselI0scaled', + 'gslsfbesselI1scaled', + 'gslsfbesselInscaled', + 'gslsfbesselK0', + 'gslsfbesselK1', + 'gslsfbesselKn', + 'gslsfbesselK0scaled', + 'gslsfbesselK1scaled', + 'gslsfbesselKnscaled', + 'gslsfbesselj0', + 'gslsfbesselj1', + 'gslsfbesselj2', + 'gslsfbesseljl', + 'gslsfbessely0', + 'gslsfbessely1', + 'gslsfbessely2', + 'gslsfbesselyl', + 'gslsfbesseli0scaled', + 'gslsfbesseli1scaled', + 'gslsfbesseli2scaled', + 'gslsfbesselilscaled', + 'gslsfbesselk0scaled', + 'gslsfbesselk1scaled', + 'gslsfbesselk2scaled', + 'gslsfbesselklscaled', + 'gslsfbesselJnu', + 'gslsfbesselYnu', + 'gslsfbesselInuscaled', + 'gslsfbesselInu', + 'gslsfbesselKnuscaled', + 'gslsfbesselKnu', + 'gslsfbessellnKnu', + 'gslsfbesselzeroJ0', + 'gslsfbesselzeroJ1', + 'gslsfbesselzeroJnu', + 'gslsfclausen', + 'gslsfhydrogenicR1', + 'gslsfdawson', + 'gslsfdebye1', + 'gslsfdebye2', + 'gslsfdebye3', + 'gslsfdebye4', + 'gslsfdebye5', + 'gslsfdebye6', + 'gslsfdilog', + 'gslsfmultiply', + 'gslsfellintKcomp', + 'gslsfellintEcomp', + 'gslsfellintPcomp', + 'gslsfellintDcomp', + 'gslsfellintF', + 'gslsfellintE', + 'gslsfellintRC', + 'gslsferfc', + 'gslsflogerfc', + 'gslsferf', + 'gslsferfZ', + 'gslsferfQ', + 'gslsfhazard', + 'gslsfexp', + 'gslsfexpmult', + 'gslsfexpm1', + 'gslsfexprel', + 'gslsfexprel2', + 'gslsfexpreln', + 'gslsfexpintE1', + 'gslsfexpintE2', + 'gslsfexpintEn', + 'gslsfexpintE1scaled', + 'gslsfexpintE2scaled', + 'gslsfexpintEnscaled', + 'gslsfexpintEi', + 'gslsfexpintEiscaled', + 'gslsfShi', + 'gslsfChi', + 'gslsfexpint3', + 'gslsfSi', + 'gslsfCi', + 'gslsfatanint', + 'gslsffermidiracm1', + 'gslsffermidirac0', + 'gslsffermidirac1', + 'gslsffermidirac2', + 'gslsffermidiracint', + 'gslsffermidiracmhalf', + 'gslsffermidirachalf', + 'gslsffermidirac3half', + 'gslsffermidiracinc0', + 'gslsflngamma', + 'gslsfgamma', + 'gslsfgammastar', + 'gslsfgammainv', + 'gslsftaylorcoeff', + 'gslsffact', + 'gslsfdoublefact', + 'gslsflnfact', + 'gslsflndoublefact', + 'gslsflnchoose', + 'gslsfchoose', + 'gslsflnpoch', + 'gslsfpoch', + 'gslsfpochrel', + 'gslsfgammaincQ', + 'gslsfgammaincP', + 'gslsfgammainc', + 'gslsflnbeta', + 'gslsfbeta', + 'gslsfbetainc', + 'gslsfgegenpoly1', + 'gslsfgegenpoly2', + 'gslsfgegenpoly3', + 'gslsfgegenpolyn', + 'gslsfhyperg0F1', + 'gslsfhyperg1F1int', + 'gslsfhyperg1F1', + 'gslsfhypergUint', + 'gslsfhypergU', + 'gslsfhyperg2F0', + 'gslsflaguerre1', + 'gslsflaguerre2', + 'gslsflaguerre3', + 'gslsflaguerren', + 'gslsflambertW0', + 'gslsflambertWm1', + 'gslsflegendrePl', + 'gslsflegendreP1', + 'gslsflegendreP2', + 'gslsflegendreP3', + 'gslsflegendreQ0', + 'gslsflegendreQ1', + 'gslsflegendreQl', + 'gslsflegendrePlm', + 'gslsflegendresphPlm', + 'gslsflegendrearraysize', + 'gslsfconicalPhalf', + 'gslsfconicalPmhalf', + 'gslsfconicalP0', + 'gslsfconicalP1', + 'gslsfconicalPsphreg', + 'gslsfconicalPcylreg', + 'gslsflegendreH3d0', + 'gslsflegendreH3d1', + 'gslsflegendreH3d', + 'gslsflog', + 'gslsflogabs', + 'gslsflog1plusx', + 'gslsflog1plusxmx', + 'gslsfpowint', + 'gslsfpsiint', + 'gslsfpsi', + 'gslsfpsi1piy', + 'gslsfpsi1int', + 'gslsfpsi1', + 'gslsfpsin', + 'gslsfsynchrotron1', + 'gslsfsynchrotron2', + 'gslsftransport2', + 'gslsftransport3', + 'gslsftransport4', + 'gslsftransport5', + 'gslsfsin', + 'gslsfcos', + 'gslsfhypot', + 'gslsfsinc', + 'gslsflnsinh', + 'gslsflncosh', + 'gslsfanglerestrictsymm', + 'gslsfanglerestrictpos', + 'gslsfzetaint', + 'gslsfzeta', + 'gslsfzetam1', + 'gslsfzetam1int', + 'gslsfhzeta', + 'gslsfetaint', + 'gslsfeta', + 'imag', + 'int1d', + 'int2d', + 'int3d', + 'intalledges', + 'intallfaces', + 'interpolate', + 'invdiff', + 'invdiffnp', + 'invdiffpos', + 'Isend', + 'isInf', + 'isNaN', + 'isoline', + 'Irecv', + 'j0', + 'j1', + 'jn', + 'jump', + 'lgamma', + 'LinearCG', + 'LinearGMRES', + 'log', + 'log10', + 'lrint', + 'lround', + 'max', + 'mean', + 'medit', + 'min', + 'mmg3d', + 'movemesh', + 'movemesh23', + 'mpiAlltoall', + 'mpiAlltoallv', + 'mpiAllgather', + 'mpiAllgatherv', + 'mpiAllReduce', + 'mpiBarrier', + 'mpiGather', + 'mpiGatherv', + 'mpiRank', + 'mpiReduce', + 'mpiScatter', + 'mpiScatterv', + 'mpiSize', + 'mpiWait', + 'mpiWaitAny', + 'mpiWtick', + 'mpiWtime', + 'mshmet', + 'NaN', + 'NLCG', + 'on', + 'plot', + 'polar', + 'Post', + 'pow', + 'processor', + 'processorblock', + 'projection', + 'randinit', + 'randint31', + 'randint32', + 'random', + 'randreal1', + 'randreal2', + 'randreal3', + 'randres53', + 'Read', + 'readmesh', + 'readmesh3', + 'Recv', + 'rint', + 'round', + 'savemesh', + 'savesol', + 'savevtk', + 'seekg', + 'Sent', + 'set', + 'sign', + 'signbit', + 'sin', + 'sinh', + 'sort', + 'splitComm', + 'splitmesh', + 'sqrt', + 'square', + 'srandom', + 'srandomdev', + 'Stringification', + 'swap', + 'system', + 'tan', + 'tanh', + 'tellg', + 'tetg', + 'tetgconvexhull', + 'tetgreconstruction', + 'tetgtransfo', + 'tgamma', + 'triangulate', + 'trunc', + 'Wait', + 'Write', + 'y0', + 'y1', + 'yn' + } + + # function parameters + parameters = { + 'A', + 'A1', + 'abserror', + 'absolute', + 'aniso', + 'aspectratio', + 'B', + 'B1', + 'bb', + 'beginend', + 'bin', + 'boundary', + 'bw', + 'close', + 'cmm', + 'coef', + 'composante', + 'cutoff', + 'datafilename', + 'dataname', + 'dim', + 'distmax', + 'displacement', + 'doptions', + 'dparams', + 'eps', + 'err', + 'errg', + 'facemerge', + 'facetcl', + 'factorize', + 'file', + 'fill', + 'fixedborder', + 'flabel', + 'flags', + 'floatmesh', + 'floatsol', + 'fregion', + 'gradation', + 'grey', + 'hmax', + 'hmin', + 'holelist', + 'hsv', + 'init', + 'inquire', + 'inside', + 'IsMetric', + 'iso', + 'ivalue', + 'keepbackvertices', + 'label', + 'labeldown', + 'labelmid', + 'labelup', + 'levelset', + 'loptions', + 'lparams', + 'maxit', + 'maxsubdiv', + 'meditff', + 'mem', + 'memory', + 'metric', + 'mode', + 'nbarrow', + 'nbiso', + 'nbiter', + 'nbjacoby', + 'nboffacetcl', + 'nbofholes', + 'nbofregions', + 'nbregul', + 'nbsmooth', + 'nbvx', + 'ncv', + 'nev', + 'nomeshgeneration', + 'normalization', + 'omega', + 'op', + 'optimize', + 'option', + 'options', + 'order', + 'orientation', + 'periodic', + 'power', + 'precon', + 'prev', + 'ps', + 'ptmerge', + 'qfe', + 'qforder', + 'qft', + 'qfV', + 'ratio', + 'rawvector', + 'reffacelow', + 'reffacemid', + 'reffaceup', + 'refnum', + 'reftet', + 'reftri', + 'region', + 'regionlist', + 'renumv', + 'rescaling', + 'ridgeangle', + 'save', + 'sigma', + 'sizeofvolume', + 'smoothing', + 'solver', + 'sparams', + 'split', + 'splitin2', + 'splitpbedge', + 'stop', + 'strategy', + 'swap', + 'switch', + 'sym', + 't', + 'tgv', + 'thetamax', + 'tol', + 'tolpivot', + 'tolpivotsym', + 'transfo', + 'U2Vc', + 'value', + 'varrow', + 'vector', + 'veps', + 'viso', + 'wait', + 'width', + 'withsurfacemesh', + 'WindowIndex', + 'which', + 'zbound' + } + + # deprecated + deprecated = {'fixeborder'} + + # do not highlight + suppress_highlight = { + 'alignof', + 'asm', + 'constexpr', + 'decltype', + 'div', + 'double', + 'grad', + 'mutable', + 'namespace', + 'noexcept', + 'restrict', + 'static_assert', + 'template', + 'this', + 'thread_local', + 'typeid', + 'typename', + 'using' + } + + def get_tokens_unprocessed(self, text, stack=('root',)): + for index, token, value in CppLexer.get_tokens_unprocessed(self, text, stack): + if value in self.operators: + yield index, Operator, value + elif value in self.types: + yield index, Keyword.Type, value + elif value in self.fespaces: + yield index, Name.Class, value + elif value in self.preprocessor: + yield index, Comment.Preproc, value + elif value in self.keywords: + yield index, Keyword.Reserved, value + elif value in self.functions: + yield index, Name.Function, value + elif value in self.parameters: + yield index, Keyword.Pseudo, value + elif value in self.suppress_highlight: + yield index, Name, value + else: + yield index, token, value diff --git a/pygments/lexers/func.py b/pygments/lexers/func.py new file mode 100644 index 0000000000000000000000000000000000000000..d8c912a3e0899070a51e7af1565f0e18aec2a86b --- /dev/null +++ b/pygments/lexers/func.py @@ -0,0 +1,110 @@ +""" + pygments.lexers.func + ~~~~~~~~~~~~~~~~~~~~ + + Lexers for FunC. + + :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS. + :license: BSD, see LICENSE for details. +""" + +from pygments.lexer import RegexLexer, include, words +from pygments.token import Text, Comment, Operator, Keyword, Name, String, \ + Number, Whitespace, Punctuation + +__all__ = ['FuncLexer'] + + +class FuncLexer(RegexLexer): + """ + For FunC source code. + """ + + name = 'FunC' + aliases = ['func', 'fc'] + filenames = ['*.fc', '*.func'] + url = 'https://docs.ton.org/develop/func/overview' + version_added = '' + + # 1. Does not start from " + # 2. Can start from ` and end with `, containing any character + # 3. Starts with underscore or { or } and have more than 1 character after it + # 4. Starts with letter, contains letters, numbers and underscores + identifier = r'(?!")(`([^`]+)`|((?=_)_|(?=\{)\{|(?=\})\}|(?![_`{}]))([^;,\[\]\(\)\s~.]+))' + + tokens = { + 'root': [ + (r'\n', Whitespace), + (r'\s+', Whitespace), + + include('keywords'), + include('strings'), + include('directives'), + include('numeric'), + include('comments'), + include('storage'), + include('functions'), + include('variables'), + + (r'[.;(),\[\]~{}]', Punctuation) + ], + 'keywords': [ + (words(( + '<=>', '>=', '<=', '!=', '==', '^>>', '~>>', + '>>', '<<', '/%', '^%', '~%', '^/', '~/', '+=', + '-=', '*=', '/=', '~/=', '^/=', '%=', '^%=', '<<=', + '>>=', '~>>=', '^>>=', '&=', '|=', '^=', '^', '=', + '~', '/', '%', '-', '*', '+','>', + '<', '&', '|', ':', '?'), prefix=r'(?<=\s)', suffix=r'(?=\s)'), + Operator), + (words(( + 'if', 'ifnot', + 'else', 'elseif', 'elseifnot', + 'while', 'do', 'until', 'repeat', + 'return', 'impure', 'method_id', + 'forall', 'asm', 'inline', 'inline_ref'), prefix=r'\b', suffix=r'\b'), + Keyword), + (words(('true', 'false'), prefix=r'\b', suffix=r'\b'), Keyword.Constant), + ], + 'directives': [ + (r'#include|#pragma', Keyword, 'directive'), + ], + 'directive': [ + include('strings'), + (r'\s+', Whitespace), + (r'version|not-version', Keyword), + (r'(>=|<=|=|>|<|\^)?([0-9]+)(.[0-9]+)?(.[0-9]+)?', Number), # version + (r';', Text, '#pop') + ], + 'strings': [ + (r'\"([^\n\"]+)\"[Hhcusa]?', String), + ], + 'numeric': [ + (r'\b(-?(?!_)([\d_]+|0x[\d_a-fA-F]+)|0b[1_0]+)(?<\|&\^' + + tokens = { + 'root': [ + (r'--(.*?)$', Comment.Single), + (r'\s+', Whitespace), + (r'\(\)', Punctuation), + (r'\b({})(?!\')\b'.format('|'.join(reserved)), Keyword.Reserved), + (r'\b({})(?!\')\b'.format('|'.join(num_types + other_types)), Keyword.Type), + + # Identifiers + (r'#\[([a-zA-Z_\(\) ]*)\]', Comment.Preproc), + (rf'[#!]?({identifier_re}\.)*{identifier_re}', Name), + + (r'\\', Operator), + (r'[-+/%=!><|&*^][-+/%=!><|&*^.]*', Operator), + (r'[][(),:;`{}?.\'~^]', Punctuation), + + # Numbers + (r'0[xX]_*[\da-fA-F](_*[\da-fA-F])*_*[pP][+-]?\d(_*\d)*' + num_postfix, + Number.Float), + (r'0[xX]_*[\da-fA-F](_*[\da-fA-F])*\.[\da-fA-F](_*[\da-fA-F])*' + r'(_*[pP][+-]?\d(_*\d)*)?' + num_postfix, Number.Float), + (r'\d(_*\d)*_*[eE][+-]?\d(_*\d)*' + num_postfix, Number.Float), + (r'\d(_*\d)*\.\d(_*\d)*(_*[eE][+-]?\d(_*\d)*)?' + num_postfix, Number.Float), + (r'0[bB]_*[01](_*[01])*' + num_postfix, Number.Bin), + (r'0[xX]_*[\da-fA-F](_*[\da-fA-F])*' + num_postfix, Number.Hex), + (r'\d(_*\d)*' + num_postfix, Number.Integer), + + # Character/String Literals + (r"'", String.Char, 'character'), + (r'"', String, 'string'), + # Special + (r'\[[a-zA-Z_\d]*\]', Keyword.Type), + (r'\(\)', Name.Builtin), + ], + 'character': [ + # Allows multi-chars, incorrectly. + (r"[^\\']'", String.Char, '#pop'), + (r"\\", String.Escape, 'escape'), + ("'", String.Char, '#pop'), + ], + 'string': [ + (r'[^\\"]+', String), + (r"\\", String.Escape, 'escape'), + ('"', String, '#pop'), + ], + + 'escape': [ + (r'[abfnrtv"\'&\\]', String.Escape, '#pop'), + (r'\^[][' + uni.Lu + r'@^_]', String.Escape, '#pop'), + ('|'.join(ascii), String.Escape, '#pop'), + (r'o[0-7]+', String.Escape, '#pop'), + (r'x[\da-fA-F]+', String.Escape, '#pop'), + (r'\d+', String.Escape, '#pop'), + (r'(\s+)(\\)', bygroups(Whitespace, String.Escape), '#pop'), + ], + } diff --git a/pygments/lexers/gcodelexer.py b/pygments/lexers/gcodelexer.py new file mode 100644 index 0000000000000000000000000000000000000000..9e968bd6b67e17fe28175e002037cbf50846088c --- /dev/null +++ b/pygments/lexers/gcodelexer.py @@ -0,0 +1,35 @@ +""" + pygments.lexers.gcodelexer + ~~~~~~~~~~~~~~~~~~~~~~~~~~ + + Lexers for the G Code Language. + + :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS. + :license: BSD, see LICENSE for details. +""" + +from pygments.lexer import RegexLexer, bygroups +from pygments.token import Comment, Name, Text, Keyword, Number + +__all__ = ['GcodeLexer'] + + +class GcodeLexer(RegexLexer): + """ + For gcode source code. + """ + name = 'g-code' + aliases = ['gcode'] + filenames = ['*.gcode'] + url = 'https://en.wikipedia.org/wiki/G-code' + version_added = '2.9' + + tokens = { + 'root': [ + (r';.*\n', Comment), + (r'^[gmGM]\d{1,4}\s', Name.Builtin), # M or G commands + (r'([^gGmM])([+-]?\d*[.]?\d+)', bygroups(Keyword, Number)), + (r'\s', Text.Whitespace), + (r'.*\n', Text), + ] + } diff --git a/pygments/lexers/gdscript.py b/pygments/lexers/gdscript.py new file mode 100644 index 0000000000000000000000000000000000000000..6dfe589767241121098ce171cea2de604ae78d23 --- /dev/null +++ b/pygments/lexers/gdscript.py @@ -0,0 +1,189 @@ +""" + pygments.lexers.gdscript + ~~~~~~~~~~~~~~~~~~~~~~~~ + + Lexer for GDScript. + + Modified by Daniel J. Ramirez based on the original + python.py. + + :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS. + :license: BSD, see LICENSE for details. +""" + +import re + +from pygments.lexer import RegexLexer, include, bygroups, default, words, \ + combined +from pygments.token import Text, Comment, Operator, Keyword, Name, String, \ + Number, Punctuation, Whitespace + +__all__ = ["GDScriptLexer"] + + +class GDScriptLexer(RegexLexer): + """ + For GDScript source code. + """ + + name = "GDScript" + url = 'https://www.godotengine.org' + aliases = ["gdscript", "gd"] + filenames = ["*.gd"] + mimetypes = ["text/x-gdscript", "application/x-gdscript"] + version_added = '' + + def innerstring_rules(ttype): + return [ + # the old style '%s' % (...) string formatting + (r"%(\(\w+\))?[-#0 +]*([0-9]+|[*])?(\.([0-9]+|[*]))?" + "[hlL]?[E-GXc-giorsux%]", + String.Interpol), + # backslashes, quotes and formatting signs must be parsed one at a time + (r'[^\\\'"%\n]+', ttype), + (r'[\'"\\]', ttype), + # unhandled string formatting sign + (r"%", ttype), + # newlines are an error (use "nl" state) + ] + + tokens = { + "root": [ + (r"\n", Whitespace), + (r'^(\s*)([rRuUbB]{,2})("""(?:.|\n)*?""")', + bygroups(Whitespace, String.Affix, String.Doc)), + (r"^(\s*)([rRuUbB]{,2})('''(?:.|\n)*?''')", + bygroups(Whitespace, String.Affix, String.Doc)), + (r"[^\S\n]+", Whitespace), + (r"#.*$", Comment.Single), + (r"[]{}:(),;[]", Punctuation), + (r"(\\)(\n)", bygroups(Text, Whitespace)), + (r"\\", Text), + (r"(in|and|or|not)\b", Operator.Word), + (r"!=|==|<<|>>|&&|\+=|-=|\*=|/=|%=|&=|\|=|\|\||[-~+/*%=<>&^.!|$]", + Operator), + include("keywords"), + (r"(func)(\s+)", bygroups(Keyword, Whitespace), "funcname"), + (r"(class)(\s+)", bygroups(Keyword, Whitespace), "classname"), + include("builtins"), + ('([rR]|[uUbB][rR]|[rR][uUbB])(""")', + bygroups(String.Affix, String.Double), + "tdqs"), + ("([rR]|[uUbB][rR]|[rR][uUbB])(''')", + bygroups(String.Affix, String.Single), + "tsqs"), + ('([rR]|[uUbB][rR]|[rR][uUbB])(")', + bygroups(String.Affix, String.Double), + "dqs"), + ("([rR]|[uUbB][rR]|[rR][uUbB])(')", + bygroups(String.Affix, String.Single), + "sqs"), + ('([uUbB]?)(""")', + bygroups(String.Affix, String.Double), + combined("stringescape", "tdqs")), + ("([uUbB]?)(''')", + bygroups(String.Affix, String.Single), + combined("stringescape", "tsqs")), + ('([uUbB]?)(")', + bygroups(String.Affix, String.Double), + combined("stringescape", "dqs")), + ("([uUbB]?)(')", + bygroups(String.Affix, String.Single), + combined("stringescape", "sqs")), + include("name"), + include("numbers"), + ], + "keywords": [ + (words(("and", "in", "not", "or", "as", "breakpoint", "class", + "class_name", "extends", "is", "func", "setget", "signal", + "tool", "const", "enum", "export", "onready", "static", + "var", "break", "continue", "if", "elif", "else", "for", + "pass", "return", "match", "while", "remote", "master", + "puppet", "remotesync", "mastersync", "puppetsync"), + suffix=r"\b"), Keyword), + ], + "builtins": [ + (words(("Color8", "ColorN", "abs", "acos", "asin", "assert", "atan", + "atan2", "bytes2var", "ceil", "char", "clamp", "convert", + "cos", "cosh", "db2linear", "decimals", "dectime", "deg2rad", + "dict2inst", "ease", "exp", "floor", "fmod", "fposmod", + "funcref", "hash", "inst2dict", "instance_from_id", "is_inf", + "is_nan", "lerp", "linear2db", "load", "log", "max", "min", + "nearest_po2", "pow", "preload", "print", "print_stack", + "printerr", "printraw", "prints", "printt", "rad2deg", + "rand_range", "rand_seed", "randf", "randi", "randomize", + "range", "round", "seed", "sign", "sin", "sinh", "sqrt", + "stepify", "str", "str2var", "tan", "tan", "tanh", + "type_exist", "typeof", "var2bytes", "var2str", "weakref", + "yield"), prefix=r"(?|\|\||\||\->|<\-|&&|<<|>>|\.\.|\.|=)', Punctuation), + + # Operators + (r'(<>|\+\.?|\-\.?|\*\.?|/\.?|%\.?|<=\.?|>=\.?|<\.?|>\.?|=)', Operator), + + # Strings + (r'"(\\"|[^"])*"', String), + + # Identifiers + (r'\b(let)(\s+)(\w+)', bygroups(Keyword, Whitespace, Name.Variable)), + (r'\b(fn)(\s+)(\w+)', bygroups(Keyword, Whitespace, Name.Function)), + (r'[a-zA-Z_/]\w*', Name), + + # numbers + (r'(\d+(_\d+)*\.(?!\.)(\d+(_\d+)*)?|\.\d+(_\d+)*)([eEf][+-]?[0-9]+)?', Number.Float), + (r'\d+(_\d+)*[eEf][+-]?[0-9]+', Number.Float), + (r'0[xX][a-fA-F0-9]+(_[a-fA-F0-9]+)*(\.([a-fA-F0-9]+(_[a-fA-F0-9]+)*)?)?p[+-]?\d+', Number.Float), + (r'0[bB][01]+(_[01]+)*', Number.Bin), + (r'0[oO][0-7]+(_[0-7]+)*', Number.Oct), + (r'0[xX][a-fA-F0-9]+(_[a-fA-F0-9]+)*', Number.Hex), + (r'\d+(_\d+)*', Number.Integer), + + # Whitespace + (r'\s+', Whitespace), + + ], + } diff --git a/pygments/lexers/go.py b/pygments/lexers/go.py new file mode 100644 index 0000000000000000000000000000000000000000..77a0fbce97cd2d51a896d7e039bb7c2dbee5434a --- /dev/null +++ b/pygments/lexers/go.py @@ -0,0 +1,97 @@ +""" + pygments.lexers.go + ~~~~~~~~~~~~~~~~~~ + + Lexers for the Google Go language. + + :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS. + :license: BSD, see LICENSE for details. +""" + +from pygments.lexer import RegexLexer, bygroups, words +from pygments.token import Text, Comment, Operator, Keyword, Name, String, \ + Number, Punctuation, Whitespace + +__all__ = ['GoLexer'] + + +class GoLexer(RegexLexer): + """ + For Go source. + """ + name = 'Go' + url = 'https://go.dev/' + filenames = ['*.go'] + aliases = ['go', 'golang'] + mimetypes = ['text/x-gosrc'] + version_added = '1.2' + + tokens = { + 'root': [ + (r'\n', Whitespace), + (r'\s+', Whitespace), + (r'(\\)(\n)', bygroups(Text, Whitespace)), # line continuations + (r'//(.*?)$', Comment.Single), + (r'/(\\\n)?[*](.|\n)*?[*](\\\n)?/', Comment.Multiline), + (r'(import|package)\b', Keyword.Namespace), + (r'(var|func|struct|map|chan|type|interface|const)\b', + Keyword.Declaration), + (words(( + 'break', 'default', 'select', 'case', 'defer', 'go', + 'else', 'goto', 'switch', 'fallthrough', 'if', 'range', + 'continue', 'for', 'return'), suffix=r'\b'), + Keyword), + (r'(true|false|iota|nil)\b', Keyword.Constant), + # It seems the builtin types aren't actually keywords, but + # can be used as functions. So we need two declarations. + (words(( + 'uint', 'uint8', 'uint16', 'uint32', 'uint64', + 'int', 'int8', 'int16', 'int32', 'int64', + 'float', 'float32', 'float64', + 'complex64', 'complex128', 'byte', 'rune', + 'string', 'bool', 'error', 'uintptr', 'any', 'comparable', + 'print', 'println', 'panic', 'recover', 'close', 'complex', + 'real', 'imag', 'len', 'cap', 'append', 'copy', 'delete', + 'new', 'make', 'min', 'max', 'clear'), suffix=r'\b(\()'), + bygroups(Name.Builtin, Punctuation)), + (words(( + 'uint', 'uint8', 'uint16', 'uint32', 'uint64', + 'int', 'int8', 'int16', 'int32', 'int64', + 'float', 'float32', 'float64', + 'complex64', 'complex128', 'byte', 'rune', + 'string', 'bool', 'error', 'uintptr', 'any', 'comparable'), suffix=r'\b'), + Keyword.Type), + # imaginary_lit + (r'\d+i', Number), + (r'\d+\.\d*([Ee][-+]\d+)?i', Number), + (r'\.\d+([Ee][-+]\d+)?i', Number), + (r'\d+[Ee][-+]\d+i', Number), + # float_lit + (r'\d+(\.\d+[eE][+\-]?\d+|' + r'\.\d*|[eE][+\-]?\d+)', Number.Float), + (r'\.\d+([eE][+\-]?\d+)?', Number.Float), + # int_lit + # -- octal_lit + (r'0[0-7]+', Number.Oct), + # -- hex_lit + (r'0[xX][0-9a-fA-F]+', Number.Hex), + # -- decimal_lit + (r'(0|[1-9][0-9]*)', Number.Integer), + # char_lit + (r"""'(\\['"\\abfnrtv]|\\x[0-9a-fA-F]{2}|\\[0-7]{1,3}""" + r"""|\\u[0-9a-fA-F]{4}|\\U[0-9a-fA-F]{8}|[^\\])'""", + String.Char), + # StringLiteral + # -- raw_string_lit + (r'`[^`]*`', String), + # -- interpreted_string_lit + (r'"(\\\\|\\[^\\]|[^"\\])*"', String), + # Tokens + (r'(<<=|>>=|<<|>>|<=|>=|&\^=|&\^|\+=|-=|\*=|/=|%=|&=|\|=|&&|\|\|' + r'|<-|\+\+|--|==|!=|:=|\.\.\.|[+\-*/%&]' + r'|~|\|)', Operator), + (r'[|^<>=!()\[\]{}.,;:]', Punctuation), + # identifier + (r'[^\W\d]\w*', Name.Other), + ] + } diff --git a/pygments/lexers/grammar_notation.py b/pygments/lexers/grammar_notation.py new file mode 100644 index 0000000000000000000000000000000000000000..45274fd14046283374e52f6b4bc4dbaad48321b4 --- /dev/null +++ b/pygments/lexers/grammar_notation.py @@ -0,0 +1,262 @@ +""" + pygments.lexers.grammar_notation + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + Lexers for grammar notations like BNF. + + :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS. + :license: BSD, see LICENSE for details. +""" + +from pygments.lexer import RegexLexer, bygroups, include, this, using, words +from pygments.token import Comment, Keyword, Literal, Name, Number, \ + Operator, Punctuation, String, Text, Whitespace + +__all__ = ['BnfLexer', 'AbnfLexer', 'JsgfLexer', 'PegLexer'] + + +class BnfLexer(RegexLexer): + """ + This lexer is for grammar notations which are similar to + original BNF. + + In order to maximize a number of targets of this lexer, + let's decide some designs: + + * We don't distinguish `Terminal Symbol`. + + * We do assume that `NonTerminal Symbol` are always enclosed + with arrow brackets. + + * We do assume that `NonTerminal Symbol` may include + any printable characters except arrow brackets and ASCII 0x20. + This assumption is for `RBNF `_. + + * We do assume that target notation doesn't support comment. + + * We don't distinguish any operators and punctuation except + `::=`. + + Though these decision making might cause too minimal highlighting + and you might be disappointed, but it is reasonable for us. + """ + + name = 'BNF' + aliases = ['bnf'] + filenames = ['*.bnf'] + mimetypes = ['text/x-bnf'] + url = 'https://en.wikipedia.org/wiki/Backus%E2%80%93Naur_form' + version_added = '2.1' + + tokens = { + 'root': [ + (r'(<)([ -;=?-~]+)(>)', + bygroups(Punctuation, Name.Class, Punctuation)), + + # an only operator + (r'::=', Operator), + + # fallback + (r'[^<>:]+', Text), # for performance + (r'.', Text), + ], + } + + +class AbnfLexer(RegexLexer): + """ + Lexer for IETF 7405 ABNF. + + (Updates `5234 `_) grammars. + """ + + name = 'ABNF' + url = 'http://www.ietf.org/rfc/rfc7405.txt' + aliases = ['abnf'] + filenames = ['*.abnf'] + mimetypes = ['text/x-abnf'] + version_added = '2.1' + + _core_rules = ( + 'ALPHA', 'BIT', 'CHAR', 'CR', 'CRLF', 'CTL', 'DIGIT', + 'DQUOTE', 'HEXDIG', 'HTAB', 'LF', 'LWSP', 'OCTET', + 'SP', 'VCHAR', 'WSP') + + tokens = { + 'root': [ + # comment + (r';.*$', Comment.Single), + + # quoted + # double quote itself in this state, it is as '%x22'. + (r'(%[si])?"[^"]*"', Literal), + + # binary (but i have never seen...) + (r'%b[01]+\-[01]+\b', Literal), # range + (r'%b[01]+(\.[01]+)*\b', Literal), # concat + + # decimal + (r'%d[0-9]+\-[0-9]+\b', Literal), # range + (r'%d[0-9]+(\.[0-9]+)*\b', Literal), # concat + + # hexadecimal + (r'%x[0-9a-fA-F]+\-[0-9a-fA-F]+\b', Literal), # range + (r'%x[0-9a-fA-F]+(\.[0-9a-fA-F]+)*\b', Literal), # concat + + # repetition (*element) including nRule + (r'\b[0-9]+\*[0-9]+', Operator), + (r'\b[0-9]+\*', Operator), + (r'\b[0-9]+', Operator), + (r'\*', Operator), + + # Strictly speaking, these are not keyword but + # are called `Core Rule'. + (words(_core_rules, suffix=r'\b'), Keyword), + + # nonterminals (ALPHA *(ALPHA / DIGIT / "-")) + (r'[a-zA-Z][a-zA-Z0-9-]*\b', Name.Class), + + # operators + (r'(=/|=|/)', Operator), + + # punctuation + (r'[\[\]()]', Punctuation), + + # fallback + (r'\s+', Whitespace), + (r'.', Text), + ], + } + + +class JsgfLexer(RegexLexer): + """ + For JSpeech Grammar Format grammars. + """ + name = 'JSGF' + url = 'https://www.w3.org/TR/jsgf/' + aliases = ['jsgf'] + filenames = ['*.jsgf'] + mimetypes = ['application/jsgf', 'application/x-jsgf', 'text/jsgf'] + version_added = '2.2' + + tokens = { + 'root': [ + include('comments'), + include('non-comments'), + ], + 'comments': [ + (r'/\*\*(?!/)', Comment.Multiline, 'documentation comment'), + (r'/\*[\w\W]*?\*/', Comment.Multiline), + (r'//.*$', Comment.Single), + ], + 'non-comments': [ + (r'\A#JSGF[^;]*', Comment.Preproc), + (r'\s+', Whitespace), + (r';', Punctuation), + (r'[=|()\[\]*+]', Operator), + (r'/[^/]+/', Number.Float), + (r'"', String.Double, 'string'), + (r'\{', String.Other, 'tag'), + (words(('import', 'public'), suffix=r'\b'), Keyword.Reserved), + (r'grammar\b', Keyword.Reserved, 'grammar name'), + (r'(<)(NULL|VOID)(>)', + bygroups(Punctuation, Name.Builtin, Punctuation)), + (r'<', Punctuation, 'rulename'), + (r'\w+|[^\s;=|()\[\]*+/"{<\w]+', Text), + ], + 'string': [ + (r'"', String.Double, '#pop'), + (r'\\.', String.Escape), + (r'[^\\"]+', String.Double), + ], + 'tag': [ + (r'\}', String.Other, '#pop'), + (r'\\.', String.Escape), + (r'[^\\}]+', String.Other), + ], + 'grammar name': [ + (r';', Punctuation, '#pop'), + (r'\s+', Whitespace), + (r'\.', Punctuation), + (r'[^;\s.]+', Name.Namespace), + ], + 'rulename': [ + (r'>', Punctuation, '#pop'), + (r'\*', Punctuation), + (r'\s+', Whitespace), + (r'([^.>]+)(\s*)(\.)', bygroups(Name.Namespace, Text, Punctuation)), + (r'[^.>]+', Name.Constant), + ], + 'documentation comment': [ + (r'\*/', Comment.Multiline, '#pop'), + (r'^(\s*)(\*?)(\s*)(@(?:example|see))(\s+)' + r'([\w\W]*?(?=(?:^\s*\*?\s*@|\*/)))', + bygroups(Whitespace, Comment.Multiline, Whitespace, Comment.Special, + Whitespace, using(this, state='example'))), + (r'(^\s*\*?\s*)(@\S*)', + bygroups(Comment.Multiline, Comment.Special)), + (r'[^*\n@]+|\w|\W', Comment.Multiline), + ], + 'example': [ + (r'(\n\s*)(\*)', bygroups(Whitespace, Comment.Multiline)), + include('non-comments'), + (r'.', Comment.Multiline), + ], + } + + +class PegLexer(RegexLexer): + """ + This lexer is for Parsing Expression Grammars (PEG). + + Various implementations of PEG have made different decisions + regarding the syntax, so let's try to be accommodating: + + * `<-`, `←`, `:`, and `=` are all accepted as rule operators. + + * Both `|` and `/` are choice operators. + + * `^`, `↑`, and `~` are cut operators. + + * A single `a-z` character immediately before a string, or + multiple `a-z` characters following a string, are part of the + string (e.g., `r"..."` or `"..."ilmsuxa`). + """ + + name = 'PEG' + url = 'https://bford.info/pub/lang/peg.pdf' + aliases = ['peg'] + filenames = ['*.peg'] + mimetypes = ['text/x-peg'] + version_added = '2.6' + + tokens = { + 'root': [ + # Comments + (r'#.*$', Comment.Single), + + # All operators + (r'<-|[←:=/|&!?*+^↑~]', Operator), + + # Other punctuation + (r'[()]', Punctuation), + + # Keywords + (r'\.', Keyword), + + # Character classes + (r'(\[)([^\]]*(?:\\.[^\]\\]*)*)(\])', + bygroups(Punctuation, String, Punctuation)), + + # Single and double quoted strings (with optional modifiers) + (r'[a-z]?"[^"\\]*(?:\\.[^"\\]*)*"[a-z]*', String.Double), + (r"[a-z]?'[^'\\]*(?:\\.[^'\\]*)*'[a-z]*", String.Single), + + # Nonterminals are not whitespace, operators, or punctuation + (r'[^\s<←:=/|&!?*+\^↑~()\[\]"\'#]+', Name.Class), + + # Fallback + (r'.', Text), + ], + } diff --git a/pygments/lexers/graph.py b/pygments/lexers/graph.py new file mode 100644 index 0000000000000000000000000000000000000000..01d75743abee16bac4c6ad8314e9fb27f63c6158 --- /dev/null +++ b/pygments/lexers/graph.py @@ -0,0 +1,108 @@ +""" + pygments.lexers.graph + ~~~~~~~~~~~~~~~~~~~~~ + + Lexers for graph query languages. + + :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS. + :license: BSD, see LICENSE for details. +""" + +import re + +from pygments.lexer import RegexLexer, include, bygroups, using, this, words +from pygments.token import Keyword, Punctuation, Comment, Operator, Name,\ + String, Number, Whitespace + + +__all__ = ['CypherLexer'] + + +class CypherLexer(RegexLexer): + """ + For Cypher Query Language + + For the Cypher version in Neo4j 3.3 + """ + name = 'Cypher' + url = 'https://neo4j.com/docs/developer-manual/3.3/cypher/' + aliases = ['cypher'] + filenames = ['*.cyp', '*.cypher'] + version_added = '2.0' + + flags = re.MULTILINE | re.IGNORECASE + + tokens = { + 'root': [ + include('clauses'), + include('keywords'), + include('relations'), + include('strings'), + include('whitespace'), + include('barewords'), + include('comment'), + ], + 'keywords': [ + (r'(create|order|match|limit|set|skip|start|return|with|where|' + r'delete|foreach|not|by|true|false)\b', Keyword), + ], + 'clauses': [ + # based on https://neo4j.com/docs/cypher-refcard/3.3/ + (r'(create)(\s+)(index|unique)\b', + bygroups(Keyword, Whitespace, Keyword)), + (r'(drop)(\s+)(contraint|index)(\s+)(on)\b', + bygroups(Keyword, Whitespace, Keyword, Whitespace, Keyword)), + (r'(ends)(\s+)(with)\b', + bygroups(Keyword, Whitespace, Keyword)), + (r'(is)(\s+)(node)(\s+)(key)\b', + bygroups(Keyword, Whitespace, Keyword, Whitespace, Keyword)), + (r'(is)(\s+)(null|unique)\b', + bygroups(Keyword, Whitespace, Keyword)), + (r'(load)(\s+)(csv)(\s+)(from)\b', + bygroups(Keyword, Whitespace, Keyword, Whitespace, Keyword)), + (r'(on)(\s+)(match|create)\b', + bygroups(Keyword, Whitespace, Keyword)), + (r'(optional)(\s+)(match)\b', + bygroups(Keyword, Whitespace, Keyword)), + (r'(order)(\s+)(by)\b', + bygroups(Keyword, Whitespace, Keyword)), + (r'(starts)(\s+)(with)\b', + bygroups(Keyword, Whitespace, Keyword)), + (r'(union)(\s+)(all)\b', + bygroups(Keyword, Whitespace, Keyword)), + (r'(using)(\s+)(periodic)(\s+)(commit)\b', + bygroups(Keyword, Whitespace, Keyword, Whitespace, Keyword)), + (r'(using)(\s+)(index)\b', + bygroups(Keyword, Whitespace, Keyword)), + (r'(using)(\s+)(range|text|point)(\s+)(index)\b', + bygroups(Keyword, Whitespace, Name, Whitespace, Keyword)), + (words(( + 'all', 'any', 'as', 'asc', 'ascending', 'assert', 'call', 'case', 'create', + 'delete', 'desc', 'descending', 'distinct', 'end', 'fieldterminator', + 'foreach', 'in', 'limit', 'match', 'merge', 'none', 'not', 'null', + 'remove', 'return', 'set', 'skip', 'single', 'start', 'then', 'union', + 'unwind', 'yield', 'where', 'when', 'with', 'collect'), suffix=r'\b'), Keyword), + ], + 'relations': [ + (r'(-\[)(.*?)(\]->)', bygroups(Operator, using(this), Operator)), + (r'(<-\[)(.*?)(\]-)', bygroups(Operator, using(this), Operator)), + (r'(-\[)(.*?)(\]-)', bygroups(Operator, using(this), Operator)), + (r'-->|<--|\[|\]', Operator), + (r'<|>|<>|=|<=|=>|\(|\)|\||:|,|;', Punctuation), + (r'[.*{}]', Punctuation), + ], + 'strings': [ + (r'([\'"])(?:\\[tbnrf\'"\\]|[^\\])*?\1', String), + (r'`(?:``|[^`])+`', Name.Variable), + ], + 'whitespace': [ + (r'\s+', Whitespace), + ], + 'barewords': [ + (r'[a-z]\w*', Name), + (r'\d+', Number), + ], + 'comment': [ + (r'//.*$', Comment.Single), + ], + } diff --git a/pygments/lexers/graphics.py b/pygments/lexers/graphics.py new file mode 100644 index 0000000000000000000000000000000000000000..400be4fd8950b130454b9b43c039415fdc166309 --- /dev/null +++ b/pygments/lexers/graphics.py @@ -0,0 +1,794 @@ +""" + pygments.lexers.graphics + ~~~~~~~~~~~~~~~~~~~~~~~~ + + Lexers for computer graphics and plotting related languages. + + :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS. + :license: BSD, see LICENSE for details. +""" + +from pygments.lexer import RegexLexer, words, include, bygroups, using, \ + this, default +from pygments.token import Text, Comment, Operator, Keyword, Name, \ + Number, Punctuation, String, Whitespace + +__all__ = ['GLShaderLexer', 'PostScriptLexer', 'AsymptoteLexer', 'GnuplotLexer', + 'PovrayLexer', 'HLSLShaderLexer'] + + +class GLShaderLexer(RegexLexer): + """ + GLSL (OpenGL Shader) lexer. + """ + name = 'GLSL' + aliases = ['glsl'] + filenames = ['*.vert', '*.frag', '*.geo'] + mimetypes = ['text/x-glslsrc'] + url = 'https://www.khronos.org/api/opengl' + version_added = '1.1' + + tokens = { + 'root': [ + (r'#(?:.*\\\n)*.*$', Comment.Preproc), + (r'//.*$', Comment.Single), + (r'/(\\\n)?[*](.|\n)*?[*](\\\n)?/', Comment.Multiline), + (r'\+|-|~|!=?|\*|/|%|<<|>>|<=?|>=?|==?|&&?|\^|\|\|?', + Operator), + (r'[?:]', Operator), # quick hack for ternary + (r'\bdefined\b', Operator), + (r'[;{}(),\[\]]', Punctuation), + # FIXME when e is present, no decimal point needed + (r'[+-]?\d*\.\d+([eE][-+]?\d+)?', Number.Float), + (r'[+-]?\d+\.\d*([eE][-+]?\d+)?', Number.Float), + (r'0[xX][0-9a-fA-F]*', Number.Hex), + (r'0[0-7]*', Number.Oct), + (r'[1-9][0-9]*', Number.Integer), + (words(( + # Storage qualifiers + 'attribute', 'const', 'uniform', 'varying', + 'buffer', 'shared', 'in', 'out', + # Layout qualifiers + 'layout', + # Interpolation qualifiers + 'flat', 'smooth', 'noperspective', + # Auxiliary qualifiers + 'centroid', 'sample', 'patch', + # Parameter qualifiers. Some double as Storage qualifiers + 'inout', + # Precision qualifiers + 'lowp', 'mediump', 'highp', 'precision', + # Invariance qualifiers + 'invariant', + # Precise qualifiers + 'precise', + # Memory qualifiers + 'coherent', 'volatile', 'restrict', 'readonly', 'writeonly', + # Statements + 'break', 'continue', 'do', 'for', 'while', 'switch', + 'case', 'default', 'if', 'else', 'subroutine', + 'discard', 'return', 'struct'), + prefix=r'\b', suffix=r'\b'), + Keyword), + (words(( + # Boolean values + 'true', 'false'), + prefix=r'\b', suffix=r'\b'), + Keyword.Constant), + (words(( + # Miscellaneous types + 'void', 'atomic_uint', + # Floating-point scalars and vectors + 'float', 'vec2', 'vec3', 'vec4', + 'double', 'dvec2', 'dvec3', 'dvec4', + # Integer scalars and vectors + 'int', 'ivec2', 'ivec3', 'ivec4', + 'uint', 'uvec2', 'uvec3', 'uvec4', + # Boolean scalars and vectors + 'bool', 'bvec2', 'bvec3', 'bvec4', + # Matrices + 'mat2', 'mat3', 'mat4', 'dmat2', 'dmat3', 'dmat4', + 'mat2x2', 'mat2x3', 'mat2x4', 'dmat2x2', 'dmat2x3', 'dmat2x4', + 'mat3x2', 'mat3x3', 'mat3x4', 'dmat3x2', 'dmat3x3', + 'dmat3x4', 'mat4x2', 'mat4x3', 'mat4x4', 'dmat4x2', 'dmat4x3', 'dmat4x4', + # Floating-point samplers + 'sampler1D', 'sampler2D', 'sampler3D', 'samplerCube', + 'sampler1DArray', 'sampler2DArray', 'samplerCubeArray', + 'sampler2DRect', 'samplerBuffer', + 'sampler2DMS', 'sampler2DMSArray', + # Shadow samplers + 'sampler1DShadow', 'sampler2DShadow', 'samplerCubeShadow', + 'sampler1DArrayShadow', 'sampler2DArrayShadow', + 'samplerCubeArrayShadow', 'sampler2DRectShadow', + # Signed integer samplers + 'isampler1D', 'isampler2D', 'isampler3D', 'isamplerCube', + 'isampler1DArray', 'isampler2DArray', 'isamplerCubeArray', + 'isampler2DRect', 'isamplerBuffer', + 'isampler2DMS', 'isampler2DMSArray', + # Unsigned integer samplers + 'usampler1D', 'usampler2D', 'usampler3D', 'usamplerCube', + 'usampler1DArray', 'usampler2DArray', 'usamplerCubeArray', + 'usampler2DRect', 'usamplerBuffer', + 'usampler2DMS', 'usampler2DMSArray', + # Floating-point image types + 'image1D', 'image2D', 'image3D', 'imageCube', + 'image1DArray', 'image2DArray', 'imageCubeArray', + 'image2DRect', 'imageBuffer', + 'image2DMS', 'image2DMSArray', + # Signed integer image types + 'iimage1D', 'iimage2D', 'iimage3D', 'iimageCube', + 'iimage1DArray', 'iimage2DArray', 'iimageCubeArray', + 'iimage2DRect', 'iimageBuffer', + 'iimage2DMS', 'iimage2DMSArray', + # Unsigned integer image types + 'uimage1D', 'uimage2D', 'uimage3D', 'uimageCube', + 'uimage1DArray', 'uimage2DArray', 'uimageCubeArray', + 'uimage2DRect', 'uimageBuffer', + 'uimage2DMS', 'uimage2DMSArray'), + prefix=r'\b', suffix=r'\b'), + Keyword.Type), + (words(( + # Reserved for future use. + 'common', 'partition', 'active', 'asm', 'class', + 'union', 'enum', 'typedef', 'template', 'this', + 'resource', 'goto', 'inline', 'noinline', 'public', + 'static', 'extern', 'external', 'interface', 'long', + 'short', 'half', 'fixed', 'unsigned', 'superp', 'input', + 'output', 'hvec2', 'hvec3', 'hvec4', 'fvec2', 'fvec3', + 'fvec4', 'sampler3DRect', 'filter', 'sizeof', 'cast', + 'namespace', 'using'), + prefix=r'\b', suffix=r'\b'), + Keyword.Reserved), + # All names beginning with "gl_" are reserved. + (r'gl_\w*', Name.Builtin), + (r'[a-zA-Z_]\w*', Name), + (r'\.', Punctuation), + (r'\s+', Whitespace), + ], + } + + +class HLSLShaderLexer(RegexLexer): + """ + HLSL (Microsoft Direct3D Shader) lexer. + """ + name = 'HLSL' + aliases = ['hlsl'] + filenames = ['*.hlsl', '*.hlsli'] + mimetypes = ['text/x-hlsl'] + url = 'https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl' + version_added = '2.3' + + tokens = { + 'root': [ + (r'#(?:.*\\\n)*.*$', Comment.Preproc), + (r'//.*$', Comment.Single), + (r'/(\\\n)?[*](.|\n)*?[*](\\\n)?/', Comment.Multiline), + (r'\+|-|~|!=?|\*|/|%|<<|>>|<=?|>=?|==?|&&?|\^|\|\|?', + Operator), + (r'[?:]', Operator), # quick hack for ternary + (r'\bdefined\b', Operator), + (r'[;{}(),.\[\]]', Punctuation), + # FIXME when e is present, no decimal point needed + (r'[+-]?\d*\.\d+([eE][-+]?\d+)?f?', Number.Float), + (r'[+-]?\d+\.\d*([eE][-+]?\d+)?f?', Number.Float), + (r'0[xX][0-9a-fA-F]*', Number.Hex), + (r'0[0-7]*', Number.Oct), + (r'[1-9][0-9]*', Number.Integer), + (r'"', String, 'string'), + (words(( + 'asm','asm_fragment','break','case','cbuffer','centroid','class', + 'column_major','compile','compile_fragment','const','continue', + 'default','discard','do','else','export','extern','for','fxgroup', + 'globallycoherent','groupshared','if','in','inline','inout', + 'interface','line','lineadj','linear','namespace','nointerpolation', + 'noperspective','NULL','out','packoffset','pass','pixelfragment', + 'point','precise','return','register','row_major','sample', + 'sampler','shared','stateblock','stateblock_state','static', + 'struct','switch','tbuffer','technique','technique10', + 'technique11','texture','typedef','triangle','triangleadj', + 'uniform','vertexfragment','volatile','while'), + prefix=r'\b', suffix=r'\b'), + Keyword), + (words(('true','false'), prefix=r'\b', suffix=r'\b'), + Keyword.Constant), + (words(( + 'auto','catch','char','const_cast','delete','dynamic_cast','enum', + 'explicit','friend','goto','long','mutable','new','operator', + 'private','protected','public','reinterpret_cast','short','signed', + 'sizeof','static_cast','template','this','throw','try','typename', + 'union','unsigned','using','virtual'), + prefix=r'\b', suffix=r'\b'), + Keyword.Reserved), + (words(( + 'dword','matrix','snorm','string','unorm','unsigned','void','vector', + 'BlendState','Buffer','ByteAddressBuffer','ComputeShader', + 'DepthStencilState','DepthStencilView','DomainShader', + 'GeometryShader','HullShader','InputPatch','LineStream', + 'OutputPatch','PixelShader','PointStream','RasterizerState', + 'RenderTargetView','RasterizerOrderedBuffer', + 'RasterizerOrderedByteAddressBuffer', + 'RasterizerOrderedStructuredBuffer','RasterizerOrderedTexture1D', + 'RasterizerOrderedTexture1DArray','RasterizerOrderedTexture2D', + 'RasterizerOrderedTexture2DArray','RasterizerOrderedTexture3D', + 'RWBuffer','RWByteAddressBuffer','RWStructuredBuffer', + 'RWTexture1D','RWTexture1DArray','RWTexture2D','RWTexture2DArray', + 'RWTexture3D','SamplerState','SamplerComparisonState', + 'StructuredBuffer','Texture1D','Texture1DArray','Texture2D', + 'Texture2DArray','Texture2DMS','Texture2DMSArray','Texture3D', + 'TextureCube','TextureCubeArray','TriangleStream','VertexShader'), + prefix=r'\b', suffix=r'\b'), + Keyword.Type), + (words(( + 'bool','double','float','int','half','min16float','min10float', + 'min16int','min12int','min16uint','uint'), + prefix=r'\b', suffix=r'([1-4](x[1-4])?)?\b'), + Keyword.Type), # vector and matrix types + (words(( + 'abort','abs','acos','all','AllMemoryBarrier', + 'AllMemoryBarrierWithGroupSync','any','AppendStructuredBuffer', + 'asdouble','asfloat','asin','asint','asuint','asuint','atan', + 'atan2','ceil','CheckAccessFullyMapped','clamp','clip', + 'CompileShader','ConsumeStructuredBuffer','cos','cosh','countbits', + 'cross','D3DCOLORtoUBYTE4','ddx','ddx_coarse','ddx_fine','ddy', + 'ddy_coarse','ddy_fine','degrees','determinant', + 'DeviceMemoryBarrier','DeviceMemoryBarrierWithGroupSync','distance', + 'dot','dst','errorf','EvaluateAttributeAtCentroid', + 'EvaluateAttributeAtSample','EvaluateAttributeSnapped','exp', + 'exp2','f16tof32','f32tof16','faceforward','firstbithigh', + 'firstbitlow','floor','fma','fmod','frac','frexp','fwidth', + 'GetRenderTargetSampleCount','GetRenderTargetSamplePosition', + 'GlobalOrderedCountIncrement','GroupMemoryBarrier', + 'GroupMemoryBarrierWithGroupSync','InterlockedAdd','InterlockedAnd', + 'InterlockedCompareExchange','InterlockedCompareStore', + 'InterlockedExchange','InterlockedMax','InterlockedMin', + 'InterlockedOr','InterlockedXor','isfinite','isinf','isnan', + 'ldexp','length','lerp','lit','log','log10','log2','mad','max', + 'min','modf','msad4','mul','noise','normalize','pow','printf', + 'Process2DQuadTessFactorsAvg','Process2DQuadTessFactorsMax', + 'Process2DQuadTessFactorsMin','ProcessIsolineTessFactors', + 'ProcessQuadTessFactorsAvg','ProcessQuadTessFactorsMax', + 'ProcessQuadTessFactorsMin','ProcessTriTessFactorsAvg', + 'ProcessTriTessFactorsMax','ProcessTriTessFactorsMin', + 'QuadReadLaneAt','QuadSwapX','QuadSwapY','radians','rcp', + 'reflect','refract','reversebits','round','rsqrt','saturate', + 'sign','sin','sincos','sinh','smoothstep','sqrt','step','tan', + 'tanh','tex1D','tex1D','tex1Dbias','tex1Dgrad','tex1Dlod', + 'tex1Dproj','tex2D','tex2D','tex2Dbias','tex2Dgrad','tex2Dlod', + 'tex2Dproj','tex3D','tex3D','tex3Dbias','tex3Dgrad','tex3Dlod', + 'tex3Dproj','texCUBE','texCUBE','texCUBEbias','texCUBEgrad', + 'texCUBElod','texCUBEproj','transpose','trunc','WaveAllBitAnd', + 'WaveAllMax','WaveAllMin','WaveAllBitOr','WaveAllBitXor', + 'WaveAllEqual','WaveAllProduct','WaveAllSum','WaveAllTrue', + 'WaveAnyTrue','WaveBallot','WaveGetLaneCount','WaveGetLaneIndex', + 'WaveGetOrderedIndex','WaveIsHelperLane','WaveOnce', + 'WavePrefixProduct','WavePrefixSum','WaveReadFirstLane', + 'WaveReadLaneAt'), + prefix=r'\b', suffix=r'\b'), + Name.Builtin), # built-in functions + (words(( + 'SV_ClipDistance','SV_ClipDistance0','SV_ClipDistance1', + 'SV_Culldistance','SV_CullDistance0','SV_CullDistance1', + 'SV_Coverage','SV_Depth','SV_DepthGreaterEqual', + 'SV_DepthLessEqual','SV_DispatchThreadID','SV_DomainLocation', + 'SV_GroupID','SV_GroupIndex','SV_GroupThreadID','SV_GSInstanceID', + 'SV_InnerCoverage','SV_InsideTessFactor','SV_InstanceID', + 'SV_IsFrontFace','SV_OutputControlPointID','SV_Position', + 'SV_PrimitiveID','SV_RenderTargetArrayIndex','SV_SampleIndex', + 'SV_StencilRef','SV_TessFactor','SV_VertexID', + 'SV_ViewportArrayIndex'), + prefix=r'\b', suffix=r'\b'), + Name.Decorator), # system-value semantics + (r'\bSV_Target[0-7]?\b', Name.Decorator), + (words(( + 'allow_uav_condition','branch','call','domain','earlydepthstencil', + 'fastopt','flatten','forcecase','instance','loop','maxtessfactor', + 'numthreads','outputcontrolpoints','outputtopology','partitioning', + 'patchconstantfunc','unroll'), + prefix=r'\b', suffix=r'\b'), + Name.Decorator), # attributes + (r'[a-zA-Z_]\w*', Name), + (r'\\$', Comment.Preproc), # backslash at end of line -- usually macro continuation + (r'\s+', Whitespace), + ], + 'string': [ + (r'"', String, '#pop'), + (r'\\([\\abfnrtv"\']|x[a-fA-F0-9]{2,4}|' + r'u[a-fA-F0-9]{4}|U[a-fA-F0-9]{8}|[0-7]{1,3})', String.Escape), + (r'[^\\"\n]+', String), # all other characters + (r'\\\n', String), # line continuation + (r'\\', String), # stray backslash + ], + } + + +class PostScriptLexer(RegexLexer): + """ + Lexer for PostScript files. + """ + name = 'PostScript' + url = 'https://en.wikipedia.org/wiki/PostScript' + aliases = ['postscript', 'postscr'] + filenames = ['*.ps', '*.eps'] + mimetypes = ['application/postscript'] + version_added = '1.4' + + delimiter = r'()<>\[\]{}/%\s' + delimiter_end = rf'(?=[{delimiter}])' + + valid_name_chars = rf'[^{delimiter}]' + valid_name = rf"{valid_name_chars}+{delimiter_end}" + + tokens = { + 'root': [ + # All comment types + (r'^%!.+$', Comment.Preproc), + (r'%%.*$', Comment.Special), + (r'(^%.*\n){2,}', Comment.Multiline), + (r'%.*$', Comment.Single), + + # String literals are awkward; enter separate state. + (r'\(', String, 'stringliteral'), + + (r'[{}<>\[\]]', Punctuation), + + # Numbers + (r'<[0-9A-Fa-f]+>' + delimiter_end, Number.Hex), + # Slight abuse: use Oct to signify any explicit base system + (r'[0-9]+\#(\-|\+)?([0-9]+\.?|[0-9]*\.[0-9]+|[0-9]+\.[0-9]*)' + r'((e|E)[0-9]+)?' + delimiter_end, Number.Oct), + (r'(\-|\+)?([0-9]+\.?|[0-9]*\.[0-9]+|[0-9]+\.[0-9]*)((e|E)[0-9]+)?' + + delimiter_end, Number.Float), + (r'(\-|\+)?[0-9]+' + delimiter_end, Number.Integer), + + # References + (rf'\/{valid_name}', Name.Variable), + + # Names + (valid_name, Name.Function), # Anything else is executed + + # These keywords taken from + # + # Is there an authoritative list anywhere that doesn't involve + # trawling documentation? + + (r'(false|true)' + delimiter_end, Keyword.Constant), + + # Conditionals / flow control + (r'(eq|ne|g[et]|l[et]|and|or|not|if(?:else)?|for(?:all)?)' + + delimiter_end, Keyword.Reserved), + + (words(( + 'abs', 'add', 'aload', 'arc', 'arcn', 'array', 'atan', 'begin', + 'bind', 'ceiling', 'charpath', 'clip', 'closepath', 'concat', + 'concatmatrix', 'copy', 'cos', 'currentlinewidth', 'currentmatrix', + 'currentpoint', 'curveto', 'cvi', 'cvs', 'def', 'defaultmatrix', + 'dict', 'dictstackoverflow', 'div', 'dtransform', 'dup', 'end', + 'exch', 'exec', 'exit', 'exp', 'fill', 'findfont', 'floor', 'get', + 'getinterval', 'grestore', 'gsave', 'gt', 'identmatrix', 'idiv', + 'idtransform', 'index', 'invertmatrix', 'itransform', 'length', + 'lineto', 'ln', 'load', 'log', 'loop', 'matrix', 'mod', 'moveto', + 'mul', 'neg', 'newpath', 'pathforall', 'pathbbox', 'pop', 'print', + 'pstack', 'put', 'quit', 'rand', 'rangecheck', 'rcurveto', 'repeat', + 'restore', 'rlineto', 'rmoveto', 'roll', 'rotate', 'round', 'run', + 'save', 'scale', 'scalefont', 'setdash', 'setfont', 'setgray', + 'setlinecap', 'setlinejoin', 'setlinewidth', 'setmatrix', + 'setrgbcolor', 'shfill', 'show', 'showpage', 'sin', 'sqrt', + 'stack', 'stringwidth', 'stroke', 'strokepath', 'sub', 'syntaxerror', + 'transform', 'translate', 'truncate', 'typecheck', 'undefined', + 'undefinedfilename', 'undefinedresult'), suffix=delimiter_end), + Name.Builtin), + + (r'\s+', Whitespace), + ], + + 'stringliteral': [ + (r'[^()\\]+', String), + (r'\\', String.Escape, 'escape'), + (r'\(', String, '#push'), + (r'\)', String, '#pop'), + ], + + 'escape': [ + (r'[0-8]{3}|n|r|t|b|f|\\|\(|\)', String.Escape, '#pop'), + default('#pop'), + ], + } + + +class AsymptoteLexer(RegexLexer): + """ + For Asymptote source code. + """ + name = 'Asymptote' + url = 'http://asymptote.sf.net/' + aliases = ['asymptote', 'asy'] + filenames = ['*.asy'] + mimetypes = ['text/x-asymptote'] + version_added = '1.2' + + #: optional Comment or Whitespace + _ws = r'(?:\s|//.*?\n|/\*.*?\*/)+' + + tokens = { + 'whitespace': [ + (r'\n', Whitespace), + (r'\s+', Whitespace), + (r'(\\)(\n)', bygroups(Text, Whitespace)), # line continuation + (r'//(\n|(.|\n)*?[^\\]\n)', Comment), + (r'/(\\\n)?\*(.|\n)*?\*(\\\n)?/', Comment), + ], + 'statements': [ + # simple string (TeX friendly) + (r'"(\\\\|\\[^\\]|[^"\\])*"', String), + # C style string (with character escapes) + (r"'", String, 'string'), + (r'(\d+\.\d*|\.\d+|\d+)[eE][+-]?\d+[lL]?', Number.Float), + (r'(\d+\.\d*|\.\d+|\d+[fF])[fF]?', Number.Float), + (r'0x[0-9a-fA-F]+[Ll]?', Number.Hex), + (r'0[0-7]+[Ll]?', Number.Oct), + (r'\d+[Ll]?', Number.Integer), + (r'[~!%^&*+=|?:<>/-]', Operator), + (r'[()\[\],.]', Punctuation), + (r'\b(case)(.+?)(:)', bygroups(Keyword, using(this), Text)), + (r'(and|controls|tension|atleast|curl|if|else|while|for|do|' + r'return|break|continue|struct|typedef|new|access|import|' + r'unravel|from|include|quote|static|public|private|restricted|' + r'this|explicit|true|false|null|cycle|newframe|operator)\b', Keyword), + # Since an asy-type-name can be also an asy-function-name, + # in the following we test if the string " [a-zA-Z]" follows + # the Keyword.Type. + # Of course it is not perfect ! + (r'(Braid|FitResult|Label|Legend|TreeNode|abscissa|arc|arrowhead|' + r'binarytree|binarytreeNode|block|bool|bool3|bounds|bqe|circle|' + r'conic|coord|coordsys|cputime|ellipse|file|filltype|frame|grid3|' + r'guide|horner|hsv|hyperbola|indexedTransform|int|inversion|key|' + r'light|line|linefit|marginT|marker|mass|object|pair|parabola|path|' + r'path3|pen|picture|point|position|projection|real|revolution|' + r'scaleT|scientific|segment|side|slice|splitface|string|surface|' + r'tensionSpecifier|ticklocate|ticksgridT|tickvalues|transform|' + r'transformation|tree|triangle|trilinear|triple|vector|' + r'vertex|void)(?=\s+[a-zA-Z])', Keyword.Type), + # Now the asy-type-name which are not asy-function-name + # except yours ! + # Perhaps useless + (r'(Braid|FitResult|TreeNode|abscissa|arrowhead|block|bool|bool3|' + r'bounds|coord|frame|guide|horner|int|linefit|marginT|pair|pen|' + r'picture|position|real|revolution|slice|splitface|ticksgridT|' + r'tickvalues|tree|triple|vertex|void)\b', Keyword.Type), + (r'[a-zA-Z_]\w*:(?!:)', Name.Label), + (r'[a-zA-Z_]\w*', Name), + ], + 'root': [ + include('whitespace'), + # functions + (r'((?:[\w*\s])+?(?:\s|\*))' # return arguments + r'([a-zA-Z_]\w*)' # method name + r'(\s*\([^;]*?\))' # signature + r'(' + _ws + r')(\{)', + bygroups(using(this), Name.Function, using(this), using(this), + Punctuation), + 'function'), + # function declarations + (r'((?:[\w*\s])+?(?:\s|\*))' # return arguments + r'([a-zA-Z_]\w*)' # method name + r'(\s*\([^;]*?\))' # signature + r'(' + _ws + r')(;)', + bygroups(using(this), Name.Function, using(this), using(this), + Punctuation)), + default('statement'), + ], + 'statement': [ + include('whitespace'), + include('statements'), + ('[{}]', Punctuation), + (';', Punctuation, '#pop'), + ], + 'function': [ + include('whitespace'), + include('statements'), + (';', Punctuation), + (r'\{', Punctuation, '#push'), + (r'\}', Punctuation, '#pop'), + ], + 'string': [ + (r"'", String, '#pop'), + (r'\\([\\abfnrtv"\'?]|x[a-fA-F0-9]{2,4}|[0-7]{1,3})', String.Escape), + (r'\n', String), + (r"[^\\'\n]+", String), # all other characters + (r'\\\n', String), + (r'\\n', String), # line continuation + (r'\\', String), # stray backslash + ], + } + + def get_tokens_unprocessed(self, text): + from pygments.lexers._asy_builtins import ASYFUNCNAME, ASYVARNAME + for index, token, value in \ + RegexLexer.get_tokens_unprocessed(self, text): + if token is Name and value in ASYFUNCNAME: + token = Name.Function + elif token is Name and value in ASYVARNAME: + token = Name.Variable + yield index, token, value + + +def _shortened(word): + dpos = word.find('$') + return '|'.join(word[:dpos] + word[dpos+1:i] + r'\b' + for i in range(len(word), dpos, -1)) + + +def _shortened_many(*words): + return '|'.join(map(_shortened, words)) + + +class GnuplotLexer(RegexLexer): + """ + For Gnuplot plotting scripts. + """ + + name = 'Gnuplot' + url = 'http://gnuplot.info/' + aliases = ['gnuplot'] + filenames = ['*.plot', '*.plt'] + mimetypes = ['text/x-gnuplot'] + version_added = '0.11' + + tokens = { + 'root': [ + include('whitespace'), + (_shortened('bi$nd'), Keyword, 'bind'), + (_shortened_many('ex$it', 'q$uit'), Keyword, 'quit'), + (_shortened('f$it'), Keyword, 'fit'), + (r'(if)(\s*)(\()', bygroups(Keyword, Text, Punctuation), 'if'), + (r'else\b', Keyword), + (_shortened('pa$use'), Keyword, 'pause'), + (_shortened_many('p$lot', 'rep$lot', 'sp$lot'), Keyword, 'plot'), + (_shortened('sa$ve'), Keyword, 'save'), + (_shortened('se$t'), Keyword, ('genericargs', 'optionarg')), + (_shortened_many('sh$ow', 'uns$et'), + Keyword, ('noargs', 'optionarg')), + (_shortened_many('low$er', 'ra$ise', 'ca$ll', 'cd$', 'cl$ear', + 'h$elp', '\\?$', 'hi$story', 'l$oad', 'pr$int', + 'pwd$', 're$read', 'res$et', 'scr$eendump', + 'she$ll', 'sy$stem', 'up$date'), + Keyword, 'genericargs'), + (_shortened_many('pwd$', 're$read', 'res$et', 'scr$eendump', + 'she$ll', 'test$'), + Keyword, 'noargs'), + (r'([a-zA-Z_]\w*)(\s*)(=)', + bygroups(Name.Variable, Whitespace, Operator), 'genericargs'), + (r'([a-zA-Z_]\w*)(\s*)(\()(.*?)(\))(\s*)(=)', + bygroups(Name.Function, Whitespace, Punctuation, + Text, Punctuation, Whitespace, Operator), 'genericargs'), + (r'@[a-zA-Z_]\w*', Name.Constant), # macros + (r';', Keyword), + ], + 'comment': [ + (r'[^\\\n]+', Comment), + (r'\\\n', Comment), + (r'\\', Comment), + # don't add the newline to the Comment token + default('#pop'), + ], + 'whitespace': [ + ('#', Comment, 'comment'), + (r'[ \t\v\f]+', Whitespace), + ], + 'noargs': [ + include('whitespace'), + # semicolon and newline end the argument list + (r';', Punctuation, '#pop'), + (r'\n', Whitespace, '#pop'), + ], + 'dqstring': [ + (r'"', String, '#pop'), + (r'\\([\\abfnrtv"\']|x[a-fA-F0-9]{2,4}|[0-7]{1,3})', String.Escape), + (r'[^\\"\n]+', String), # all other characters + (r'\\\n', String), # line continuation + (r'\\', String), # stray backslash + (r'\n', Whitespace, '#pop'), # newline ends the string too + ], + 'sqstring': [ + (r"''", String), # escaped single quote + (r"'", String, '#pop'), + (r"[^\\'\n]+", String), # all other characters + (r'\\\n', String), # line continuation + (r'\\', String), # normal backslash + (r'\n', Whitespace, '#pop'), # newline ends the string too + ], + 'genericargs': [ + include('noargs'), + (r'"', String, 'dqstring'), + (r"'", String, 'sqstring'), + (r'(\d+\.\d*|\.\d+|\d+)[eE][+-]?\d+', Number.Float), + (r'(\d+\.\d*|\.\d+)', Number.Float), + (r'-?\d+', Number.Integer), + ('[,.~!%^&*+=|?:<>/-]', Operator), + (r'[{}()\[\]]', Punctuation), + (r'(eq|ne)\b', Operator.Word), + (r'([a-zA-Z_]\w*)(\s*)(\()', + bygroups(Name.Function, Text, Punctuation)), + (r'[a-zA-Z_]\w*', Name), + (r'@[a-zA-Z_]\w*', Name.Constant), # macros + (r'(\\)(\n)', bygroups(Text, Whitespace)), + ], + 'optionarg': [ + include('whitespace'), + (_shortened_many( + "a$ll", "an$gles", "ar$row", "au$toscale", "b$ars", "bor$der", + "box$width", "cl$abel", "c$lip", "cn$trparam", "co$ntour", "da$ta", + "data$file", "dg$rid3d", "du$mmy", "enc$oding", "dec$imalsign", + "fit$", "font$path", "fo$rmat", "fu$nction", "fu$nctions", "g$rid", + "hid$den3d", "his$torysize", "is$osamples", "k$ey", "keyt$itle", + "la$bel", "li$nestyle", "ls$", "loa$dpath", "loc$ale", "log$scale", + "mac$ros", "map$ping", "map$ping3d", "mar$gin", "lmar$gin", + "rmar$gin", "tmar$gin", "bmar$gin", "mo$use", "multi$plot", + "mxt$ics", "nomxt$ics", "mx2t$ics", "nomx2t$ics", "myt$ics", + "nomyt$ics", "my2t$ics", "nomy2t$ics", "mzt$ics", "nomzt$ics", + "mcbt$ics", "nomcbt$ics", "of$fsets", "or$igin", "o$utput", + "pa$rametric", "pm$3d", "pal$ette", "colorb$ox", "p$lot", + "poi$ntsize", "pol$ar", "pr$int", "obj$ect", "sa$mples", "si$ze", + "st$yle", "su$rface", "table$", "t$erminal", "termo$ptions", "ti$cs", + "ticsc$ale", "ticsl$evel", "timef$mt", "tim$estamp", "tit$le", + "v$ariables", "ve$rsion", "vi$ew", "xyp$lane", "xda$ta", "x2da$ta", + "yda$ta", "y2da$ta", "zda$ta", "cbda$ta", "xl$abel", "x2l$abel", + "yl$abel", "y2l$abel", "zl$abel", "cbl$abel", "xti$cs", "noxti$cs", + "x2ti$cs", "nox2ti$cs", "yti$cs", "noyti$cs", "y2ti$cs", "noy2ti$cs", + "zti$cs", "nozti$cs", "cbti$cs", "nocbti$cs", "xdti$cs", "noxdti$cs", + "x2dti$cs", "nox2dti$cs", "ydti$cs", "noydti$cs", "y2dti$cs", + "noy2dti$cs", "zdti$cs", "nozdti$cs", "cbdti$cs", "nocbdti$cs", + "xmti$cs", "noxmti$cs", "x2mti$cs", "nox2mti$cs", "ymti$cs", + "noymti$cs", "y2mti$cs", "noy2mti$cs", "zmti$cs", "nozmti$cs", + "cbmti$cs", "nocbmti$cs", "xr$ange", "x2r$ange", "yr$ange", + "y2r$ange", "zr$ange", "cbr$ange", "rr$ange", "tr$ange", "ur$ange", + "vr$ange", "xzeroa$xis", "x2zeroa$xis", "yzeroa$xis", "y2zeroa$xis", + "zzeroa$xis", "zeroa$xis", "z$ero"), Name.Builtin, '#pop'), + ], + 'bind': [ + ('!', Keyword, '#pop'), + (_shortened('all$windows'), Name.Builtin), + include('genericargs'), + ], + 'quit': [ + (r'gnuplot\b', Keyword), + include('noargs'), + ], + 'fit': [ + (r'via\b', Name.Builtin), + include('plot'), + ], + 'if': [ + (r'\)', Punctuation, '#pop'), + include('genericargs'), + ], + 'pause': [ + (r'(mouse|any|button1|button2|button3)\b', Name.Builtin), + (_shortened('key$press'), Name.Builtin), + include('genericargs'), + ], + 'plot': [ + (_shortened_many('ax$es', 'axi$s', 'bin$ary', 'ev$ery', 'i$ndex', + 'mat$rix', 's$mooth', 'thru$', 't$itle', + 'not$itle', 'u$sing', 'w$ith'), + Name.Builtin), + include('genericargs'), + ], + 'save': [ + (_shortened_many('f$unctions', 's$et', 't$erminal', 'v$ariables'), + Name.Builtin), + include('genericargs'), + ], + } + + +class PovrayLexer(RegexLexer): + """ + For Persistence of Vision Raytracer files. + """ + name = 'POVRay' + url = 'http://www.povray.org/' + aliases = ['pov'] + filenames = ['*.pov', '*.inc'] + mimetypes = ['text/x-povray'] + version_added = '0.11' + + tokens = { + 'root': [ + (r'/\*[\w\W]*?\*/', Comment.Multiline), + (r'//.*$', Comment.Single), + (r'(?s)"(?:\\.|[^"\\])+"', String.Double), + (words(( + 'break', 'case', 'debug', 'declare', 'default', 'define', 'else', + 'elseif', 'end', 'error', 'fclose', 'fopen', 'for', 'if', 'ifdef', + 'ifndef', 'include', 'local', 'macro', 'range', 'read', 'render', + 'statistics', 'switch', 'undef', 'version', 'warning', 'while', + 'write'), prefix=r'#', suffix=r'\b'), + Comment.Preproc), + (words(( + 'aa_level', 'aa_threshold', 'abs', 'acos', 'acosh', 'adaptive', 'adc_bailout', + 'agate', 'agate_turb', 'all', 'alpha', 'ambient', 'ambient_light', 'angle', + 'aperture', 'arc_angle', 'area_light', 'asc', 'asin', 'asinh', 'assumed_gamma', + 'atan', 'atan2', 'atanh', 'atmosphere', 'atmospheric_attenuation', + 'attenuating', 'average', 'background', 'black_hole', 'blue', 'blur_samples', + 'bounded_by', 'box_mapping', 'bozo', 'break', 'brick', 'brick_size', + 'brightness', 'brilliance', 'bumps', 'bumpy1', 'bumpy2', 'bumpy3', 'bump_map', + 'bump_size', 'case', 'caustics', 'ceil', 'checker', 'chr', 'clipped_by', 'clock', + 'color', 'color_map', 'colour', 'colour_map', 'component', 'composite', 'concat', + 'confidence', 'conic_sweep', 'constant', 'control0', 'control1', 'cos', 'cosh', + 'count', 'crackle', 'crand', 'cube', 'cubic_spline', 'cylindrical_mapping', + 'debug', 'declare', 'default', 'degrees', 'dents', 'diffuse', 'direction', + 'distance', 'distance_maximum', 'div', 'dust', 'dust_type', 'eccentricity', + 'else', 'emitting', 'end', 'error', 'error_bound', 'exp', 'exponent', + 'fade_distance', 'fade_power', 'falloff', 'falloff_angle', 'false', + 'file_exists', 'filter', 'finish', 'fisheye', 'flatness', 'flip', 'floor', + 'focal_point', 'fog', 'fog_alt', 'fog_offset', 'fog_type', 'frequency', 'gif', + 'global_settings', 'glowing', 'gradient', 'granite', 'gray_threshold', + 'green', 'halo', 'hexagon', 'hf_gray_16', 'hierarchy', 'hollow', 'hypercomplex', + 'if', 'ifdef', 'iff', 'image_map', 'incidence', 'include', 'int', 'interpolate', + 'inverse', 'ior', 'irid', 'irid_wavelength', 'jitter', 'lambda', 'leopard', + 'linear', 'linear_spline', 'linear_sweep', 'location', 'log', 'looks_like', + 'look_at', 'low_error_factor', 'mandel', 'map_type', 'marble', 'material_map', + 'matrix', 'max', 'max_intersections', 'max_iteration', 'max_trace_level', + 'max_value', 'metallic', 'min', 'minimum_reuse', 'mod', 'mortar', + 'nearest_count', 'no', 'normal', 'normal_map', 'no_shadow', 'number_of_waves', + 'octaves', 'off', 'offset', 'omega', 'omnimax', 'on', 'once', 'onion', 'open', + 'orthographic', 'panoramic', 'pattern1', 'pattern2', 'pattern3', + 'perspective', 'pgm', 'phase', 'phong', 'phong_size', 'pi', 'pigment', + 'pigment_map', 'planar_mapping', 'png', 'point_at', 'pot', 'pow', 'ppm', + 'precision', 'pwr', 'quadratic_spline', 'quaternion', 'quick_color', + 'quick_colour', 'quilted', 'radial', 'radians', 'radiosity', 'radius', 'rainbow', + 'ramp_wave', 'rand', 'range', 'reciprocal', 'recursion_limit', 'red', + 'reflection', 'refraction', 'render', 'repeat', 'rgb', 'rgbf', 'rgbft', 'rgbt', + 'right', 'ripples', 'rotate', 'roughness', 'samples', 'scale', 'scallop_wave', + 'scattering', 'seed', 'shadowless', 'sin', 'sine_wave', 'sinh', 'sky', 'sky_sphere', + 'slice', 'slope_map', 'smooth', 'specular', 'spherical_mapping', 'spiral', + 'spiral1', 'spiral2', 'spotlight', 'spotted', 'sqr', 'sqrt', 'statistics', 'str', + 'strcmp', 'strength', 'strlen', 'strlwr', 'strupr', 'sturm', 'substr', 'switch', 'sys', + 't', 'tan', 'tanh', 'test_camera_1', 'test_camera_2', 'test_camera_3', + 'test_camera_4', 'texture', 'texture_map', 'tga', 'thickness', 'threshold', + 'tightness', 'tile2', 'tiles', 'track', 'transform', 'translate', 'transmit', + 'triangle_wave', 'true', 'ttf', 'turbulence', 'turb_depth', 'type', + 'ultra_wide_angle', 'up', 'use_color', 'use_colour', 'use_index', 'u_steps', + 'val', 'variance', 'vaxis_rotate', 'vcross', 'vdot', 'version', 'vlength', + 'vnormalize', 'volume_object', 'volume_rendered', 'vol_with_light', + 'vrotate', 'v_steps', 'warning', 'warp', 'water_level', 'waves', 'while', 'width', + 'wood', 'wrinkles', 'yes'), prefix=r'\b', suffix=r'\b'), + Keyword), + (words(( + 'bicubic_patch', 'blob', 'box', 'camera', 'cone', 'cubic', 'cylinder', 'difference', + 'disc', 'height_field', 'intersection', 'julia_fractal', 'lathe', + 'light_source', 'merge', 'mesh', 'object', 'plane', 'poly', 'polygon', 'prism', + 'quadric', 'quartic', 'smooth_triangle', 'sor', 'sphere', 'superellipsoid', + 'text', 'torus', 'triangle', 'union'), suffix=r'\b'), + Name.Builtin), + (r'\b(x|y|z|u|v)\b', Name.Builtin.Pseudo), + (r'[a-zA-Z_]\w*', Name), + (r'[0-9]*\.[0-9]+', Number.Float), + (r'[0-9]+', Number.Integer), + (r'[\[\](){}<>;,]', Punctuation), + (r'[-+*/=.|&]|<=|>=|!=', Operator), + (r'"(\\\\|\\[^\\]|[^"\\])*"', String), + (r'\s+', Whitespace), + ] + } + + def analyse_text(text): + """POVRAY is similar to JSON/C, but the combination of camera and + light_source is probably not very likely elsewhere. HLSL or GLSL + are similar (GLSL even has #version), but they miss #declare, and + light_source/camera are not keywords anywhere else -- it's fair + to assume though that any POVRAY scene must have a camera and + lightsource.""" + result = 0 + if '#version' in text: + result += 0.05 + if '#declare' in text: + result += 0.05 + if 'camera' in text: + result += 0.05 + if 'light_source' in text: + result += 0.1 + + return result diff --git a/pygments/lexers/graphql.py b/pygments/lexers/graphql.py new file mode 100644 index 0000000000000000000000000000000000000000..7de5b6e7cf41d8e4337319f2847d967be56cf80f --- /dev/null +++ b/pygments/lexers/graphql.py @@ -0,0 +1,176 @@ +""" + pygments.lexers.graphql + ~~~~~~~~~~~~~~~~~~~~~~~ + + Lexer for GraphQL, an open-source data query and manipulation + language for APIs. + + More information: + https://graphql.org/ + + :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS. + :license: BSD, see LICENSE for details. +""" + +from pygments.lexer import RegexLexer, words, include, bygroups, default +from pygments.token import (Comment, Keyword, Name, Number, Punctuation, String, + Whitespace) + + +__all__ = ["GraphQLLexer"] + +OPERATION_TYPES = ("query", "mutation", "subscription") +BUILTIN_TYPES = ("Int", "Float", "String", "Boolean", "ID") +BOOLEAN_VALUES = ("true", "false", "null") +KEYWORDS = ( + "type", + "schema", + "extend", + "enum", + "scalar", + "implements", + "interface", + "union", + "input", + "directive", + "QUERY", + "MUTATION", + "SUBSCRIPTION", + "FIELD", + "FRAGMENT_DEFINITION", + "FRAGMENT_SPREAD", + "INLINE_FRAGMENT", + "SCHEMA", + "SCALAR", + "OBJECT", + "FIELD_DEFINITION", + "ARGUMENT_DEFINITION", + "INTERFACE", + "UNION", + "ENUM", + "ENUM_VALUE", + "INPUT_OBJECT", + "INPUT_FIELD_DEFINITION", +) + + +class GraphQLLexer(RegexLexer): + """ + Lexer for GraphQL syntax + """ + name = "GraphQL" + aliases = ["graphql"] + filenames = ["*.graphql"] + url = "https://graphql.org" + version_added = '2.16' + + tokens = { + "ignored_tokens": [ + (r"\s+", Whitespace), # Whitespaces + (r"#.*$", Comment), + (",", Punctuation), # Insignificant commas + ], + "value": [ + include("ignored_tokens"), + (r"-?\d+(?![.eE])", Number.Integer, "#pop"), + ( + r"-?\d+(\.\d+)?([eE][+-]?\d+)?", + Number.Float, + "#pop", + ), + (r'"', String, ("#pop", "string")), + (words(BOOLEAN_VALUES, suffix=r"\b"), Name.Builtin, "#pop"), + (r"\$[a-zA-Z_]\w*", Name.Variable, "#pop"), + (r"[a-zA-Z_]\w*", Name.Constant, "#pop"), + (r"\[", Punctuation, ("#pop", "list_value")), + (r"\{", Punctuation, ("#pop", "object_value")), + ], + "list_value": [ + include("ignored_tokens"), + ("]", Punctuation, "#pop"), + default("value"), + ], + "object_value": [ + include("ignored_tokens"), + (r"[a-zA-Z_]\w*", Name), + (r":", Punctuation, "value"), + (r"\}", Punctuation, "#pop"), + ], + "string": [ + (r'\\(["\\/bfnrt]|u[a-fA-F0-9]{4})', String.Escape), + (r'[^\\"\n]+', String), # all other characters + (r'"', String, "#pop"), + ], + "root": [ + include("ignored_tokens"), + (words(OPERATION_TYPES, suffix=r"\b"), Keyword, "operation"), + (words(KEYWORDS, suffix=r"\b"), Keyword), + (r"\{", Punctuation, "selection_set"), + (r"fragment\b", Keyword, "fragment_definition"), + ], + "operation": [ + include("ignored_tokens"), + (r"[a-zA-Z_]\w*", Name.Function), + (r"\(", Punctuation, "variable_definition"), + (r"\{", Punctuation, ("#pop", "selection_set")), + ], + "variable_definition": [ + include("ignored_tokens"), + (r"\$[a-zA-Z_]\w*", Name.Variable), + (r"[\]!]", Punctuation), + (r":", Punctuation, "type"), + (r"=", Punctuation, "value"), + (r"\)", Punctuation, "#pop"), + ], + "type": [ + include("ignored_tokens"), + (r"\[", Punctuation), + (words(BUILTIN_TYPES, suffix=r"\b"), Name.Builtin, "#pop"), + (r"[a-zA-Z_]\w*", Name.Class, "#pop"), + ], + "selection_set": [ + include("ignored_tokens"), + (r"([a-zA-Z_]\w*)(\s*)(:)", bygroups(Name.Label, Whitespace, Punctuation)), + (r"[a-zA-Z_]\w*", Name), # Field + ( + r"(\.\.\.)(\s+)(on)\b", + bygroups(Punctuation, Whitespace, Keyword), + "inline_fragment", + ), + (r"\.\.\.", Punctuation, "fragment_spread"), + (r"\(", Punctuation, "arguments"), + (r"@[a-zA-Z_]\w*", Name.Decorator, "directive"), + (r"\{", Punctuation, "selection_set"), + (r"\}", Punctuation, "#pop"), + ], + "directive": [ + include("ignored_tokens"), + (r"\(", Punctuation, ("#pop", "arguments")), + ], + "arguments": [ + include("ignored_tokens"), + (r"[a-zA-Z_]\w*", Name), + (r":", Punctuation, "value"), + (r"\)", Punctuation, "#pop"), + ], + # Fragments + "fragment_definition": [ + include("ignored_tokens"), + (r"[\]!]", Punctuation), # For NamedType + (r"on\b", Keyword, "type"), + (r"[a-zA-Z_]\w*", Name.Function), + (r"@[a-zA-Z_]\w*", Name.Decorator, "directive"), + (r"\{", Punctuation, ("#pop", "selection_set")), + ], + "fragment_spread": [ + include("ignored_tokens"), + (r"@[a-zA-Z_]\w*", Name.Decorator, "directive"), + (r"[a-zA-Z_]\w*", Name, "#pop"), # Fragment name + ], + "inline_fragment": [ + include("ignored_tokens"), + (r"[a-zA-Z_]\w*", Name.Class), # Type condition + (r"@[a-zA-Z_]\w*", Name.Decorator, "directive"), + (r"\{", Punctuation, ("#pop", "selection_set")), + ], + } diff --git a/pygments/lexers/graphviz.py b/pygments/lexers/graphviz.py new file mode 100644 index 0000000000000000000000000000000000000000..6b4841554f4acb570faa5896e0e5042bbfedc225 --- /dev/null +++ b/pygments/lexers/graphviz.py @@ -0,0 +1,58 @@ +""" + pygments.lexers.graphviz + ~~~~~~~~~~~~~~~~~~~~~~~~ + + Lexer for the DOT language (graphviz). + + :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS. + :license: BSD, see LICENSE for details. +""" + +from pygments.lexer import RegexLexer, bygroups +from pygments.token import Comment, Keyword, Operator, Name, String, Number, \ + Punctuation, Whitespace + + +__all__ = ['GraphvizLexer'] + + +class GraphvizLexer(RegexLexer): + """ + For graphviz DOT graph description language. + """ + name = 'Graphviz' + url = 'https://www.graphviz.org/doc/info/lang.html' + aliases = ['graphviz', 'dot'] + filenames = ['*.gv', '*.dot'] + mimetypes = ['text/x-graphviz', 'text/vnd.graphviz'] + version_added = '2.8' + tokens = { + 'root': [ + (r'\s+', Whitespace), + (r'(#|//).*?$', Comment.Single), + (r'/(\\\n)?[*](.|\n)*?[*](\\\n)?/', Comment.Multiline), + (r'(?i)(node|edge|graph|digraph|subgraph|strict)\b', Keyword), + (r'--|->', Operator), + (r'[{}[\]:;,]', Punctuation), + (r'(\b\D\w*)(\s*)(=)(\s*)', + bygroups(Name.Attribute, Whitespace, Punctuation, Whitespace), + 'attr_id'), + (r'\b(n|ne|e|se|s|sw|w|nw|c|_)\b', Name.Builtin), + (r'\b\D\w*', Name.Tag), # node + (r'[-]?((\.[0-9]+)|([0-9]+(\.[0-9]*)?))', Number), + (r'"(\\"|[^"])*?"', Name.Tag), # quoted node + (r'<', Punctuation, 'xml'), + ], + 'attr_id': [ + (r'\b\D\w*', String, '#pop'), + (r'[-]?((\.[0-9]+)|([0-9]+(\.[0-9]*)?))', Number, '#pop'), + (r'"(\\"|[^"])*?"', String.Double, '#pop'), + (r'<', Punctuation, ('#pop', 'xml')), + ], + 'xml': [ + (r'<', Punctuation, '#push'), + (r'>', Punctuation, '#pop'), + (r'\s+', Whitespace), + (r'[^<>\s]', Name.Tag), + ] + } diff --git a/pygments/lexers/gsql.py b/pygments/lexers/gsql.py new file mode 100644 index 0000000000000000000000000000000000000000..9ff187889f74e537bb9c42d56524760adc2d21d0 --- /dev/null +++ b/pygments/lexers/gsql.py @@ -0,0 +1,103 @@ +""" + pygments.lexers.gsql + ~~~~~~~~~~~~~~~~~~~~ + + Lexers for TigerGraph GSQL graph query language + + :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS. + :license: BSD, see LICENSE for details. +""" + +import re + +from pygments.lexer import RegexLexer, include, bygroups, using, this, words +from pygments.token import Keyword, Punctuation, Comment, Operator, Name, \ + String, Number, Whitespace + +__all__ = ["GSQLLexer"] + + +class GSQLLexer(RegexLexer): + + """ + For GSQL queries (version 3.x). + """ + + name = 'GSQL' + url = 'https://docs.tigergraph.com/dev/gsql-ref' + aliases = ['gsql'] + filenames = ['*.gsql'] + version_added = '2.10' + + flags = re.MULTILINE | re.IGNORECASE + + tokens = { + 'root': [ + include('comment'), + include('keywords'), + include('clauses'), + include('accums'), + include('relations'), + include('strings'), + include('whitespace'), + include('barewords'), + include('operators'), + ], + 'comment': [ + (r'\#.*', Comment.Single), + (r'/\*(.|\n)*?\*/', Comment.Multiline), + ], + 'keywords': [ + (words(( + 'ACCUM', 'AND', 'ANY', 'API', 'AS', 'ASC', 'AVG', 'BAG', 'BATCH', + 'BETWEEN', 'BOOL', 'BOTH', 'BREAK', 'BY', 'CASE', 'CATCH', 'COALESCE', + 'COMPRESS', 'CONTINUE', 'COUNT', 'CREATE', 'DATETIME', 'DATETIME_ADD', + 'DATETIME_SUB', 'DELETE', 'DESC', 'DISTRIBUTED', 'DO', 'DOUBLE', + 'EDGE', 'ELSE', 'END', 'ESCAPE', 'EXCEPTION', 'FALSE', 'FILE', + 'FILTER', 'FLOAT', 'FOREACH', 'FOR', 'FROM', 'GRAPH', 'GROUP', + 'GSQL_INT_MAX', 'GSQL_INT_MIN', 'GSQL_UINT_MAX', 'HAVING', 'IF', + 'IN', 'INSERT', 'INT', 'INTERPRET', 'INTERSECT', 'INTERVAL', 'INTO', + 'IS', 'ISEMPTY', 'JSONARRAY', 'JSONOBJECT', 'LASTHOP', 'LEADING', + 'LIKE', 'LIMIT', 'LIST', 'LOAD_ACCUM', 'LOG', 'MAP', 'MATCH', 'MAX', + 'MIN', 'MINUS', 'NOT', 'NOW', 'NULL', 'OFFSET', 'OR', 'ORDER', 'PATH', + 'PER', 'PINNED', 'POST_ACCUM', 'POST-ACCUM', 'PRIMARY_ID', 'PRINT', + 'QUERY', 'RAISE', 'RANGE', 'REPLACE', 'RESET_COLLECTION_ACCUM', + 'RETURN', 'RETURNS', 'RUN', 'SAMPLE', 'SELECT', 'SELECT_VERTEX', + 'SET', 'SRC', 'STATIC', 'STRING', 'SUM', 'SYNTAX', 'TARGET', + 'TAGSTGT', 'THEN', 'TO', 'TO_CSV', 'TO_DATETIME', 'TRAILING', + 'TRIM', 'TRUE', 'TRY', 'TUPLE', 'TYPEDEF', 'UINT', 'UNION', 'UPDATE', + 'VALUES', 'VERTEX', 'WHEN', 'WHERE', 'WHILE', 'WITH'), + prefix=r'(?|<-', Operator), + (r'[.*{}\[\]\<\>\_]', Punctuation), + ], + 'strings': [ + (r'"([^"\\]|\\.)*"', String), + (r'@{1,2}\w+', Name.Variable), + ], + 'whitespace': [ + (r'\s+', Whitespace), + ], + 'barewords': [ + (r'[a-z]\w*', Name), + (r'(\d+\.\d+|\d+)', Number), + ], + 'operators': [ + (r'\$|[^0-9|\/|\-](\-\=|\+\=|\*\=|\\\=|\=|\=\=|\=\=\=|' + r'\+|\-|\*|\\|\+\=|\>|\<)[^\>|\/]', Operator), + (r'(\||\(|\)|\,|\;|\=|\-|\+|\*|\/|\>|\<|\:)', Operator), + ], + } diff --git a/pygments/lexers/hare.py b/pygments/lexers/hare.py new file mode 100644 index 0000000000000000000000000000000000000000..56548d5e59f3ca0968d79f7e76cb67dc4ac817fd --- /dev/null +++ b/pygments/lexers/hare.py @@ -0,0 +1,73 @@ +""" + pygments.lexers.hare + ~~~~~~~~~~~~~~~~~~~~ + + Lexers for the Hare language. + + :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS. + :license: BSD, see LICENSE for details. +""" + +from pygments.lexer import RegexLexer, include, words +from pygments.token import Comment, Operator, Keyword, Name, String, \ + Number, Punctuation, Whitespace + +__all__ = ['HareLexer'] + +class HareLexer(RegexLexer): + """ + Lexer for the Hare programming language. + """ + name = 'Hare' + url = 'https://harelang.org/' + aliases = ['hare'] + filenames = ['*.ha'] + mimetypes = ['text/x-hare'] + version_added = '2.19' + + _ws = r'(?:\s|//.*?\n|/[*].*?[*]/)+' + _ws1 = r'\s*(?:/[*].*?[*]/\s*)?' + + tokens = { + 'whitespace': [ + (r'^use.*;', Comment.Preproc), + (r'@[a-z]+', Comment.Preproc), + (r'\n', Whitespace), + (r'\s+', Whitespace), + (r'//.*?$', Comment.Single), + ], + 'statements': [ + (r'"', String, 'string'), + (r'`[^`]*`', String), + (r"'(\\.|\\[0-7]{1,3}|\\x[a-fA-F0-9]{1,2}|[^\\\'\n])'", String.Char), + (r'(\d+\.\d*|\.\d+|\d+)[eE][+-]?\d+[LlUu]*', Number.Float), + (r'(\d+\.\d*|\.\d+|\d+[fF])[fF]?', Number.Float), + (r'0x[0-9a-fA-F]+[LlUu]*', Number.Hex), + (r'0o[0-7]+[LlUu]*', Number.Oct), + (r'\d+[zui]?(\d+)?', Number.Integer), + (r'[~!%^&*+=|?:<>/-]', Operator), + (words(('as', 'is', '=>', '..', '...')), Operator), + (r'[()\[\],.{};]+', Punctuation), + (words(('abort', 'align', 'alloc', 'append', 'assert', 'case', + 'const', 'def', 'defer', 'delete', 'else', 'enum', 'export', + 'fn', 'for', 'free', 'if', 'let', 'len', 'match', 'offset', + 'return', 'static', 'struct', 'switch', 'type', 'union', + 'yield', 'vastart', 'vaarg', 'vaend'), + suffix=r'\b'), Keyword), + (r'(bool|int|uint|uintptr|u8|u16|u32|u64|i8|i16|i32|i64|f32|f64|null|done|never|void|nullable|rune|size|valist)\b', + Keyword.Type), + (r'(true|false|null)\b', Name.Builtin), + (r'[a-zA-Z_]\w*', Name), + ], + 'string': [ + (r'"', String, '#pop'), + (r'\\([\\abfnrtv"\']|x[a-fA-F0-9]{2,4}|' + r'u[a-fA-F0-9]{4}|U[a-fA-F0-9]{8}|[0-7]{1,3})', String.Escape), + (r'[^\\"\n]+', String), # all other characters + (r'\\', String), # stray backslash + ], + 'root': [ + include('whitespace'), + include('statements'), + ], + } diff --git a/pygments/lexers/haskell.py b/pygments/lexers/haskell.py new file mode 100644 index 0000000000000000000000000000000000000000..0a9f2c15a740fb95472ca21c7b8bc8ac6ecb241c --- /dev/null +++ b/pygments/lexers/haskell.py @@ -0,0 +1,866 @@ +""" + pygments.lexers.haskell + ~~~~~~~~~~~~~~~~~~~~~~~ + + Lexers for Haskell and related languages. + + :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS. + :license: BSD, see LICENSE for details. +""" + +import re + +from pygments.lexer import Lexer, RegexLexer, bygroups, do_insertions, \ + default, include, inherit, line_re +from pygments.token import Text, Comment, Operator, Keyword, Name, String, \ + Number, Punctuation, Generic, Whitespace +from pygments import unistring as uni + +__all__ = ['HaskellLexer', 'HspecLexer', 'IdrisLexer', 'AgdaLexer', 'CryptolLexer', + 'LiterateHaskellLexer', 'LiterateIdrisLexer', 'LiterateAgdaLexer', + 'LiterateCryptolLexer', 'KokaLexer'] + + +class HaskellLexer(RegexLexer): + """ + A Haskell lexer based on the lexemes defined in the Haskell 98 Report. + """ + name = 'Haskell' + url = 'https://www.haskell.org/' + aliases = ['haskell', 'hs'] + filenames = ['*.hs'] + mimetypes = ['text/x-haskell'] + version_added = '0.8' + + reserved = ('case', 'class', 'data', 'default', 'deriving', 'do', 'else', + 'family', 'if', 'in', 'infix[lr]?', 'instance', + 'let', 'newtype', 'of', 'then', 'type', 'where', '_') + ascii = ('NUL', 'SOH', '[SE]TX', 'EOT', 'ENQ', 'ACK', + 'BEL', 'BS', 'HT', 'LF', 'VT', 'FF', 'CR', 'S[OI]', 'DLE', + 'DC[1-4]', 'NAK', 'SYN', 'ETB', 'CAN', + 'EM', 'SUB', 'ESC', '[FGRU]S', 'SP', 'DEL') + + tokens = { + 'root': [ + # Whitespace: + (r'\s+', Whitespace), + # (r'--\s*|.*$', Comment.Doc), + (r'--(?![!#$%&*+./<=>?@^|_~:\\]).*?$', Comment.Single), + (r'\{-', Comment.Multiline, 'comment'), + # Lexemes: + # Identifiers + (r'\bimport\b', Keyword.Reserved, 'import'), + (r'\bmodule\b', Keyword.Reserved, 'module'), + (r'\berror\b', Name.Exception), + (r'\b({})(?!\')\b'.format('|'.join(reserved)), Keyword.Reserved), + (r"'[^\\]'", String.Char), # this has to come before the TH quote + (r'^[_' + uni.Ll + r'][\w\']*', Name.Function), + (r"'?[_" + uni.Ll + r"][\w']*", Name), + (r"('')?[" + uni.Lu + r"][\w\']*", Keyword.Type), + (r"(')[" + uni.Lu + r"][\w\']*", Keyword.Type), + (r"(')\[[^\]]*\]", Keyword.Type), # tuples and lists get special treatment in GHC + (r"(')\([^)]*\)", Keyword.Type), # .. + (r"(')[:!#$%&*+.\\/<=>?@^|~-]+", Keyword.Type), # promoted type operators + # Operators + (r'\\(?![:!#$%&*+.\\/<=>?@^|~-]+)', Name.Function), # lambda operator + (r'(<-|::|->|=>|=)(?![:!#$%&*+.\\/<=>?@^|~-]+)', Operator.Word), # specials + (r':[:!#$%&*+.\\/<=>?@^|~-]*', Keyword.Type), # Constructor operators + (r'[:!#$%&*+.\\/<=>?@^|~-]+', Operator), # Other operators + # Numbers + (r'0[xX]_*[\da-fA-F](_*[\da-fA-F])*_*[pP][+-]?\d(_*\d)*', Number.Float), + (r'0[xX]_*[\da-fA-F](_*[\da-fA-F])*\.[\da-fA-F](_*[\da-fA-F])*' + r'(_*[pP][+-]?\d(_*\d)*)?', Number.Float), + (r'\d(_*\d)*_*[eE][+-]?\d(_*\d)*', Number.Float), + (r'\d(_*\d)*\.\d(_*\d)*(_*[eE][+-]?\d(_*\d)*)?', Number.Float), + (r'0[bB]_*[01](_*[01])*', Number.Bin), + (r'0[oO]_*[0-7](_*[0-7])*', Number.Oct), + (r'0[xX]_*[\da-fA-F](_*[\da-fA-F])*', Number.Hex), + (r'\d(_*\d)*', Number.Integer), + # Character/String Literals + (r"'", String.Char, 'character'), + (r'"', String, 'string'), + # Special + (r'\[\]', Keyword.Type), + (r'\(\)', Name.Builtin), + (r'[][(),;`{}]', Punctuation), + ], + 'import': [ + # Import statements + (r'\s+', Whitespace), + (r'"', String, 'string'), + # after "funclist" state + (r'\)', Punctuation, '#pop'), + (r'qualified\b', Keyword), + # import X as Y + (r'([' + uni.Lu + r'][\w.]*)(\s+)(as)(\s+)([' + uni.Lu + r'][\w.]*)', + bygroups(Name.Namespace, Whitespace, Keyword, Whitespace, Name), '#pop'), + # import X hiding (functions) + (r'([' + uni.Lu + r'][\w.]*)(\s+)(hiding)(\s+)(\()', + bygroups(Name.Namespace, Whitespace, Keyword, Whitespace, Punctuation), 'funclist'), + # import X (functions) + (r'([' + uni.Lu + r'][\w.]*)(\s+)(\()', + bygroups(Name.Namespace, Whitespace, Punctuation), 'funclist'), + # import X + (r'[\w.]+', Name.Namespace, '#pop'), + ], + 'module': [ + (r'\s+', Whitespace), + (r'([' + uni.Lu + r'][\w.]*)(\s+)(\()', + bygroups(Name.Namespace, Whitespace, Punctuation), 'funclist'), + (r'[' + uni.Lu + r'][\w.]*', Name.Namespace, '#pop'), + ], + 'funclist': [ + (r'\s+', Whitespace), + (r'[' + uni.Lu + r']\w*', Keyword.Type), + (r'(_[\w\']+|[' + uni.Ll + r'][\w\']*)', Name.Function), + (r'--(?![!#$%&*+./<=>?@^|_~:\\]).*?$', Comment.Single), + (r'\{-', Comment.Multiline, 'comment'), + (r',', Punctuation), + (r'[:!#$%&*+.\\/<=>?@^|~-]+', Operator), + # (HACK, but it makes sense to push two instances, believe me) + (r'\(', Punctuation, ('funclist', 'funclist')), + (r'\)', Punctuation, '#pop:2'), + ], + # NOTE: the next four states are shared in the AgdaLexer; make sure + # any change is compatible with Agda as well or copy over and change + 'comment': [ + # Multiline Comments + (r'[^-{}]+', Comment.Multiline), + (r'\{-', Comment.Multiline, '#push'), + (r'-\}', Comment.Multiline, '#pop'), + (r'[-{}]', Comment.Multiline), + ], + 'character': [ + # Allows multi-chars, incorrectly. + (r"[^\\']'", String.Char, '#pop'), + (r"\\", String.Escape, 'escape'), + ("'", String.Char, '#pop'), + ], + 'string': [ + (r'[^\\"]+', String), + (r"\\", String.Escape, 'escape'), + ('"', String, '#pop'), + ], + 'escape': [ + (r'[abfnrtv"\'&\\]', String.Escape, '#pop'), + (r'\^[][' + uni.Lu + r'@^_]', String.Escape, '#pop'), + ('|'.join(ascii), String.Escape, '#pop'), + (r'o[0-7]+', String.Escape, '#pop'), + (r'x[\da-fA-F]+', String.Escape, '#pop'), + (r'\d+', String.Escape, '#pop'), + (r'(\s+)(\\)', bygroups(Whitespace, String.Escape), '#pop'), + ], + } + + +class HspecLexer(HaskellLexer): + """ + A Haskell lexer with support for Hspec constructs. + """ + + name = 'Hspec' + aliases = ['hspec'] + filenames = ['*Spec.hs'] + mimetypes = [] + version_added = '2.4' + + tokens = { + 'root': [ + (r'(it)(\s*)("[^"]*")', bygroups(Text, Whitespace, String.Doc)), + (r'(describe)(\s*)("[^"]*")', bygroups(Text, Whitespace, String.Doc)), + (r'(context)(\s*)("[^"]*")', bygroups(Text, Whitespace, String.Doc)), + inherit, + ], + } + + +class IdrisLexer(RegexLexer): + """ + A lexer for the dependently typed programming language Idris. + + Based on the Haskell and Agda Lexer. + """ + name = 'Idris' + url = 'https://www.idris-lang.org/' + aliases = ['idris', 'idr'] + filenames = ['*.idr'] + mimetypes = ['text/x-idris'] + version_added = '2.0' + + reserved = ('case', 'class', 'data', 'default', 'using', 'do', 'else', + 'if', 'in', 'infix[lr]?', 'instance', 'rewrite', 'auto', + 'namespace', 'codata', 'mutual', 'private', 'public', 'abstract', + 'total', 'partial', + 'interface', 'implementation', 'export', 'covering', 'constructor', + 'let', 'proof', 'of', 'then', 'static', 'where', '_', 'with', + 'pattern', 'term', 'syntax', 'prefix', + 'postulate', 'parameters', 'record', 'dsl', 'impossible', 'implicit', + 'tactics', 'intros', 'intro', 'compute', 'refine', 'exact', 'trivial') + + ascii = ('NUL', 'SOH', '[SE]TX', 'EOT', 'ENQ', 'ACK', + 'BEL', 'BS', 'HT', 'LF', 'VT', 'FF', 'CR', 'S[OI]', 'DLE', + 'DC[1-4]', 'NAK', 'SYN', 'ETB', 'CAN', + 'EM', 'SUB', 'ESC', '[FGRU]S', 'SP', 'DEL') + + directives = ('lib', 'link', 'flag', 'include', 'hide', 'freeze', 'access', + 'default', 'logging', 'dynamic', 'name', 'error_handlers', 'language') + + tokens = { + 'root': [ + # Comments + (r'^(\s*)(%({}))'.format('|'.join(directives)), + bygroups(Whitespace, Keyword.Reserved)), + (r'(\s*)(--(?![!#$%&*+./<=>?@^|_~:\\]).*?)$', bygroups(Whitespace, Comment.Single)), + (r'(\s*)(\|{3}.*?)$', bygroups(Whitespace, Comment.Single)), + (r'(\s*)(\{-)', bygroups(Whitespace, Comment.Multiline), 'comment'), + # Declaration + (r'^(\s*)([^\s(){}]+)(\s*)(:)(\s*)', + bygroups(Whitespace, Name.Function, Whitespace, Operator.Word, Whitespace)), + # Identifiers + (r'\b({})(?!\')\b'.format('|'.join(reserved)), Keyword.Reserved), + (r'(import|module)(\s+)', bygroups(Keyword.Reserved, Whitespace), 'module'), + (r"('')?[A-Z][\w\']*", Keyword.Type), + (r'[a-z][\w\']*', Text), + # Special Symbols + (r'(<-|::|->|=>|=)', Operator.Word), # specials + (r'([(){}\[\]:!#$%&*+.\\/<=>?@^|~-]+)', Operator.Word), # specials + # Numbers + (r'\d+[eE][+-]?\d+', Number.Float), + (r'\d+\.\d+([eE][+-]?\d+)?', Number.Float), + (r'0[xX][\da-fA-F]+', Number.Hex), + (r'\d+', Number.Integer), + # Strings + (r"'", String.Char, 'character'), + (r'"', String, 'string'), + (r'[^\s(){}]+', Text), + (r'\s+?', Whitespace), # Whitespace + ], + 'module': [ + (r'\s+', Whitespace), + (r'([A-Z][\w.]*)(\s+)(\()', + bygroups(Name.Namespace, Whitespace, Punctuation), 'funclist'), + (r'[A-Z][\w.]*', Name.Namespace, '#pop'), + ], + 'funclist': [ + (r'\s+', Whitespace), + (r'[A-Z]\w*', Keyword.Type), + (r'(_[\w\']+|[a-z][\w\']*)', Name.Function), + (r'--.*$', Comment.Single), + (r'\{-', Comment.Multiline, 'comment'), + (r',', Punctuation), + (r'[:!#$%&*+.\\/<=>?@^|~-]+', Operator), + # (HACK, but it makes sense to push two instances, believe me) + (r'\(', Punctuation, ('funclist', 'funclist')), + (r'\)', Punctuation, '#pop:2'), + ], + # NOTE: the next four states are shared in the AgdaLexer; make sure + # any change is compatible with Agda as well or copy over and change + 'comment': [ + # Multiline Comments + (r'[^-{}]+', Comment.Multiline), + (r'\{-', Comment.Multiline, '#push'), + (r'-\}', Comment.Multiline, '#pop'), + (r'[-{}]', Comment.Multiline), + ], + 'character': [ + # Allows multi-chars, incorrectly. + (r"[^\\']", String.Char), + (r"\\", String.Escape, 'escape'), + ("'", String.Char, '#pop'), + ], + 'string': [ + (r'[^\\"]+', String), + (r"\\", String.Escape, 'escape'), + ('"', String, '#pop'), + ], + 'escape': [ + (r'[abfnrtv"\'&\\]', String.Escape, '#pop'), + (r'\^[][A-Z@^_]', String.Escape, '#pop'), + ('|'.join(ascii), String.Escape, '#pop'), + (r'o[0-7]+', String.Escape, '#pop'), + (r'x[\da-fA-F]+', String.Escape, '#pop'), + (r'\d+', String.Escape, '#pop'), + (r'(\s+)(\\)', bygroups(Whitespace, String.Escape), '#pop') + ], + } + + +class AgdaLexer(RegexLexer): + """ + For the Agda dependently typed functional programming language and + proof assistant. + """ + + name = 'Agda' + url = 'http://wiki.portal.chalmers.se/agda/pmwiki.php' + aliases = ['agda'] + filenames = ['*.agda'] + mimetypes = ['text/x-agda'] + version_added = '2.0' + + reserved = ( + 'abstract', 'codata', 'coinductive', 'constructor', 'data', 'do', + 'eta-equality', 'field', 'forall', 'hiding', 'in', 'inductive', 'infix', + 'infixl', 'infixr', 'instance', 'interleaved', 'let', 'macro', 'mutual', + 'no-eta-equality', 'opaque', 'open', 'overlap', 'pattern', 'postulate', 'primitive', + 'private', 'quote', 'quoteTerm', 'record', 'renaming', 'rewrite', + 'syntax', 'tactic', 'unfolding', 'unquote', 'unquoteDecl', 'unquoteDef', 'using', + 'variable', 'where', 'with', + ) + + tokens = { + 'root': [ + # Declaration + (r'^(\s*)([^\s(){}]+)(\s*)(:)(\s*)', + bygroups(Whitespace, Name.Function, Whitespace, + Operator.Word, Whitespace)), + # Comments + (r'--(?![!#$%&*+./<=>?@^|_~:\\]).*?$', Comment.Single), + (r'\{-', Comment.Multiline, 'comment'), + # Holes + (r'\{!', Comment.Directive, 'hole'), + # Lexemes: + # Identifiers + (r'\b({})(?!\')\b'.format('|'.join(reserved)), Keyword.Reserved), + (r'(import|module)(\s+)', bygroups(Keyword.Reserved, Whitespace), + 'module'), + (r'\b(Set|Prop)[\u2080-\u2089]*\b', Keyword.Type), + # Special Symbols + (r'(\(|\)|\{|\})', Operator), + (r'(\.{1,3}|\||\u03BB|\u2200|\u2192|:|=|->)', Operator.Word), + # Numbers + (r'\d+[eE][+-]?\d+', Number.Float), + (r'\d+\.\d+([eE][+-]?\d+)?', Number.Float), + (r'0[xX][\da-fA-F]+', Number.Hex), + (r'\d+', Number.Integer), + # Strings + (r"'", String.Char, 'character'), + (r'"', String, 'string'), + (r'[^\s(){}]+', Text), + (r'\s+?', Whitespace), # Whitespace + ], + 'hole': [ + # Holes + (r'[^!{}]+', Comment.Directive), + (r'\{!', Comment.Directive, '#push'), + (r'!\}', Comment.Directive, '#pop'), + (r'[!{}]', Comment.Directive), + ], + 'module': [ + (r'\{-', Comment.Multiline, 'comment'), + (r'[a-zA-Z][\w.\']*', Name, '#pop'), + (r'[\W0-9_]+', Text) + ], + 'comment': HaskellLexer.tokens['comment'], + 'character': HaskellLexer.tokens['character'], + 'string': HaskellLexer.tokens['string'], + 'escape': HaskellLexer.tokens['escape'] + } + + +class CryptolLexer(RegexLexer): + """ + FIXME: A Cryptol2 lexer based on the lexemes defined in the Haskell 98 Report. + """ + name = 'Cryptol' + aliases = ['cryptol', 'cry'] + filenames = ['*.cry'] + mimetypes = ['text/x-cryptol'] + url = 'https://www.cryptol.net' + version_added = '2.0' + + reserved = ('Arith', 'Bit', 'Cmp', 'False', 'Inf', 'True', 'else', + 'export', 'extern', 'fin', 'if', 'import', 'inf', 'lg2', + 'max', 'min', 'module', 'newtype', 'pragma', 'property', + 'then', 'type', 'where', 'width') + ascii = ('NUL', 'SOH', '[SE]TX', 'EOT', 'ENQ', 'ACK', + 'BEL', 'BS', 'HT', 'LF', 'VT', 'FF', 'CR', 'S[OI]', 'DLE', + 'DC[1-4]', 'NAK', 'SYN', 'ETB', 'CAN', + 'EM', 'SUB', 'ESC', '[FGRU]S', 'SP', 'DEL') + + tokens = { + 'root': [ + # Whitespace: + (r'\s+', Whitespace), + # (r'--\s*|.*$', Comment.Doc), + (r'//.*$', Comment.Single), + (r'/\*', Comment.Multiline, 'comment'), + # Lexemes: + # Identifiers + (r'\bimport\b', Keyword.Reserved, 'import'), + (r'\bmodule\b', Keyword.Reserved, 'module'), + (r'\berror\b', Name.Exception), + (r'\b({})(?!\')\b'.format('|'.join(reserved)), Keyword.Reserved), + (r'^[_a-z][\w\']*', Name.Function), + (r"'?[_a-z][\w']*", Name), + (r"('')?[A-Z][\w\']*", Keyword.Type), + # Operators + (r'\\(?![:!#$%&*+.\\/<=>?@^|~-]+)', Name.Function), # lambda operator + (r'(<-|::|->|=>|=)(?![:!#$%&*+.\\/<=>?@^|~-]+)', Operator.Word), # specials + (r':[:!#$%&*+.\\/<=>?@^|~-]*', Keyword.Type), # Constructor operators + (r'[:!#$%&*+.\\/<=>?@^|~-]+', Operator), # Other operators + # Numbers + (r'\d+[eE][+-]?\d+', Number.Float), + (r'\d+\.\d+([eE][+-]?\d+)?', Number.Float), + (r'0[oO][0-7]+', Number.Oct), + (r'0[xX][\da-fA-F]+', Number.Hex), + (r'\d+', Number.Integer), + # Character/String Literals + (r"'", String.Char, 'character'), + (r'"', String, 'string'), + # Special + (r'\[\]', Keyword.Type), + (r'\(\)', Name.Builtin), + (r'[][(),;`{}]', Punctuation), + ], + 'import': [ + # Import statements + (r'\s+', Whitespace), + (r'"', String, 'string'), + # after "funclist" state + (r'\)', Punctuation, '#pop'), + (r'qualified\b', Keyword), + # import X as Y + (r'([A-Z][\w.]*)(\s+)(as)(\s+)([A-Z][\w.]*)', + bygroups(Name.Namespace, Whitespace, Keyword, Whitespace, Name), '#pop'), + # import X hiding (functions) + (r'([A-Z][\w.]*)(\s+)(hiding)(\s+)(\()', + bygroups(Name.Namespace, Whitespace, Keyword, Whitespace, Punctuation), 'funclist'), + # import X (functions) + (r'([A-Z][\w.]*)(\s+)(\()', + bygroups(Name.Namespace, Whitespace, Punctuation), 'funclist'), + # import X + (r'[\w.]+', Name.Namespace, '#pop'), + ], + 'module': [ + (r'\s+', Whitespace), + (r'([A-Z][\w.]*)(\s+)(\()', + bygroups(Name.Namespace, Whitespace, Punctuation), 'funclist'), + (r'[A-Z][\w.]*', Name.Namespace, '#pop'), + ], + 'funclist': [ + (r'\s+', Whitespace), + (r'[A-Z]\w*', Keyword.Type), + (r'(_[\w\']+|[a-z][\w\']*)', Name.Function), + # TODO: these don't match the comments in docs, remove. + # (r'--(?![!#$%&*+./<=>?@^|_~:\\]).*?$', Comment.Single), + # (r'{-', Comment.Multiline, 'comment'), + (r',', Punctuation), + (r'[:!#$%&*+.\\/<=>?@^|~-]+', Operator), + # (HACK, but it makes sense to push two instances, believe me) + (r'\(', Punctuation, ('funclist', 'funclist')), + (r'\)', Punctuation, '#pop:2'), + ], + 'comment': [ + # Multiline Comments + (r'[^/*]+', Comment.Multiline), + (r'/\*', Comment.Multiline, '#push'), + (r'\*/', Comment.Multiline, '#pop'), + (r'[*/]', Comment.Multiline), + ], + 'character': [ + # Allows multi-chars, incorrectly. + (r"[^\\']'", String.Char, '#pop'), + (r"\\", String.Escape, 'escape'), + ("'", String.Char, '#pop'), + ], + 'string': [ + (r'[^\\"]+', String), + (r"\\", String.Escape, 'escape'), + ('"', String, '#pop'), + ], + 'escape': [ + (r'[abfnrtv"\'&\\]', String.Escape, '#pop'), + (r'\^[][A-Z@^_]', String.Escape, '#pop'), + ('|'.join(ascii), String.Escape, '#pop'), + (r'o[0-7]+', String.Escape, '#pop'), + (r'x[\da-fA-F]+', String.Escape, '#pop'), + (r'\d+', String.Escape, '#pop'), + (r'(\s+)(\\)', bygroups(Whitespace, String.Escape), '#pop'), + ], + } + + EXTRA_KEYWORDS = {'join', 'split', 'reverse', 'transpose', 'width', + 'length', 'tail', '<<', '>>', '<<<', '>>>', 'const', + 'reg', 'par', 'seq', 'ASSERT', 'undefined', 'error', + 'trace'} + + def get_tokens_unprocessed(self, text): + stack = ['root'] + for index, token, value in \ + RegexLexer.get_tokens_unprocessed(self, text, stack): + if token is Name and value in self.EXTRA_KEYWORDS: + yield index, Name.Builtin, value + else: + yield index, token, value + + +class LiterateLexer(Lexer): + """ + Base class for lexers of literate file formats based on LaTeX or Bird-style + (prefixing each code line with ">"). + + Additional options accepted: + + `litstyle` + If given, must be ``"bird"`` or ``"latex"``. If not given, the style + is autodetected: if the first non-whitespace character in the source + is a backslash or percent character, LaTeX is assumed, else Bird. + """ + + bird_re = re.compile(r'(>[ \t]*)(.*\n)') + + def __init__(self, baselexer, **options): + self.baselexer = baselexer + Lexer.__init__(self, **options) + + def get_tokens_unprocessed(self, text): + style = self.options.get('litstyle') + if style is None: + style = (text.lstrip()[0:1] in '%\\') and 'latex' or 'bird' + + code = '' + insertions = [] + if style == 'bird': + # bird-style + for match in line_re.finditer(text): + line = match.group() + m = self.bird_re.match(line) + if m: + insertions.append((len(code), + [(0, Comment.Special, m.group(1))])) + code += m.group(2) + else: + insertions.append((len(code), [(0, Text, line)])) + else: + # latex-style + from pygments.lexers.markup import TexLexer + lxlexer = TexLexer(**self.options) + codelines = 0 + latex = '' + for match in line_re.finditer(text): + line = match.group() + if codelines: + if line.lstrip().startswith('\\end{code}'): + codelines = 0 + latex += line + else: + code += line + elif line.lstrip().startswith('\\begin{code}'): + codelines = 1 + latex += line + insertions.append((len(code), + list(lxlexer.get_tokens_unprocessed(latex)))) + latex = '' + else: + latex += line + insertions.append((len(code), + list(lxlexer.get_tokens_unprocessed(latex)))) + yield from do_insertions(insertions, self.baselexer.get_tokens_unprocessed(code)) + + +class LiterateHaskellLexer(LiterateLexer): + """ + For Literate Haskell (Bird-style or LaTeX) source. + + Additional options accepted: + + `litstyle` + If given, must be ``"bird"`` or ``"latex"``. If not given, the style + is autodetected: if the first non-whitespace character in the source + is a backslash or percent character, LaTeX is assumed, else Bird. + """ + name = 'Literate Haskell' + aliases = ['literate-haskell', 'lhaskell', 'lhs'] + filenames = ['*.lhs'] + mimetypes = ['text/x-literate-haskell'] + url = 'https://wiki.haskell.org/Literate_programming' + version_added = '0.9' + + def __init__(self, **options): + hslexer = HaskellLexer(**options) + LiterateLexer.__init__(self, hslexer, **options) + + +class LiterateIdrisLexer(LiterateLexer): + """ + For Literate Idris (Bird-style or LaTeX) source. + + Additional options accepted: + + `litstyle` + If given, must be ``"bird"`` or ``"latex"``. If not given, the style + is autodetected: if the first non-whitespace character in the source + is a backslash or percent character, LaTeX is assumed, else Bird. + """ + name = 'Literate Idris' + aliases = ['literate-idris', 'lidris', 'lidr'] + filenames = ['*.lidr'] + mimetypes = ['text/x-literate-idris'] + url = 'https://idris2.readthedocs.io/en/latest/reference/literate.html' + version_added = '2.0' + + def __init__(self, **options): + hslexer = IdrisLexer(**options) + LiterateLexer.__init__(self, hslexer, **options) + + +class LiterateAgdaLexer(LiterateLexer): + """ + For Literate Agda source. + + Additional options accepted: + + `litstyle` + If given, must be ``"bird"`` or ``"latex"``. If not given, the style + is autodetected: if the first non-whitespace character in the source + is a backslash or percent character, LaTeX is assumed, else Bird. + """ + name = 'Literate Agda' + aliases = ['literate-agda', 'lagda'] + filenames = ['*.lagda'] + mimetypes = ['text/x-literate-agda'] + url = 'https://agda.readthedocs.io/en/latest/tools/literate-programming.html' + version_added = '2.0' + + def __init__(self, **options): + agdalexer = AgdaLexer(**options) + LiterateLexer.__init__(self, agdalexer, litstyle='latex', **options) + + +class LiterateCryptolLexer(LiterateLexer): + """ + For Literate Cryptol (Bird-style or LaTeX) source. + + Additional options accepted: + + `litstyle` + If given, must be ``"bird"`` or ``"latex"``. If not given, the style + is autodetected: if the first non-whitespace character in the source + is a backslash or percent character, LaTeX is assumed, else Bird. + """ + name = 'Literate Cryptol' + aliases = ['literate-cryptol', 'lcryptol', 'lcry'] + filenames = ['*.lcry'] + mimetypes = ['text/x-literate-cryptol'] + url = 'https://www.cryptol.net' + version_added = '2.0' + + def __init__(self, **options): + crylexer = CryptolLexer(**options) + LiterateLexer.__init__(self, crylexer, **options) + + +class KokaLexer(RegexLexer): + """ + Lexer for the Koka language. + """ + + name = 'Koka' + url = 'https://koka-lang.github.io/koka/doc/index.html' + aliases = ['koka'] + filenames = ['*.kk', '*.kki'] + mimetypes = ['text/x-koka'] + version_added = '1.6' + + keywords = [ + 'infix', 'infixr', 'infixl', + 'type', 'cotype', 'rectype', 'alias', + 'struct', 'con', + 'fun', 'function', 'val', 'var', + 'external', + 'if', 'then', 'else', 'elif', 'return', 'match', + 'private', 'public', 'private', + 'module', 'import', 'as', + 'include', 'inline', + 'rec', + 'try', 'yield', 'enum', + 'interface', 'instance', + ] + + # keywords that are followed by a type + typeStartKeywords = [ + 'type', 'cotype', 'rectype', 'alias', 'struct', 'enum', + ] + + # keywords valid in a type + typekeywords = [ + 'forall', 'exists', 'some', 'with', + ] + + # builtin names and special names + builtin = [ + 'for', 'while', 'repeat', + 'foreach', 'foreach-indexed', + 'error', 'catch', 'finally', + 'cs', 'js', 'file', 'ref', 'assigned', + ] + + # symbols that can be in an operator + symbols = r'[$%&*+@!/\\^~=.:\-?|<>]+' + + # symbol boundary: an operator keyword should not be followed by any of these + sboundary = '(?!' + symbols + ')' + + # name boundary: a keyword should not be followed by any of these + boundary = r'(?![\w/])' + + # koka token abstractions + tokenType = Name.Attribute + tokenTypeDef = Name.Class + tokenConstructor = Generic.Emph + + # main lexer + tokens = { + 'root': [ + include('whitespace'), + + # go into type mode + (r'::?' + sboundary, tokenType, 'type'), + (r'(alias)(\s+)([a-z]\w*)?', bygroups(Keyword, Whitespace, tokenTypeDef), + 'alias-type'), + (r'(struct)(\s+)([a-z]\w*)?', bygroups(Keyword, Whitespace, tokenTypeDef), + 'struct-type'), + ((r'({})'.format('|'.join(typeStartKeywords))) + + r'(\s+)([a-z]\w*)?', bygroups(Keyword, Whitespace, tokenTypeDef), + 'type'), + + # special sequences of tokens (we use ?: for non-capturing group as + # required by 'bygroups') + (r'(module)(\s+)(interface(?=\s))?(\s+)?((?:[a-z]\w*/)*[a-z]\w*)', + bygroups(Keyword, Whitespace, Keyword, Whitespace, Name.Namespace)), + (r'(import)(\s+)((?:[a-z]\w*/)*[a-z]\w*)' + r'(?:(\s*)(=)(\s*)(qualified)?(\s*)' + r'((?:[a-z]\w*/)*[a-z]\w*))?', + bygroups(Keyword, Whitespace, Name.Namespace, Whitespace, Keyword, Whitespace, + Keyword, Whitespace, Name.Namespace)), + + (r'^(public|private)?(\s+)?(function|fun|val)' + r'(\s+)([a-z]\w*|\((?:' + symbols + r'|/)\))', + bygroups(Keyword, Whitespace, Keyword, Whitespace, Name.Function)), + (r'^(?:(public|private)(?=\s+external))?((?|[=.]' + sboundary, Keyword), + + # names + (r'((?:[a-z]\w*/)*)([A-Z]\w*)', + bygroups(Name.Namespace, tokenConstructor)), + (r'((?:[a-z]\w*/)*)([a-z]\w*)', bygroups(Name.Namespace, Name)), + (r'((?:[a-z]\w*/)*)(\((?:' + symbols + r'|/)\))', + bygroups(Name.Namespace, Name)), + (r'_\w*', Name.Variable), + + # literal string + (r'@"', String.Double, 'litstring'), + + # operators + (symbols + "|/(?![*/])", Operator), + (r'`', Operator), + (r'[{}()\[\];,]', Punctuation), + + # literals. No check for literal characters with len > 1 + (r'[0-9]+\.[0-9]+([eE][\-+]?[0-9]+)?', Number.Float), + (r'0[xX][0-9a-fA-F]+', Number.Hex), + (r'[0-9]+', Number.Integer), + + (r"'", String.Char, 'char'), + (r'"', String.Double, 'string'), + ], + + # type started by alias + 'alias-type': [ + (r'=', Keyword), + include('type') + ], + + # type started by struct + 'struct-type': [ + (r'(?=\((?!,*\)))', Punctuation, '#pop'), + include('type') + ], + + # type started by colon + 'type': [ + (r'[(\[<]', tokenType, 'type-nested'), + include('type-content') + ], + + # type nested in brackets: can contain parameters, comma etc. + 'type-nested': [ + (r'[)\]>]', tokenType, '#pop'), + (r'[(\[<]', tokenType, 'type-nested'), + (r',', tokenType), + (r'([a-z]\w*)(\s*)(:)(?!:)', + bygroups(Name, Whitespace, tokenType)), # parameter name + include('type-content') + ], + + # shared contents of a type + 'type-content': [ + include('whitespace'), + + # keywords + (r'({})'.format('|'.join(typekeywords)) + boundary, Keyword), + (r'(?=(({})'.format('|'.join(keywords)) + boundary + '))', + Keyword, '#pop'), # need to match because names overlap... + + # kinds + (r'[EPHVX]' + boundary, tokenType), + + # type names + (r'[a-z][0-9]*(?![\w/])', tokenType), + (r'_\w*', tokenType.Variable), # Generic.Emph + (r'((?:[a-z]\w*/)*)([A-Z]\w*)', + bygroups(Name.Namespace, tokenType)), + (r'((?:[a-z]\w*/)*)([a-z]\w+)', + bygroups(Name.Namespace, tokenType)), + + # type keyword operators + (r'::|->|[.:|]', tokenType), + + # catchall + default('#pop') + ], + + # comments and literals + 'whitespace': [ + (r'(\n\s*)(#.*)$', bygroups(Whitespace, Comment.Preproc)), + (r'\s+', Whitespace), + (r'/\*', Comment.Multiline, 'comment'), + (r'//.*$', Comment.Single) + ], + 'comment': [ + (r'[^/*]+', Comment.Multiline), + (r'/\*', Comment.Multiline, '#push'), + (r'\*/', Comment.Multiline, '#pop'), + (r'[*/]', Comment.Multiline), + ], + 'litstring': [ + (r'[^"]+', String.Double), + (r'""', String.Escape), + (r'"', String.Double, '#pop'), + ], + 'string': [ + (r'[^\\"\n]+', String.Double), + include('escape-sequence'), + (r'["\n]', String.Double, '#pop'), + ], + 'char': [ + (r'[^\\\'\n]+', String.Char), + include('escape-sequence'), + (r'[\'\n]', String.Char, '#pop'), + ], + 'escape-sequence': [ + (r'\\[nrt\\"\']', String.Escape), + (r'\\x[0-9a-fA-F]{2}', String.Escape), + (r'\\u[0-9a-fA-F]{4}', String.Escape), + # Yes, \U literals are 6 hex digits. + (r'\\U[0-9a-fA-F]{6}', String.Escape) + ] + } diff --git a/pygments/lexers/haxe.py b/pygments/lexers/haxe.py new file mode 100644 index 0000000000000000000000000000000000000000..66bd2a350f327ef59b79048b31630c1373f73375 --- /dev/null +++ b/pygments/lexers/haxe.py @@ -0,0 +1,935 @@ +""" + pygments.lexers.haxe + ~~~~~~~~~~~~~~~~~~~~ + + Lexers for Haxe and related stuff. + + :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS. + :license: BSD, see LICENSE for details. +""" + +import re + +from pygments.lexer import ExtendedRegexLexer, RegexLexer, include, bygroups, \ + default +from pygments.token import Text, Comment, Operator, Keyword, Name, String, \ + Number, Punctuation, Generic, Whitespace + +__all__ = ['HaxeLexer', 'HxmlLexer'] + + +class HaxeLexer(ExtendedRegexLexer): + """ + For Haxe source code. + """ + + name = 'Haxe' + url = 'http://haxe.org/' + aliases = ['haxe', 'hxsl', 'hx'] + filenames = ['*.hx', '*.hxsl'] + mimetypes = ['text/haxe', 'text/x-haxe', 'text/x-hx'] + version_added = '1.3' + + # keywords extracted from lexer.mll in the haxe compiler source + keyword = (r'(?:function|class|static|var|if|else|while|do|for|' + r'break|return|continue|extends|implements|import|' + r'switch|case|default|public|private|try|untyped|' + r'catch|new|this|throw|extern|enum|in|interface|' + r'cast|override|dynamic|typedef|package|' + r'inline|using|null|true|false|abstract)\b') + + # idtype in lexer.mll + typeid = r'_*[A-Z]\w*' + + # combined ident and dollar and idtype + ident = r'(?:_*[a-z]\w*|_+[0-9]\w*|' + typeid + r'|_+|\$\w+)' + + binop = (r'(?:%=|&=|\|=|\^=|\+=|\-=|\*=|/=|<<=|>\s*>\s*=|>\s*>\s*>\s*=|==|' + r'!=|<=|>\s*=|&&|\|\||<<|>>>|>\s*>|\.\.\.|<|>|%|&|\||\^|\+|\*|' + r'/|\-|=>|=)') + + # ident except keywords + ident_no_keyword = r'(?!' + keyword + ')' + ident + + flags = re.DOTALL | re.MULTILINE + + preproc_stack = [] + + def preproc_callback(self, match, ctx): + proc = match.group(2) + + if proc == 'if': + # store the current stack + self.preproc_stack.append(ctx.stack[:]) + elif proc in ['else', 'elseif']: + # restore the stack back to right before #if + if self.preproc_stack: + ctx.stack = self.preproc_stack[-1][:] + elif proc == 'end': + # remove the saved stack of previous #if + if self.preproc_stack: + self.preproc_stack.pop() + + # #if and #elseif should follow by an expr + if proc in ['if', 'elseif']: + ctx.stack.append('preproc-expr') + + # #error can be optionally follow by the error msg + if proc in ['error']: + ctx.stack.append('preproc-error') + + yield match.start(), Comment.Preproc, '#' + proc + ctx.pos = match.end() + + tokens = { + 'root': [ + include('spaces'), + include('meta'), + (r'(?:package)\b', Keyword.Namespace, ('semicolon', 'package')), + (r'(?:import)\b', Keyword.Namespace, ('semicolon', 'import')), + (r'(?:using)\b', Keyword.Namespace, ('semicolon', 'using')), + (r'(?:extern|private)\b', Keyword.Declaration), + (r'(?:abstract)\b', Keyword.Declaration, 'abstract'), + (r'(?:class|interface)\b', Keyword.Declaration, 'class'), + (r'(?:enum)\b', Keyword.Declaration, 'enum'), + (r'(?:typedef)\b', Keyword.Declaration, 'typedef'), + + # top-level expression + # although it is not supported in haxe, but it is common to write + # expression in web pages the positive lookahead here is to prevent + # an infinite loop at the EOF + (r'(?=.)', Text, 'expr-statement'), + ], + + # space/tab/comment/preproc + 'spaces': [ + (r'\s+', Whitespace), + (r'//[^\n\r]*', Comment.Single), + (r'/\*.*?\*/', Comment.Multiline), + (r'(#)(if|elseif|else|end|error)\b', preproc_callback), + ], + + 'string-single-interpol': [ + (r'\$\{', String.Interpol, ('string-interpol-close', 'expr')), + (r'\$\$', String.Escape), + (r'\$(?=' + ident + ')', String.Interpol, 'ident'), + include('string-single'), + ], + + 'string-single': [ + (r"'", String.Single, '#pop'), + (r'\\.', String.Escape), + (r'.', String.Single), + ], + + 'string-double': [ + (r'"', String.Double, '#pop'), + (r'\\.', String.Escape), + (r'.', String.Double), + ], + + 'string-interpol-close': [ + (r'\$'+ident, String.Interpol), + (r'\}', String.Interpol, '#pop'), + ], + + 'package': [ + include('spaces'), + (ident, Name.Namespace), + (r'\.', Punctuation, 'import-ident'), + default('#pop'), + ], + + 'import': [ + include('spaces'), + (ident, Name.Namespace), + (r'\*', Keyword), # wildcard import + (r'\.', Punctuation, 'import-ident'), + (r'in', Keyword.Namespace, 'ident'), + default('#pop'), + ], + + 'import-ident': [ + include('spaces'), + (r'\*', Keyword, '#pop'), # wildcard import + (ident, Name.Namespace, '#pop'), + ], + + 'using': [ + include('spaces'), + (ident, Name.Namespace), + (r'\.', Punctuation, 'import-ident'), + default('#pop'), + ], + + 'preproc-error': [ + (r'\s+', Whitespace), + (r"'", String.Single, ('#pop', 'string-single')), + (r'"', String.Double, ('#pop', 'string-double')), + default('#pop'), + ], + + 'preproc-expr': [ + (r'\s+', Whitespace), + (r'\!', Comment.Preproc), + (r'\(', Comment.Preproc, ('#pop', 'preproc-parenthesis')), + + (ident, Comment.Preproc, '#pop'), + + # Float + (r'\.[0-9]+', Number.Float), + (r'[0-9]+[eE][+\-]?[0-9]+', Number.Float), + (r'[0-9]+\.[0-9]*[eE][+\-]?[0-9]+', Number.Float), + (r'[0-9]+\.[0-9]+', Number.Float), + (r'[0-9]+\.(?!' + ident + r'|\.\.)', Number.Float), + + # Int + (r'0x[0-9a-fA-F]+', Number.Hex), + (r'[0-9]+', Number.Integer), + + # String + (r"'", String.Single, ('#pop', 'string-single')), + (r'"', String.Double, ('#pop', 'string-double')), + ], + + 'preproc-parenthesis': [ + (r'\s+', Whitespace), + (r'\)', Comment.Preproc, '#pop'), + default('preproc-expr-in-parenthesis'), + ], + + 'preproc-expr-chain': [ + (r'\s+', Whitespace), + (binop, Comment.Preproc, ('#pop', 'preproc-expr-in-parenthesis')), + default('#pop'), + ], + + # same as 'preproc-expr' but able to chain 'preproc-expr-chain' + 'preproc-expr-in-parenthesis': [ + (r'\s+', Whitespace), + (r'\!', Comment.Preproc), + (r'\(', Comment.Preproc, + ('#pop', 'preproc-expr-chain', 'preproc-parenthesis')), + + (ident, Comment.Preproc, ('#pop', 'preproc-expr-chain')), + + # Float + (r'\.[0-9]+', Number.Float, ('#pop', 'preproc-expr-chain')), + (r'[0-9]+[eE][+\-]?[0-9]+', Number.Float, ('#pop', 'preproc-expr-chain')), + (r'[0-9]+\.[0-9]*[eE][+\-]?[0-9]+', Number.Float, ('#pop', 'preproc-expr-chain')), + (r'[0-9]+\.[0-9]+', Number.Float, ('#pop', 'preproc-expr-chain')), + (r'[0-9]+\.(?!' + ident + r'|\.\.)', Number.Float, ('#pop', 'preproc-expr-chain')), + + # Int + (r'0x[0-9a-fA-F]+', Number.Hex, ('#pop', 'preproc-expr-chain')), + (r'[0-9]+', Number.Integer, ('#pop', 'preproc-expr-chain')), + + # String + (r"'", String.Single, + ('#pop', 'preproc-expr-chain', 'string-single')), + (r'"', String.Double, + ('#pop', 'preproc-expr-chain', 'string-double')), + ], + + 'abstract': [ + include('spaces'), + default(('#pop', 'abstract-body', 'abstract-relation', + 'abstract-opaque', 'type-param-constraint', 'type-name')), + ], + + 'abstract-body': [ + include('spaces'), + (r'\{', Punctuation, ('#pop', 'class-body')), + ], + + 'abstract-opaque': [ + include('spaces'), + (r'\(', Punctuation, ('#pop', 'parenthesis-close', 'type')), + default('#pop'), + ], + + 'abstract-relation': [ + include('spaces'), + (r'(?:to|from)', Keyword.Declaration, 'type'), + (r',', Punctuation), + default('#pop'), + ], + + 'meta': [ + include('spaces'), + (r'@', Name.Decorator, ('meta-body', 'meta-ident', 'meta-colon')), + ], + + # optional colon + 'meta-colon': [ + include('spaces'), + (r':', Name.Decorator, '#pop'), + default('#pop'), + ], + + # same as 'ident' but set token as Name.Decorator instead of Name + 'meta-ident': [ + include('spaces'), + (ident, Name.Decorator, '#pop'), + ], + + 'meta-body': [ + include('spaces'), + (r'\(', Name.Decorator, ('#pop', 'meta-call')), + default('#pop'), + ], + + 'meta-call': [ + include('spaces'), + (r'\)', Name.Decorator, '#pop'), + default(('#pop', 'meta-call-sep', 'expr')), + ], + + 'meta-call-sep': [ + include('spaces'), + (r'\)', Name.Decorator, '#pop'), + (r',', Punctuation, ('#pop', 'meta-call')), + ], + + 'typedef': [ + include('spaces'), + default(('#pop', 'typedef-body', 'type-param-constraint', + 'type-name')), + ], + + 'typedef-body': [ + include('spaces'), + (r'=', Operator, ('#pop', 'optional-semicolon', 'type')), + ], + + 'enum': [ + include('spaces'), + default(('#pop', 'enum-body', 'bracket-open', + 'type-param-constraint', 'type-name')), + ], + + 'enum-body': [ + include('spaces'), + include('meta'), + (r'\}', Punctuation, '#pop'), + (ident_no_keyword, Name, ('enum-member', 'type-param-constraint')), + ], + + 'enum-member': [ + include('spaces'), + (r'\(', Punctuation, + ('#pop', 'semicolon', 'flag', 'function-param')), + default(('#pop', 'semicolon', 'flag')), + ], + + 'class': [ + include('spaces'), + default(('#pop', 'class-body', 'bracket-open', 'extends', + 'type-param-constraint', 'type-name')), + ], + + 'extends': [ + include('spaces'), + (r'(?:extends|implements)\b', Keyword.Declaration, 'type'), + (r',', Punctuation), # the comma is made optional here, since haxe2 + # requires the comma but haxe3 does not allow it + default('#pop'), + ], + + 'bracket-open': [ + include('spaces'), + (r'\{', Punctuation, '#pop'), + ], + + 'bracket-close': [ + include('spaces'), + (r'\}', Punctuation, '#pop'), + ], + + 'class-body': [ + include('spaces'), + include('meta'), + (r'\}', Punctuation, '#pop'), + (r'(?:static|public|private|override|dynamic|inline|macro)\b', + Keyword.Declaration), + default('class-member'), + ], + + 'class-member': [ + include('spaces'), + (r'(var)\b', Keyword.Declaration, + ('#pop', 'optional-semicolon', 'var')), + (r'(function)\b', Keyword.Declaration, + ('#pop', 'optional-semicolon', 'class-method')), + ], + + # local function, anonymous or not + 'function-local': [ + include('spaces'), + (ident_no_keyword, Name.Function, + ('#pop', 'optional-expr', 'flag', 'function-param', + 'parenthesis-open', 'type-param-constraint')), + default(('#pop', 'optional-expr', 'flag', 'function-param', + 'parenthesis-open', 'type-param-constraint')), + ], + + 'optional-expr': [ + include('spaces'), + include('expr'), + default('#pop'), + ], + + 'class-method': [ + include('spaces'), + (ident, Name.Function, ('#pop', 'optional-expr', 'flag', + 'function-param', 'parenthesis-open', + 'type-param-constraint')), + ], + + # function arguments + 'function-param': [ + include('spaces'), + (r'\)', Punctuation, '#pop'), + (r'\?', Punctuation), + (ident_no_keyword, Name, + ('#pop', 'function-param-sep', 'assign', 'flag')), + ], + + 'function-param-sep': [ + include('spaces'), + (r'\)', Punctuation, '#pop'), + (r',', Punctuation, ('#pop', 'function-param')), + ], + + 'prop-get-set': [ + include('spaces'), + (r'\(', Punctuation, ('#pop', 'parenthesis-close', + 'prop-get-set-opt', 'comma', 'prop-get-set-opt')), + default('#pop'), + ], + + 'prop-get-set-opt': [ + include('spaces'), + (r'(?:default|null|never|dynamic|get|set)\b', Keyword, '#pop'), + (ident_no_keyword, Text, '#pop'), # custom getter/setter + ], + + 'expr-statement': [ + include('spaces'), + # makes semicolon optional here, just to avoid checking the last + # one is bracket or not. + default(('#pop', 'optional-semicolon', 'expr')), + ], + + 'expr': [ + include('spaces'), + (r'@', Name.Decorator, ('#pop', 'optional-expr', 'meta-body', + 'meta-ident', 'meta-colon')), + (r'(?:\+\+|\-\-|~(?!/)|!|\-)', Operator), + (r'\(', Punctuation, ('#pop', 'expr-chain', 'parenthesis')), + (r'(?:static|public|private|override|dynamic|inline)\b', + Keyword.Declaration), + (r'(?:function)\b', Keyword.Declaration, ('#pop', 'expr-chain', + 'function-local')), + (r'\{', Punctuation, ('#pop', 'expr-chain', 'bracket')), + (r'(?:true|false|null)\b', Keyword.Constant, ('#pop', 'expr-chain')), + (r'(?:this)\b', Keyword, ('#pop', 'expr-chain')), + (r'(?:cast)\b', Keyword, ('#pop', 'expr-chain', 'cast')), + (r'(?:try)\b', Keyword, ('#pop', 'catch', 'expr')), + (r'(?:var)\b', Keyword.Declaration, ('#pop', 'var')), + (r'(?:new)\b', Keyword, ('#pop', 'expr-chain', 'new')), + (r'(?:switch)\b', Keyword, ('#pop', 'switch')), + (r'(?:if)\b', Keyword, ('#pop', 'if')), + (r'(?:do)\b', Keyword, ('#pop', 'do')), + (r'(?:while)\b', Keyword, ('#pop', 'while')), + (r'(?:for)\b', Keyword, ('#pop', 'for')), + (r'(?:untyped|throw)\b', Keyword), + (r'(?:return)\b', Keyword, ('#pop', 'optional-expr')), + (r'(?:macro)\b', Keyword, ('#pop', 'macro')), + (r'(?:continue|break)\b', Keyword, '#pop'), + (r'(?:\$\s*[a-z]\b|\$(?!'+ident+'))', Name, ('#pop', 'dollar')), + (ident_no_keyword, Name, ('#pop', 'expr-chain')), + + # Float + (r'\.[0-9]+', Number.Float, ('#pop', 'expr-chain')), + (r'[0-9]+[eE][+\-]?[0-9]+', Number.Float, ('#pop', 'expr-chain')), + (r'[0-9]+\.[0-9]*[eE][+\-]?[0-9]+', Number.Float, ('#pop', 'expr-chain')), + (r'[0-9]+\.[0-9]+', Number.Float, ('#pop', 'expr-chain')), + (r'[0-9]+\.(?!' + ident + r'|\.\.)', Number.Float, ('#pop', 'expr-chain')), + + # Int + (r'0x[0-9a-fA-F]+', Number.Hex, ('#pop', 'expr-chain')), + (r'[0-9]+', Number.Integer, ('#pop', 'expr-chain')), + + # String + (r"'", String.Single, ('#pop', 'expr-chain', 'string-single-interpol')), + (r'"', String.Double, ('#pop', 'expr-chain', 'string-double')), + + # EReg + (r'~/(\\\\|\\[^\\]|[^/\\\n])*/[gimsu]*', String.Regex, ('#pop', 'expr-chain')), + + # Array + (r'\[', Punctuation, ('#pop', 'expr-chain', 'array-decl')), + ], + + 'expr-chain': [ + include('spaces'), + (r'(?:\+\+|\-\-)', Operator), + (binop, Operator, ('#pop', 'expr')), + (r'(?:in)\b', Keyword, ('#pop', 'expr')), + (r'\?', Operator, ('#pop', 'expr', 'ternary', 'expr')), + (r'(\.)(' + ident_no_keyword + ')', bygroups(Punctuation, Name)), + (r'\[', Punctuation, 'array-access'), + (r'\(', Punctuation, 'call'), + default('#pop'), + ], + + # macro reification + 'macro': [ + include('spaces'), + include('meta'), + (r':', Punctuation, ('#pop', 'type')), + + (r'(?:extern|private)\b', Keyword.Declaration), + (r'(?:abstract)\b', Keyword.Declaration, ('#pop', 'optional-semicolon', 'abstract')), + (r'(?:class|interface)\b', Keyword.Declaration, ('#pop', 'optional-semicolon', 'macro-class')), + (r'(?:enum)\b', Keyword.Declaration, ('#pop', 'optional-semicolon', 'enum')), + (r'(?:typedef)\b', Keyword.Declaration, ('#pop', 'optional-semicolon', 'typedef')), + + default(('#pop', 'expr')), + ], + + 'macro-class': [ + (r'\{', Punctuation, ('#pop', 'class-body')), + include('class') + ], + + # cast can be written as "cast expr" or "cast(expr, type)" + 'cast': [ + include('spaces'), + (r'\(', Punctuation, ('#pop', 'parenthesis-close', + 'cast-type', 'expr')), + default(('#pop', 'expr')), + ], + + # optionally give a type as the 2nd argument of cast() + 'cast-type': [ + include('spaces'), + (r',', Punctuation, ('#pop', 'type')), + default('#pop'), + ], + + 'catch': [ + include('spaces'), + (r'(?:catch)\b', Keyword, ('expr', 'function-param', + 'parenthesis-open')), + default('#pop'), + ], + + # do-while loop + 'do': [ + include('spaces'), + default(('#pop', 'do-while', 'expr')), + ], + + # the while after do + 'do-while': [ + include('spaces'), + (r'(?:while)\b', Keyword, ('#pop', 'parenthesis', + 'parenthesis-open')), + ], + + 'while': [ + include('spaces'), + (r'\(', Punctuation, ('#pop', 'expr', 'parenthesis')), + ], + + 'for': [ + include('spaces'), + (r'\(', Punctuation, ('#pop', 'expr', 'parenthesis')), + ], + + 'if': [ + include('spaces'), + (r'\(', Punctuation, ('#pop', 'else', 'optional-semicolon', 'expr', + 'parenthesis')), + ], + + 'else': [ + include('spaces'), + (r'(?:else)\b', Keyword, ('#pop', 'expr')), + default('#pop'), + ], + + 'switch': [ + include('spaces'), + default(('#pop', 'switch-body', 'bracket-open', 'expr')), + ], + + 'switch-body': [ + include('spaces'), + (r'(?:case|default)\b', Keyword, ('case-block', 'case')), + (r'\}', Punctuation, '#pop'), + ], + + 'case': [ + include('spaces'), + (r':', Punctuation, '#pop'), + default(('#pop', 'case-sep', 'case-guard', 'expr')), + ], + + 'case-sep': [ + include('spaces'), + (r':', Punctuation, '#pop'), + (r',', Punctuation, ('#pop', 'case')), + ], + + 'case-guard': [ + include('spaces'), + (r'(?:if)\b', Keyword, ('#pop', 'parenthesis', 'parenthesis-open')), + default('#pop'), + ], + + # optional multiple expr under a case + 'case-block': [ + include('spaces'), + (r'(?!(?:case|default)\b|\})', Keyword, 'expr-statement'), + default('#pop'), + ], + + 'new': [ + include('spaces'), + default(('#pop', 'call', 'parenthesis-open', 'type')), + ], + + 'array-decl': [ + include('spaces'), + (r'\]', Punctuation, '#pop'), + default(('#pop', 'array-decl-sep', 'expr')), + ], + + 'array-decl-sep': [ + include('spaces'), + (r'\]', Punctuation, '#pop'), + (r',', Punctuation, ('#pop', 'array-decl')), + ], + + 'array-access': [ + include('spaces'), + default(('#pop', 'array-access-close', 'expr')), + ], + + 'array-access-close': [ + include('spaces'), + (r'\]', Punctuation, '#pop'), + ], + + 'comma': [ + include('spaces'), + (r',', Punctuation, '#pop'), + ], + + 'colon': [ + include('spaces'), + (r':', Punctuation, '#pop'), + ], + + 'semicolon': [ + include('spaces'), + (r';', Punctuation, '#pop'), + ], + + 'optional-semicolon': [ + include('spaces'), + (r';', Punctuation, '#pop'), + default('#pop'), + ], + + # identity that CAN be a Haxe keyword + 'ident': [ + include('spaces'), + (ident, Name, '#pop'), + ], + + 'dollar': [ + include('spaces'), + (r'\{', Punctuation, ('#pop', 'expr-chain', 'bracket-close', 'expr')), + default(('#pop', 'expr-chain')), + ], + + 'type-name': [ + include('spaces'), + (typeid, Name, '#pop'), + ], + + 'type-full-name': [ + include('spaces'), + (r'\.', Punctuation, 'ident'), + default('#pop'), + ], + + 'type': [ + include('spaces'), + (r'\?', Punctuation), + (ident, Name, ('#pop', 'type-check', 'type-full-name')), + (r'\{', Punctuation, ('#pop', 'type-check', 'type-struct')), + (r'\(', Punctuation, ('#pop', 'type-check', 'type-parenthesis')), + ], + + 'type-parenthesis': [ + include('spaces'), + default(('#pop', 'parenthesis-close', 'type')), + ], + + 'type-check': [ + include('spaces'), + (r'->', Punctuation, ('#pop', 'type')), + (r'<(?!=)', Punctuation, 'type-param'), + default('#pop'), + ], + + 'type-struct': [ + include('spaces'), + (r'\}', Punctuation, '#pop'), + (r'\?', Punctuation), + (r'>', Punctuation, ('comma', 'type')), + (ident_no_keyword, Name, ('#pop', 'type-struct-sep', 'type', 'colon')), + include('class-body'), + ], + + 'type-struct-sep': [ + include('spaces'), + (r'\}', Punctuation, '#pop'), + (r',', Punctuation, ('#pop', 'type-struct')), + ], + + # type-param can be a normal type or a constant literal... + 'type-param-type': [ + # Float + (r'\.[0-9]+', Number.Float, '#pop'), + (r'[0-9]+[eE][+\-]?[0-9]+', Number.Float, '#pop'), + (r'[0-9]+\.[0-9]*[eE][+\-]?[0-9]+', Number.Float, '#pop'), + (r'[0-9]+\.[0-9]+', Number.Float, '#pop'), + (r'[0-9]+\.(?!' + ident + r'|\.\.)', Number.Float, '#pop'), + + # Int + (r'0x[0-9a-fA-F]+', Number.Hex, '#pop'), + (r'[0-9]+', Number.Integer, '#pop'), + + # String + (r"'", String.Single, ('#pop', 'string-single')), + (r'"', String.Double, ('#pop', 'string-double')), + + # EReg + (r'~/(\\\\|\\[^\\]|[^/\\\n])*/[gim]*', String.Regex, '#pop'), + + # Array + (r'\[', Operator, ('#pop', 'array-decl')), + + include('type'), + ], + + # type-param part of a type + # ie. the path in Map + 'type-param': [ + include('spaces'), + default(('#pop', 'type-param-sep', 'type-param-type')), + ], + + 'type-param-sep': [ + include('spaces'), + (r'>', Punctuation, '#pop'), + (r',', Punctuation, ('#pop', 'type-param')), + ], + + # optional type-param that may include constraint + # ie. + 'type-param-constraint': [ + include('spaces'), + (r'<(?!=)', Punctuation, ('#pop', 'type-param-constraint-sep', + 'type-param-constraint-flag', 'type-name')), + default('#pop'), + ], + + 'type-param-constraint-sep': [ + include('spaces'), + (r'>', Punctuation, '#pop'), + (r',', Punctuation, ('#pop', 'type-param-constraint-sep', + 'type-param-constraint-flag', 'type-name')), + ], + + # the optional constraint inside type-param + 'type-param-constraint-flag': [ + include('spaces'), + (r':', Punctuation, ('#pop', 'type-param-constraint-flag-type')), + default('#pop'), + ], + + 'type-param-constraint-flag-type': [ + include('spaces'), + (r'\(', Punctuation, ('#pop', 'type-param-constraint-flag-type-sep', + 'type')), + default(('#pop', 'type')), + ], + + 'type-param-constraint-flag-type-sep': [ + include('spaces'), + (r'\)', Punctuation, '#pop'), + (r',', Punctuation, 'type'), + ], + + # a parenthesis expr that contain exactly one expr + 'parenthesis': [ + include('spaces'), + default(('#pop', 'parenthesis-close', 'flag', 'expr')), + ], + + 'parenthesis-open': [ + include('spaces'), + (r'\(', Punctuation, '#pop'), + ], + + 'parenthesis-close': [ + include('spaces'), + (r'\)', Punctuation, '#pop'), + ], + + 'var': [ + include('spaces'), + (ident_no_keyword, Text, ('#pop', 'var-sep', 'assign', 'flag', 'prop-get-set')), + ], + + # optional more var decl. + 'var-sep': [ + include('spaces'), + (r',', Punctuation, ('#pop', 'var')), + default('#pop'), + ], + + # optional assignment + 'assign': [ + include('spaces'), + (r'=', Operator, ('#pop', 'expr')), + default('#pop'), + ], + + # optional type flag + 'flag': [ + include('spaces'), + (r':', Punctuation, ('#pop', 'type')), + default('#pop'), + ], + + # colon as part of a ternary operator (?:) + 'ternary': [ + include('spaces'), + (r':', Operator, '#pop'), + ], + + # function call + 'call': [ + include('spaces'), + (r'\)', Punctuation, '#pop'), + default(('#pop', 'call-sep', 'expr')), + ], + + # after a call param + 'call-sep': [ + include('spaces'), + (r'\)', Punctuation, '#pop'), + (r',', Punctuation, ('#pop', 'call')), + ], + + # bracket can be block or object + 'bracket': [ + include('spaces'), + (r'(?!(?:\$\s*[a-z]\b|\$(?!'+ident+')))' + ident_no_keyword, Name, + ('#pop', 'bracket-check')), + (r"'", String.Single, ('#pop', 'bracket-check', 'string-single')), + (r'"', String.Double, ('#pop', 'bracket-check', 'string-double')), + default(('#pop', 'block')), + ], + + 'bracket-check': [ + include('spaces'), + (r':', Punctuation, ('#pop', 'object-sep', 'expr')), # is object + default(('#pop', 'block', 'optional-semicolon', 'expr-chain')), # is block + ], + + # code block + 'block': [ + include('spaces'), + (r'\}', Punctuation, '#pop'), + default('expr-statement'), + ], + + # object in key-value pairs + 'object': [ + include('spaces'), + (r'\}', Punctuation, '#pop'), + default(('#pop', 'object-sep', 'expr', 'colon', 'ident-or-string')) + ], + + # a key of an object + 'ident-or-string': [ + include('spaces'), + (ident_no_keyword, Name, '#pop'), + (r"'", String.Single, ('#pop', 'string-single')), + (r'"', String.Double, ('#pop', 'string-double')), + ], + + # after a key-value pair in object + 'object-sep': [ + include('spaces'), + (r'\}', Punctuation, '#pop'), + (r',', Punctuation, ('#pop', 'object')), + ], + + + + } + + def analyse_text(text): + if re.match(r'\w+\s*:\s*\w', text): + return 0.3 + + +class HxmlLexer(RegexLexer): + """ + Lexer for haXe build files. + """ + name = 'Hxml' + url = 'https://haxe.org/manual/compiler-usage-hxml.html' + aliases = ['haxeml', 'hxml'] + filenames = ['*.hxml'] + version_added = '1.6' + + tokens = { + 'root': [ + # Separator + (r'(--)(next)', bygroups(Punctuation, Generic.Heading)), + # Compiler switches with one dash + (r'(-)(prompt|debug|v)', bygroups(Punctuation, Keyword.Keyword)), + # Compilerswitches with two dashes + (r'(--)(neko-source|flash-strict|flash-use-stage|no-opt|no-traces|' + r'no-inline|times|no-output)', bygroups(Punctuation, Keyword)), + # Targets and other options that take an argument + (r'(-)(cpp|js|neko|x|as3|swf9?|swf-lib|php|xml|main|lib|D|resource|' + r'cp|cmd)( +)(.+)', + bygroups(Punctuation, Keyword, Whitespace, String)), + # Options that take only numerical arguments + (r'(-)(swf-version)( +)(\d+)', + bygroups(Punctuation, Keyword, Whitespace, Number.Integer)), + # An Option that defines the size, the fps and the background + # color of an flash movie + (r'(-)(swf-header)( +)(\d+)(:)(\d+)(:)(\d+)(:)([A-Fa-f0-9]{6})', + bygroups(Punctuation, Keyword, Whitespace, Number.Integer, + Punctuation, Number.Integer, Punctuation, Number.Integer, + Punctuation, Number.Hex)), + # options with two dashes that takes arguments + (r'(--)(js-namespace|php-front|php-lib|remap|gen-hx-classes)( +)' + r'(.+)', bygroups(Punctuation, Keyword, Whitespace, String)), + # Single line comment, multiline ones are not allowed. + (r'#.*', Comment.Single) + ] + } diff --git a/pygments/lexers/hdl.py b/pygments/lexers/hdl.py new file mode 100644 index 0000000000000000000000000000000000000000..d22b66ff6a035d5512aa32997dc66f7788eb0ab2 --- /dev/null +++ b/pygments/lexers/hdl.py @@ -0,0 +1,466 @@ +""" + pygments.lexers.hdl + ~~~~~~~~~~~~~~~~~~~ + + Lexers for hardware descriptor languages. + + :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS. + :license: BSD, see LICENSE for details. +""" + +import re + +from pygments.lexer import RegexLexer, bygroups, include, using, this, words +from pygments.token import Text, Comment, Operator, Keyword, Name, String, \ + Number, Punctuation, Whitespace + +__all__ = ['VerilogLexer', 'SystemVerilogLexer', 'VhdlLexer'] + + +class VerilogLexer(RegexLexer): + """ + For verilog source code with preprocessor directives. + """ + name = 'verilog' + aliases = ['verilog', 'v'] + filenames = ['*.v'] + mimetypes = ['text/x-verilog'] + url = 'https://en.wikipedia.org/wiki/Verilog' + version_added = '1.4' + + #: optional Comment or Whitespace + _ws = r'(?:\s|//.*?\n|/[*].*?[*]/)+' + + tokens = { + 'root': [ + (r'^\s*`define', Comment.Preproc, 'macro'), + (r'\s+', Whitespace), + (r'(\\)(\n)', bygroups(String.Escape, Whitespace)), # line continuation + (r'/(\\\n)?/(\n|(.|\n)*?[^\\]\n)', Comment.Single), + (r'/(\\\n)?[*](.|\n)*?[*](\\\n)?/', Comment.Multiline), + (r'[{}#@]', Punctuation), + (r'L?"', String, 'string'), + (r"L?'(\\.|\\[0-7]{1,3}|\\x[a-fA-F0-9]{1,2}|[^\\\'\n])'", String.Char), + (r'(\d+\.\d*|\.\d+|\d+)[eE][+-]?\d+[lL]?', Number.Float), + (r'(\d+\.\d*|\.\d+|\d+[fF])[fF]?', Number.Float), + (r'([0-9]+)|(\'h)[0-9a-fA-F]+', Number.Hex), + (r'([0-9]+)|(\'b)[01]+', Number.Bin), + (r'([0-9]+)|(\'d)[0-9]+', Number.Integer), + (r'([0-9]+)|(\'o)[0-7]+', Number.Oct), + (r'\'[01xz]', Number), + (r'\d+[Ll]?', Number.Integer), + (r'[~!%^&*+=|?:<>/-]', Operator), + (r'[()\[\],.;\']', Punctuation), + (r'`[a-zA-Z_]\w*', Name.Constant), + + (r'^(\s*)(package)(\s+)', bygroups(Whitespace, Keyword.Namespace, Text)), + (r'^(\s*)(import)(\s+)', bygroups(Whitespace, Keyword.Namespace, Text), + 'import'), + + (words(( + 'always', 'always_comb', 'always_ff', 'always_latch', 'and', + 'assign', 'automatic', 'begin', 'break', 'buf', 'bufif0', 'bufif1', + 'case', 'casex', 'casez', 'cmos', 'const', 'continue', 'deassign', + 'default', 'defparam', 'disable', 'do', 'edge', 'else', 'end', 'endcase', + 'endfunction', 'endgenerate', 'endmodule', 'endpackage', 'endprimitive', + 'endspecify', 'endtable', 'endtask', 'enum', 'event', 'final', 'for', + 'force', 'forever', 'fork', 'function', 'generate', 'genvar', 'highz0', + 'highz1', 'if', 'initial', 'inout', 'input', 'integer', 'join', 'large', + 'localparam', 'macromodule', 'medium', 'module', 'nand', 'negedge', + 'nmos', 'nor', 'not', 'notif0', 'notif1', 'or', 'output', 'packed', + 'parameter', 'pmos', 'posedge', 'primitive', 'pull0', 'pull1', + 'pulldown', 'pullup', 'rcmos', 'ref', 'release', 'repeat', 'return', + 'rnmos', 'rpmos', 'rtran', 'rtranif0', 'rtranif1', 'scalared', 'signed', + 'small', 'specify', 'specparam', 'strength', 'string', 'strong0', + 'strong1', 'struct', 'table', 'task', 'tran', 'tranif0', 'tranif1', + 'type', 'typedef', 'unsigned', 'var', 'vectored', 'void', 'wait', + 'weak0', 'weak1', 'while', 'xnor', 'xor'), suffix=r'\b'), + Keyword), + + (words(( + 'accelerate', 'autoexpand_vectornets', 'celldefine', 'default_nettype', + 'else', 'elsif', 'endcelldefine', 'endif', 'endprotect', 'endprotected', + 'expand_vectornets', 'ifdef', 'ifndef', 'include', 'noaccelerate', + 'noexpand_vectornets', 'noremove_gatenames', 'noremove_netnames', + 'nounconnected_drive', 'protect', 'protected', 'remove_gatenames', + 'remove_netnames', 'resetall', 'timescale', 'unconnected_drive', + 'undef'), prefix=r'`', suffix=r'\b'), + Comment.Preproc), + + (words(( + 'bits', 'bitstoreal', 'bitstoshortreal', 'countdrivers', 'display', 'fclose', + 'fdisplay', 'finish', 'floor', 'fmonitor', 'fopen', 'fstrobe', 'fwrite', + 'getpattern', 'history', 'incsave', 'input', 'itor', 'key', 'list', 'log', + 'monitor', 'monitoroff', 'monitoron', 'nokey', 'nolog', 'printtimescale', + 'random', 'readmemb', 'readmemh', 'realtime', 'realtobits', 'reset', + 'reset_count', 'reset_value', 'restart', 'rtoi', 'save', 'scale', 'scope', + 'shortrealtobits', 'showscopes', 'showvariables', 'showvars', 'sreadmemb', + 'sreadmemh', 'stime', 'stop', 'strobe', 'time', 'timeformat', 'write'), + prefix=r'\$', suffix=r'\b'), + Name.Builtin), + + (words(( + 'byte', 'shortint', 'int', 'longint', 'integer', 'time', + 'bit', 'logic', 'reg', 'supply0', 'supply1', 'tri', 'triand', + 'trior', 'tri0', 'tri1', 'trireg', 'uwire', 'wire', 'wand', 'wor' + 'shortreal', 'real', 'realtime'), suffix=r'\b'), + Keyword.Type), + (r'[a-zA-Z_]\w*:(?!:)', Name.Label), + (r'\$?[a-zA-Z_]\w*', Name), + (r'\\(\S+)', Name), + ], + 'string': [ + (r'"', String, '#pop'), + (r'\\([\\abfnrtv"\']|x[a-fA-F0-9]{2,4}|[0-7]{1,3})', String.Escape), + (r'[^\\"\n]+', String), # all other characters + (r'(\\)(\n)', bygroups(String.Escape, Whitespace)), # line continuation + (r'\\', String), # stray backslash + ], + 'macro': [ + (r'[^/\n]+', Comment.Preproc), + (r'/[*](.|\n)*?[*]/', Comment.Multiline), + (r'//.*?\n', Comment.Single, '#pop'), + (r'/', Comment.Preproc), + (r'(?<=\\)\n', Comment.Preproc), + (r'\n', Whitespace, '#pop'), + ], + 'import': [ + (r'[\w:]+\*?', Name.Namespace, '#pop') + ] + } + + def analyse_text(text): + """Verilog code will use one of reg/wire/assign for sure, and that + is not common elsewhere.""" + result = 0 + if 'reg' in text: + result += 0.1 + if 'wire' in text: + result += 0.1 + if 'assign' in text: + result += 0.1 + + return result + + +class SystemVerilogLexer(RegexLexer): + """ + Extends verilog lexer to recognise all SystemVerilog keywords from IEEE + 1800-2009 standard. + """ + name = 'systemverilog' + aliases = ['systemverilog', 'sv'] + filenames = ['*.sv', '*.svh'] + mimetypes = ['text/x-systemverilog'] + url = 'https://en.wikipedia.org/wiki/SystemVerilog' + version_added = '1.5' + + #: optional Comment or Whitespace + _ws = r'(?:\s|//.*?\n|/[*].*?[*]/)+' + + tokens = { + 'root': [ + (r'^(\s*)(`define)', bygroups(Whitespace, Comment.Preproc), 'macro'), + (r'^(\s*)(package)(\s+)', bygroups(Whitespace, Keyword.Namespace, Whitespace)), + (r'^(\s*)(import)(\s+)', bygroups(Whitespace, Keyword.Namespace, Whitespace), 'import'), + + (r'\s+', Whitespace), + (r'(\\)(\n)', bygroups(String.Escape, Whitespace)), # line continuation + (r'/(\\\n)?/(\n|(.|\n)*?[^\\]\n)', Comment.Single), + (r'/(\\\n)?[*](.|\n)*?[*](\\\n)?/', Comment.Multiline), + (r'[{}#@]', Punctuation), + (r'L?"', String, 'string'), + (r"L?'(\\.|\\[0-7]{1,3}|\\x[a-fA-F0-9]{1,2}|[^\\\'\n])'", String.Char), + + (r'(\d+\.\d*|\.\d+|\d+)[eE][+-]?\d+[lL]?', Number.Float), + (r'(\d+\.\d*|\.\d+|\d+[fF])[fF]?', Number.Float), + + (r'([1-9][_0-9]*)?\s*\'[sS]?[bB]\s*[xXzZ?01][_xXzZ?01]*', + Number.Bin), + (r'([1-9][_0-9]*)?\s*\'[sS]?[oO]\s*[xXzZ?0-7][_xXzZ?0-7]*', + Number.Oct), + (r'([1-9][_0-9]*)?\s*\'[sS]?[dD]\s*[xXzZ?0-9][_xXzZ?0-9]*', + Number.Integer), + (r'([1-9][_0-9]*)?\s*\'[sS]?[hH]\s*[xXzZ?0-9a-fA-F][_xXzZ?0-9a-fA-F]*', + Number.Hex), + + (r'\'[01xXzZ]', Number), + (r'[0-9][_0-9]*', Number.Integer), + + (r'[~!%^&*+=|?:<>/-]', Operator), + (words(('inside', 'dist'), suffix=r'\b'), Operator.Word), + + (r'[()\[\],.;\'$]', Punctuation), + (r'`[a-zA-Z_]\w*', Name.Constant), + + (words(( + 'accept_on', 'alias', 'always', 'always_comb', 'always_ff', + 'always_latch', 'and', 'assert', 'assign', 'assume', 'automatic', + 'before', 'begin', 'bind', 'bins', 'binsof', 'break', 'buf', + 'bufif0', 'bufif1', 'case', 'casex', 'casez', 'cell', + 'checker', 'clocking', 'cmos', 'config', + 'constraint', 'context', 'continue', 'cover', 'covergroup', + 'coverpoint', 'cross', 'deassign', 'default', 'defparam', 'design', + 'disable', 'do', 'edge', 'else', 'end', 'endcase', + 'endchecker', 'endclocking', 'endconfig', 'endfunction', + 'endgenerate', 'endgroup', 'endinterface', 'endmodule', 'endpackage', + 'endprimitive', 'endprogram', 'endproperty', 'endsequence', + 'endspecify', 'endtable', 'endtask', 'enum', 'eventually', + 'expect', 'export', 'extern', 'final', 'first_match', + 'for', 'force', 'foreach', 'forever', 'fork', 'forkjoin', 'function', + 'generate', 'genvar', 'global', 'highz0', 'highz1', 'if', 'iff', + 'ifnone', 'ignore_bins', 'illegal_bins', 'implies', 'implements', 'import', + 'incdir', 'include', 'initial', 'inout', 'input', + 'instance', 'interconnect', 'interface', 'intersect', 'join', + 'join_any', 'join_none', 'large', 'let', 'liblist', 'library', + 'local', 'localparam', 'macromodule', 'matches', + 'medium', 'modport', 'module', 'nand', 'negedge', 'nettype', 'new', 'nexttime', + 'nmos', 'nor', 'noshowcancelled', 'not', 'notif0', 'notif1', 'null', + 'or', 'output', 'package', 'packed', 'parameter', 'pmos', 'posedge', + 'primitive', 'priority', 'program', 'property', 'protected', 'pull0', + 'pull1', 'pulldown', 'pullup', 'pulsestyle_ondetect', + 'pulsestyle_onevent', 'pure', 'rand', 'randc', 'randcase', + 'randsequence', 'rcmos', 'ref', + 'reject_on', 'release', 'repeat', 'restrict', 'return', 'rnmos', + 'rpmos', 'rtran', 'rtranif0', 'rtranif1', 's_always', 's_eventually', + 's_nexttime', 's_until', 's_until_with', 'scalared', 'sequence', + 'showcancelled', 'small', 'soft', 'solve', + 'specify', 'specparam', 'static', 'strong', 'strong0', + 'strong1', 'struct', 'super', 'sync_accept_on', + 'sync_reject_on', 'table', 'tagged', 'task', 'this', 'throughout', + 'timeprecision', 'timeunit', 'tran', 'tranif0', 'tranif1', + 'typedef', 'union', 'unique', 'unique0', 'until', + 'until_with', 'untyped', 'use', 'vectored', + 'virtual', 'wait', 'wait_order', 'weak', 'weak0', + 'weak1', 'while', 'wildcard', 'with', 'within', + 'xnor', 'xor'), + suffix=r'\b'), + Keyword), + + (r'(class)(\s+)([a-zA-Z_]\w*)', + bygroups(Keyword.Declaration, Whitespace, Name.Class)), + (r'(extends)(\s+)([a-zA-Z_]\w*)', + bygroups(Keyword.Declaration, Whitespace, Name.Class)), + (r'(endclass\b)(?:(\s*)(:)(\s*)([a-zA-Z_]\w*))?', + bygroups(Keyword.Declaration, Whitespace, Punctuation, Whitespace, Name.Class)), + + (words(( + # Variable types + 'bit', 'byte', 'chandle', 'const', 'event', 'int', 'integer', + 'logic', 'longint', 'real', 'realtime', 'reg', 'shortint', + 'shortreal', 'signed', 'string', 'time', 'type', 'unsigned', + 'var', 'void', + # Net types + 'supply0', 'supply1', 'tri', 'triand', 'trior', 'trireg', + 'tri0', 'tri1', 'uwire', 'wand', 'wire', 'wor'), + suffix=r'\b'), + Keyword.Type), + + (words(( + '`__FILE__', '`__LINE__', '`begin_keywords', '`celldefine', + '`default_nettype', '`define', '`else', '`elsif', '`end_keywords', + '`endcelldefine', '`endif', '`ifdef', '`ifndef', '`include', + '`line', '`nounconnected_drive', '`pragma', '`resetall', + '`timescale', '`unconnected_drive', '`undef', '`undefineall'), + suffix=r'\b'), + Comment.Preproc), + + (words(( + # Simulation control tasks (20.2) + '$exit', '$finish', '$stop', + # Simulation time functions (20.3) + '$realtime', '$stime', '$time', + # Timescale tasks (20.4) + '$printtimescale', '$timeformat', + # Conversion functions + '$bitstoreal', '$bitstoshortreal', '$cast', '$itor', + '$realtobits', '$rtoi', '$shortrealtobits', '$signed', + '$unsigned', + # Data query functions (20.6) + '$bits', '$isunbounded', '$typename', + # Array query functions (20.7) + '$dimensions', '$high', '$increment', '$left', '$low', '$right', + '$size', '$unpacked_dimensions', + # Math functions (20.8) + '$acos', '$acosh', '$asin', '$asinh', '$atan', '$atan2', + '$atanh', '$ceil', '$clog2', '$cos', '$cosh', '$exp', '$floor', + '$hypot', '$ln', '$log10', '$pow', '$sin', '$sinh', '$sqrt', + '$tan', '$tanh', + # Bit vector system functions (20.9) + '$countbits', '$countones', '$isunknown', '$onehot', '$onehot0', + # Severity tasks (20.10) + '$info', '$error', '$fatal', '$warning', + # Assertion control tasks (20.12) + '$assertcontrol', '$assertfailoff', '$assertfailon', + '$assertkill', '$assertnonvacuouson', '$assertoff', '$asserton', + '$assertpassoff', '$assertpasson', '$assertvacuousoff', + # Sampled value system functions (20.13) + '$changed', '$changed_gclk', '$changing_gclk', '$falling_gclk', + '$fell', '$fell_gclk', '$future_gclk', '$past', '$past_gclk', + '$rising_gclk', '$rose', '$rose_gclk', '$sampled', '$stable', + '$stable_gclk', '$steady_gclk', + # Coverage control functions (20.14) + '$coverage_control', '$coverage_get', '$coverage_get_max', + '$coverage_merge', '$coverage_save', '$get_coverage', + '$load_coverage_db', '$set_coverage_db_name', + # Probabilistic distribution functions (20.15) + '$dist_chi_square', '$dist_erlang', '$dist_exponential', + '$dist_normal', '$dist_poisson', '$dist_t', '$dist_uniform', + '$random', + # Stochastic analysis tasks and functions (20.16) + '$q_add', '$q_exam', '$q_full', '$q_initialize', '$q_remove', + # PLA modeling tasks (20.17) + '$async$and$array', '$async$and$plane', '$async$nand$array', + '$async$nand$plane', '$async$nor$array', '$async$nor$plane', + '$async$or$array', '$async$or$plane', '$sync$and$array', + '$sync$and$plane', '$sync$nand$array', '$sync$nand$plane', + '$sync$nor$array', '$sync$nor$plane', '$sync$or$array', + '$sync$or$plane', + # Miscellaneous tasks and functions (20.18) + '$system', + # Display tasks (21.2) + '$display', '$displayb', '$displayh', '$displayo', '$monitor', + '$monitorb', '$monitorh', '$monitoro', '$monitoroff', + '$monitoron', '$strobe', '$strobeb', '$strobeh', '$strobeo', + '$write', '$writeb', '$writeh', '$writeo', + # File I/O tasks and functions (21.3) + '$fclose', '$fdisplay', '$fdisplayb', '$fdisplayh', + '$fdisplayo', '$feof', '$ferror', '$fflush', '$fgetc', '$fgets', + '$fmonitor', '$fmonitorb', '$fmonitorh', '$fmonitoro', '$fopen', + '$fread', '$fscanf', '$fseek', '$fstrobe', '$fstrobeb', + '$fstrobeh', '$fstrobeo', '$ftell', '$fwrite', '$fwriteb', + '$fwriteh', '$fwriteo', '$rewind', '$sformat', '$sformatf', + '$sscanf', '$swrite', '$swriteb', '$swriteh', '$swriteo', + '$ungetc', + # Memory load tasks (21.4) + '$readmemb', '$readmemh', + # Memory dump tasks (21.5) + '$writememb', '$writememh', + # Command line input (21.6) + '$test$plusargs', '$value$plusargs', + # VCD tasks (21.7) + '$dumpall', '$dumpfile', '$dumpflush', '$dumplimit', '$dumpoff', + '$dumpon', '$dumpports', '$dumpportsall', '$dumpportsflush', + '$dumpportslimit', '$dumpportsoff', '$dumpportson', '$dumpvars', + ), suffix=r'\b'), + Name.Builtin), + + (r'[a-zA-Z_]\w*:(?!:)', Name.Label), + (r'\$?[a-zA-Z_]\w*', Name), + (r'\\(\S+)', Name), + ], + 'string': [ + (r'"', String, '#pop'), + (r'\\([\\abfnrtv"\']|x[a-fA-F0-9]{2,4}|[0-7]{1,3})', String.Escape), + (r'[^\\"\n]+', String), # all other characters + (r'(\\)(\n)', bygroups(String.Escape, Whitespace)), # line continuation + (r'\\', String), # stray backslash + ], + 'macro': [ + (r'[^/\n]+', Comment.Preproc), + (r'/[*](.|\n)*?[*]/', Comment.Multiline), + (r'//.*?$', Comment.Single, '#pop'), + (r'/', Comment.Preproc), + (r'(?<=\\)\n', Comment.Preproc), + (r'\n', Whitespace, '#pop'), + ], + 'import': [ + (r'[\w:]+\*?', Name.Namespace, '#pop') + ] + } + + +class VhdlLexer(RegexLexer): + """ + For VHDL source code. + """ + name = 'vhdl' + aliases = ['vhdl'] + filenames = ['*.vhdl', '*.vhd'] + mimetypes = ['text/x-vhdl'] + url = 'https://en.wikipedia.org/wiki/VHDL' + version_added = '1.5' + flags = re.MULTILINE | re.IGNORECASE + + tokens = { + 'root': [ + (r'\s+', Whitespace), + (r'(\\)(\n)', bygroups(String.Escape, Whitespace)), # line continuation + (r'--.*?$', Comment.Single), + (r'/(\\\n)?[*](.|\n)*?[*](\\\n)?/', Comment.Multiline), + (r"'(U|X|0|1|Z|W|L|H|-)'", String.Char), + (r'[~!%^&*+=|?:<>/-]', Operator), + (r"'[a-z_]\w*", Name.Attribute), + (r'[()\[\],.;\']', Punctuation), + (r'"[^\n\\"]*"', String), + + (r'(library)(\s+)([a-z_]\w*)', + bygroups(Keyword, Whitespace, Name.Namespace)), + (r'(use)(\s+)(entity)', bygroups(Keyword, Whitespace, Keyword)), + (r'(use)(\s+)([a-z_][\w.]*\.)(all)', + bygroups(Keyword, Whitespace, Name.Namespace, Keyword)), + (r'(use)(\s+)([a-z_][\w.]*)', + bygroups(Keyword, Whitespace, Name.Namespace)), + (r'(std|ieee)(\.[a-z_]\w*)', + bygroups(Name.Namespace, Name.Namespace)), + (words(('std', 'ieee', 'work'), suffix=r'\b'), + Name.Namespace), + (r'(entity|component)(\s+)([a-z_]\w*)', + bygroups(Keyword, Whitespace, Name.Class)), + (r'(architecture|configuration)(\s+)([a-z_]\w*)(\s+)' + r'(of)(\s+)([a-z_]\w*)(\s+)(is)', + bygroups(Keyword, Whitespace, Name.Class, Whitespace, Keyword, Whitespace, + Name.Class, Whitespace, Keyword)), + (r'([a-z_]\w*)(:)(\s+)(process|for)', + bygroups(Name.Class, Operator, Whitespace, Keyword)), + (r'(end)(\s+)', bygroups(using(this), Whitespace), 'endblock'), + + include('types'), + include('keywords'), + include('numbers'), + + (r'[a-z_]\w*', Name), + ], + 'endblock': [ + include('keywords'), + (r'[a-z_]\w*', Name.Class), + (r'\s+', Whitespace), + (r';', Punctuation, '#pop'), + ], + 'types': [ + (words(( + 'boolean', 'bit', 'character', 'severity_level', 'integer', 'time', + 'delay_length', 'natural', 'positive', 'string', 'bit_vector', + 'file_open_kind', 'file_open_status', 'std_ulogic', 'std_ulogic_vector', + 'std_logic', 'std_logic_vector', 'signed', 'unsigned'), suffix=r'\b'), + Keyword.Type), + ], + 'keywords': [ + (words(( + 'abs', 'access', 'after', 'alias', 'all', 'and', + 'architecture', 'array', 'assert', 'attribute', 'begin', 'block', + 'body', 'buffer', 'bus', 'case', 'component', 'configuration', + 'constant', 'disconnect', 'downto', 'else', 'elsif', 'end', + 'entity', 'exit', 'file', 'for', 'function', 'generate', + 'generic', 'group', 'guarded', 'if', 'impure', 'in', + 'inertial', 'inout', 'is', 'label', 'library', 'linkage', + 'literal', 'loop', 'map', 'mod', 'nand', 'new', + 'next', 'nor', 'not', 'null', 'of', 'on', + 'open', 'or', 'others', 'out', 'package', 'port', + 'postponed', 'procedure', 'process', 'pure', 'range', 'record', + 'register', 'reject', 'rem', 'return', 'rol', 'ror', 'select', + 'severity', 'signal', 'shared', 'sla', 'sll', 'sra', + 'srl', 'subtype', 'then', 'to', 'transport', 'type', + 'units', 'until', 'use', 'variable', 'wait', 'when', + 'while', 'with', 'xnor', 'xor'), suffix=r'\b'), + Keyword), + ], + 'numbers': [ + (r'\d{1,2}#[0-9a-f_]+#?', Number.Integer), + (r'\d+', Number.Integer), + (r'(\d+\.\d*|\.\d+|\d+)E[+-]?\d+', Number.Float), + (r'X"[0-9a-f_]+"', Number.Hex), + (r'O"[0-7_]+"', Number.Oct), + (r'B"[01_]+"', Number.Bin), + ], + } diff --git a/pygments/lexers/hexdump.py b/pygments/lexers/hexdump.py new file mode 100644 index 0000000000000000000000000000000000000000..4da4edb6af7bbb2855a923b0fdc618a9054e988c --- /dev/null +++ b/pygments/lexers/hexdump.py @@ -0,0 +1,102 @@ +""" + pygments.lexers.hexdump + ~~~~~~~~~~~~~~~~~~~~~~~ + + Lexers for hexadecimal dumps. + + :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS. + :license: BSD, see LICENSE for details. +""" + +from pygments.lexer import RegexLexer, bygroups, include +from pygments.token import Name, Number, String, Punctuation, Whitespace + +__all__ = ['HexdumpLexer'] + + +class HexdumpLexer(RegexLexer): + """ + For typical hex dump output formats by the UNIX and GNU/Linux tools ``hexdump``, + ``hd``, ``hexcat``, ``od`` and ``xxd``, and the DOS tool ``DEBUG``. For example: + + .. sourcecode:: hexdump + + 00000000 7f 45 4c 46 02 01 01 00 00 00 00 00 00 00 00 00 |.ELF............| + 00000010 02 00 3e 00 01 00 00 00 c5 48 40 00 00 00 00 00 |..>......H@.....| + + The specific supported formats are the outputs of: + + * ``hexdump FILE`` + * ``hexdump -C FILE`` -- the `canonical` format used in the example. + * ``hd FILE`` -- same as ``hexdump -C FILE``. + * ``hexcat FILE`` + * ``od -t x1z FILE`` + * ``xxd FILE`` + * ``DEBUG.EXE FILE.COM`` and entering ``d`` to the prompt. + """ + name = 'Hexdump' + aliases = ['hexdump'] + url = 'https://en.wikipedia.org/wiki/Hex_dump' + version_added = '2.1' + + hd = r'[0-9A-Ha-h]' + + tokens = { + 'root': [ + (r'\n', Whitespace), + include('offset'), + (r'('+hd+r'{2})(\-)('+hd+r'{2})', + bygroups(Number.Hex, Punctuation, Number.Hex)), + (hd+r'{2}', Number.Hex), + (r'(\s{2,3})(\>)(.{16})(\<)$', + bygroups(Whitespace, Punctuation, String, Punctuation), 'bracket-strings'), + (r'(\s{2,3})(\|)(.{16})(\|)$', + bygroups(Whitespace, Punctuation, String, Punctuation), 'piped-strings'), + (r'(\s{2,3})(\>)(.{1,15})(\<)$', + bygroups(Whitespace, Punctuation, String, Punctuation)), + (r'(\s{2,3})(\|)(.{1,15})(\|)$', + bygroups(Whitespace, Punctuation, String, Punctuation)), + (r'(\s{2,3})(.{1,15})$', bygroups(Whitespace, String)), + (r'(\s{2,3})(.{16}|.{20})$', bygroups(Whitespace, String), 'nonpiped-strings'), + (r'\s', Whitespace), + (r'^\*', Punctuation), + ], + 'offset': [ + (r'^('+hd+'+)(:)', bygroups(Name.Label, Punctuation), 'offset-mode'), + (r'^'+hd+'+', Name.Label), + ], + 'offset-mode': [ + (r'\s', Whitespace, '#pop'), + (hd+'+', Name.Label), + (r':', Punctuation) + ], + 'piped-strings': [ + (r'\n', Whitespace), + include('offset'), + (hd+r'{2}', Number.Hex), + (r'(\s{2,3})(\|)(.{1,16})(\|)$', + bygroups(Whitespace, Punctuation, String, Punctuation)), + (r'\s', Whitespace), + (r'^\*', Punctuation), + ], + 'bracket-strings': [ + (r'\n', Whitespace), + include('offset'), + (hd+r'{2}', Number.Hex), + (r'(\s{2,3})(\>)(.{1,16})(\<)$', + bygroups(Whitespace, Punctuation, String, Punctuation)), + (r'\s', Whitespace), + (r'^\*', Punctuation), + ], + 'nonpiped-strings': [ + (r'\n', Whitespace), + include('offset'), + (r'('+hd+r'{2})(\-)('+hd+r'{2})', + bygroups(Number.Hex, Punctuation, Number.Hex)), + (hd+r'{2}', Number.Hex), + (r'(\s{19,})(.{1,20}?)$', bygroups(Whitespace, String)), + (r'(\s{2,3})(.{1,20})$', bygroups(Whitespace, String)), + (r'\s', Whitespace), + (r'^\*', Punctuation), + ], + } diff --git a/pygments/lexers/html.py b/pygments/lexers/html.py new file mode 100644 index 0000000000000000000000000000000000000000..66ba6609f7d4647c7b0c529cbedca18808dd5eb0 --- /dev/null +++ b/pygments/lexers/html.py @@ -0,0 +1,670 @@ +""" + pygments.lexers.html + ~~~~~~~~~~~~~~~~~~~~ + + Lexers for HTML, XML and related markup. + + :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS. + :license: BSD, see LICENSE for details. +""" + +import re + +from pygments.lexer import RegexLexer, ExtendedRegexLexer, include, bygroups, \ + default, using, inherit, this +from pygments.token import Text, Comment, Operator, Keyword, Name, String, \ + Punctuation, Whitespace +from pygments.util import looks_like_xml, html_doctype_matches + +from pygments.lexers.javascript import JavascriptLexer +from pygments.lexers.jvm import ScalaLexer +from pygments.lexers.css import CssLexer, _indentation, _starts_block +from pygments.lexers.ruby import RubyLexer + +__all__ = ['HtmlLexer', 'DtdLexer', 'XmlLexer', 'XsltLexer', 'HamlLexer', + 'ScamlLexer', 'PugLexer', 'VueLexer', 'UrlEncodedLexer'] + + +class HtmlLexer(RegexLexer): + """ + For HTML 4 and XHTML 1 markup. Nested JavaScript and CSS is highlighted + by the appropriate lexer. + """ + + name = 'HTML' + url = 'https://html.spec.whatwg.org/' + aliases = ['html'] + filenames = ['*.html', '*.htm', '*.xhtml', '*.xslt'] + mimetypes = ['text/html', 'application/xhtml+xml'] + version_added = '' + + flags = re.IGNORECASE | re.DOTALL + tokens = { + 'root': [ + ('[^<&]+', Text), + (r'&\S*?;', Name.Entity), + (r'\<\!\[CDATA\[.*?\]\]\>', Comment.Preproc), + (r'', Comment.Multiline), + (r'<\?.*?\?>', Comment.Preproc), + (']*>', Comment.Preproc), + (r'(<)(\s*)(script)(\s*)', + bygroups(Punctuation, Text, Name.Tag, Text), + ('script-content', 'tag')), + (r'(<)(\s*)(style)(\s*)', + bygroups(Punctuation, Text, Name.Tag, Text), + ('style-content', 'tag')), + # note: this allows tag names not used in HTML like , + # this is to support yet-unknown template engines and the like + (r'(<)(\s*)([\w:.-]+)', + bygroups(Punctuation, Text, Name.Tag), 'tag'), + (r'(<)(\s*)(/)(\s*)([\w:.-]+)(\s*)(>)', + bygroups(Punctuation, Text, Punctuation, Text, Name.Tag, Text, + Punctuation)), + ], + 'tag': [ + (r'\s+', Text), + (r'([\w:-]+\s*)(=)(\s*)', bygroups(Name.Attribute, Operator, Text), + 'attr'), + (r'[\w:-]+', Name.Attribute), + (r'(/?)(\s*)(>)', bygroups(Punctuation, Text, Punctuation), '#pop'), + ], + 'script-content': [ + (r'(<)(\s*)(/)(\s*)(script)(\s*)(>)', + bygroups(Punctuation, Text, Punctuation, Text, Name.Tag, Text, + Punctuation), '#pop'), + (r'.+?(?=<\s*/\s*script\s*>)', using(JavascriptLexer)), + # fallback cases for when there is no closing script tag + # first look for newline and then go back into root state + # if that fails just read the rest of the file + # this is similar to the error handling logic in lexer.py + (r'.+?\n', using(JavascriptLexer), '#pop'), + (r'.+', using(JavascriptLexer), '#pop'), + ], + 'style-content': [ + (r'(<)(\s*)(/)(\s*)(style)(\s*)(>)', + bygroups(Punctuation, Text, Punctuation, Text, Name.Tag, Text, + Punctuation),'#pop'), + (r'.+?(?=<\s*/\s*style\s*>)', using(CssLexer)), + # fallback cases for when there is no closing style tag + # first look for newline and then go back into root state + # if that fails just read the rest of the file + # this is similar to the error handling logic in lexer.py + (r'.+?\n', using(CssLexer), '#pop'), + (r'.+', using(CssLexer), '#pop'), + ], + 'attr': [ + ('".*?"', String, '#pop'), + ("'.*?'", String, '#pop'), + (r'[^\s>]+', String, '#pop'), + ], + } + + def analyse_text(text): + if html_doctype_matches(text): + return 0.5 + + +class DtdLexer(RegexLexer): + """ + A lexer for DTDs (Document Type Definitions). + """ + + flags = re.MULTILINE | re.DOTALL + + name = 'DTD' + aliases = ['dtd'] + filenames = ['*.dtd'] + mimetypes = ['application/xml-dtd'] + url = 'https://en.wikipedia.org/wiki/Document_type_definition' + version_added = '1.5' + + tokens = { + 'root': [ + include('common'), + + (r'(\s]+)', + bygroups(Keyword, Text, Name.Tag)), + (r'PUBLIC|SYSTEM', Keyword.Constant), + (r'[\[\]>]', Keyword), + ], + + 'common': [ + (r'\s+', Text), + (r'(%|&)[^;]*;', Name.Entity), + ('', Comment, '#pop'), + ('-', Comment), + ], + + 'element': [ + include('common'), + (r'EMPTY|ANY|#PCDATA', Keyword.Constant), + (r'[^>\s|()?+*,]+', Name.Tag), + (r'>', Keyword, '#pop'), + ], + + 'attlist': [ + include('common'), + (r'CDATA|IDREFS|IDREF|ID|NMTOKENS|NMTOKEN|ENTITIES|ENTITY|NOTATION', + Keyword.Constant), + (r'#REQUIRED|#IMPLIED|#FIXED', Keyword.Constant), + (r'xml:space|xml:lang', Keyword.Reserved), + (r'[^>\s|()?+*,]+', Name.Attribute), + (r'>', Keyword, '#pop'), + ], + + 'entity': [ + include('common'), + (r'SYSTEM|PUBLIC|NDATA', Keyword.Constant), + (r'[^>\s|()?+*,]+', Name.Entity), + (r'>', Keyword, '#pop'), + ], + + 'notation': [ + include('common'), + (r'SYSTEM|PUBLIC', Keyword.Constant), + (r'[^>\s|()?+*,]+', Name.Attribute), + (r'>', Keyword, '#pop'), + ], + } + + def analyse_text(text): + if not looks_like_xml(text) and \ + ('', Comment.Preproc), + (r'', Comment.Multiline), + (r'<\?.*?\?>', Comment.Preproc), + (']*>', Comment.Preproc), + (r'<\s*[\w:.-]+', Name.Tag, 'tag'), + (r'<\s*/\s*[\w:.-]+\s*>', Name.Tag), + ], + 'tag': [ + (r'\s+', Whitespace), + (r'[\w.:-]+\s*=', Name.Attribute, 'attr'), + (r'/?\s*>', Name.Tag, '#pop'), + ], + 'attr': [ + (r'\s+', Whitespace), + ('".*?"', String, '#pop'), + ("'.*?'", String, '#pop'), + (r'[^\s>]+', String, '#pop'), + ], + } + + def analyse_text(text): + if looks_like_xml(text): + return 0.45 # less than HTML + + +class XsltLexer(XmlLexer): + """ + A lexer for XSLT. + """ + + name = 'XSLT' + aliases = ['xslt'] + filenames = ['*.xsl', '*.xslt', '*.xpl'] # xpl is XProc + mimetypes = ['application/xsl+xml', 'application/xslt+xml'] + url = 'https://www.w3.org/TR/xslt-30' + version_added = '0.10' + + EXTRA_KEYWORDS = { + 'apply-imports', 'apply-templates', 'attribute', + 'attribute-set', 'call-template', 'choose', 'comment', + 'copy', 'copy-of', 'decimal-format', 'element', 'fallback', + 'for-each', 'if', 'import', 'include', 'key', 'message', + 'namespace-alias', 'number', 'otherwise', 'output', 'param', + 'preserve-space', 'processing-instruction', 'sort', + 'strip-space', 'stylesheet', 'template', 'text', 'transform', + 'value-of', 'variable', 'when', 'with-param' + } + + def get_tokens_unprocessed(self, text): + for index, token, value in XmlLexer.get_tokens_unprocessed(self, text): + m = re.match(']*)/?>?', value) + + if token is Name.Tag and m and m.group(1) in self.EXTRA_KEYWORDS: + yield index, Keyword, value + else: + yield index, token, value + + def analyse_text(text): + if looks_like_xml(text) and ']{1,2}(?=[ \t=])', Punctuation), + include('eval-or-plain'), + ], + + 'plain': [ + (r'([^#\n]|#[^{\n]|(\\\\)*\\#\{)+', Text), + (r'(#\{)(' + _dot + r'*?)(\})', + bygroups(String.Interpol, using(RubyLexer), String.Interpol)), + (r'\n', Text, 'root'), + ], + + 'html-attributes': [ + (r'\s+', Text), + (r'[\w:-]+[ \t]*=', Name.Attribute, 'html-attribute-value'), + (r'[\w:-]+', Name.Attribute), + (r'\)', Text, '#pop'), + ], + + 'html-attribute-value': [ + (r'[ \t]+', Text), + (r'\w+', Name.Variable, '#pop'), + (r'@\w+', Name.Variable.Instance, '#pop'), + (r'\$\w+', Name.Variable.Global, '#pop'), + (r"'(\\\\|\\[^\\]|[^'\\\n])*'", String, '#pop'), + (r'"(\\\\|\\[^\\]|[^"\\\n])*"', String, '#pop'), + ], + + 'html-comment-block': [ + (_dot + '+', Comment), + (r'\n', Text, 'root'), + ], + + 'haml-comment-block': [ + (_dot + '+', Comment.Preproc), + (r'\n', Text, 'root'), + ], + + 'filter-block': [ + (r'([^#\n]|#[^{\n]|(\\\\)*\\#\{)+', Name.Decorator), + (r'(#\{)(' + _dot + r'*?)(\})', + bygroups(String.Interpol, using(RubyLexer), String.Interpol)), + (r'\n', Text, 'root'), + ], + } + + +class ScamlLexer(ExtendedRegexLexer): + """ + For Scaml markup. Scaml is Haml for Scala. + """ + + name = 'Scaml' + aliases = ['scaml'] + filenames = ['*.scaml'] + mimetypes = ['text/x-scaml'] + url = 'https://scalate.github.io/scalate/' + version_added = '1.4' + + flags = re.IGNORECASE + # Scaml does not yet support the " |\n" notation to + # wrap long lines. Once it does, use the custom faux + # dot instead. + # _dot = r'(?: \|\n(?=.* \|)|.)' + _dot = r'.' + + tokens = { + 'root': [ + (r'[ \t]*\n', Text), + (r'[ \t]*', _indentation), + ], + + 'css': [ + (r'\.[\w:-]+', Name.Class, 'tag'), + (r'\#[\w:-]+', Name.Function, 'tag'), + ], + + 'eval-or-plain': [ + (r'[&!]?==', Punctuation, 'plain'), + (r'([&!]?[=~])(' + _dot + r'*\n)', + bygroups(Punctuation, using(ScalaLexer)), + 'root'), + default('plain'), + ], + + 'content': [ + include('css'), + (r'%[\w:-]+', Name.Tag, 'tag'), + (r'!!!' + _dot + r'*\n', Name.Namespace, '#pop'), + (r'(/)(\[' + _dot + r'*?\])(' + _dot + r'*\n)', + bygroups(Comment, Comment.Special, Comment), + '#pop'), + (r'/' + _dot + r'*\n', _starts_block(Comment, 'html-comment-block'), + '#pop'), + (r'-#' + _dot + r'*\n', _starts_block(Comment.Preproc, + 'scaml-comment-block'), '#pop'), + (r'(-@\s*)(import)?(' + _dot + r'*\n)', + bygroups(Punctuation, Keyword, using(ScalaLexer)), + '#pop'), + (r'(-)(' + _dot + r'*\n)', + bygroups(Punctuation, using(ScalaLexer)), + '#pop'), + (r':' + _dot + r'*\n', _starts_block(Name.Decorator, 'filter-block'), + '#pop'), + include('eval-or-plain'), + ], + + 'tag': [ + include('css'), + (r'\{(,\n|' + _dot + r')*?\}', using(ScalaLexer)), + (r'\[' + _dot + r'*?\]', using(ScalaLexer)), + (r'\(', Text, 'html-attributes'), + (r'/[ \t]*\n', Punctuation, '#pop:2'), + (r'[<>]{1,2}(?=[ \t=])', Punctuation), + include('eval-or-plain'), + ], + + 'plain': [ + (r'([^#\n]|#[^{\n]|(\\\\)*\\#\{)+', Text), + (r'(#\{)(' + _dot + r'*?)(\})', + bygroups(String.Interpol, using(ScalaLexer), String.Interpol)), + (r'\n', Text, 'root'), + ], + + 'html-attributes': [ + (r'\s+', Text), + (r'[\w:-]+[ \t]*=', Name.Attribute, 'html-attribute-value'), + (r'[\w:-]+', Name.Attribute), + (r'\)', Text, '#pop'), + ], + + 'html-attribute-value': [ + (r'[ \t]+', Text), + (r'\w+', Name.Variable, '#pop'), + (r'@\w+', Name.Variable.Instance, '#pop'), + (r'\$\w+', Name.Variable.Global, '#pop'), + (r"'(\\\\|\\[^\\]|[^'\\\n])*'", String, '#pop'), + (r'"(\\\\|\\[^\\]|[^"\\\n])*"', String, '#pop'), + ], + + 'html-comment-block': [ + (_dot + '+', Comment), + (r'\n', Text, 'root'), + ], + + 'scaml-comment-block': [ + (_dot + '+', Comment.Preproc), + (r'\n', Text, 'root'), + ], + + 'filter-block': [ + (r'([^#\n]|#[^{\n]|(\\\\)*\\#\{)+', Name.Decorator), + (r'(#\{)(' + _dot + r'*?)(\})', + bygroups(String.Interpol, using(ScalaLexer), String.Interpol)), + (r'\n', Text, 'root'), + ], + } + + +class PugLexer(ExtendedRegexLexer): + """ + For Pug markup. + Pug is a variant of Scaml, see: + http://scalate.fusesource.org/documentation/scaml-reference.html + """ + + name = 'Pug' + aliases = ['pug', 'jade'] + filenames = ['*.pug', '*.jade'] + mimetypes = ['text/x-pug', 'text/x-jade'] + url = 'https://pugjs.org' + version_added = '1.4' + + flags = re.IGNORECASE + _dot = r'.' + + tokens = { + 'root': [ + (r'[ \t]*\n', Text), + (r'[ \t]*', _indentation), + ], + + 'css': [ + (r'\.[\w:-]+', Name.Class, 'tag'), + (r'\#[\w:-]+', Name.Function, 'tag'), + ], + + 'eval-or-plain': [ + (r'[&!]?==', Punctuation, 'plain'), + (r'([&!]?[=~])(' + _dot + r'*\n)', + bygroups(Punctuation, using(ScalaLexer)), 'root'), + default('plain'), + ], + + 'content': [ + include('css'), + (r'!!!' + _dot + r'*\n', Name.Namespace, '#pop'), + (r'(/)(\[' + _dot + r'*?\])(' + _dot + r'*\n)', + bygroups(Comment, Comment.Special, Comment), + '#pop'), + (r'/' + _dot + r'*\n', _starts_block(Comment, 'html-comment-block'), + '#pop'), + (r'-#' + _dot + r'*\n', _starts_block(Comment.Preproc, + 'scaml-comment-block'), '#pop'), + (r'(-@\s*)(import)?(' + _dot + r'*\n)', + bygroups(Punctuation, Keyword, using(ScalaLexer)), + '#pop'), + (r'(-)(' + _dot + r'*\n)', + bygroups(Punctuation, using(ScalaLexer)), + '#pop'), + (r':' + _dot + r'*\n', _starts_block(Name.Decorator, 'filter-block'), + '#pop'), + (r'[\w:-]+', Name.Tag, 'tag'), + (r'\|', Text, 'eval-or-plain'), + ], + + 'tag': [ + include('css'), + (r'\{(,\n|' + _dot + r')*?\}', using(ScalaLexer)), + (r'\[' + _dot + r'*?\]', using(ScalaLexer)), + (r'\(', Text, 'html-attributes'), + (r'/[ \t]*\n', Punctuation, '#pop:2'), + (r'[<>]{1,2}(?=[ \t=])', Punctuation), + include('eval-or-plain'), + ], + + 'plain': [ + (r'([^#\n]|#[^{\n]|(\\\\)*\\#\{)+', Text), + (r'(#\{)(' + _dot + r'*?)(\})', + bygroups(String.Interpol, using(ScalaLexer), String.Interpol)), + (r'\n', Text, 'root'), + ], + + 'html-attributes': [ + (r'\s+', Text), + (r'[\w:-]+[ \t]*=', Name.Attribute, 'html-attribute-value'), + (r'[\w:-]+', Name.Attribute), + (r'\)', Text, '#pop'), + ], + + 'html-attribute-value': [ + (r'[ \t]+', Text), + (r'\w+', Name.Variable, '#pop'), + (r'@\w+', Name.Variable.Instance, '#pop'), + (r'\$\w+', Name.Variable.Global, '#pop'), + (r"'(\\\\|\\[^\\]|[^'\\\n])*'", String, '#pop'), + (r'"(\\\\|\\[^\\]|[^"\\\n])*"', String, '#pop'), + ], + + 'html-comment-block': [ + (_dot + '+', Comment), + (r'\n', Text, 'root'), + ], + + 'scaml-comment-block': [ + (_dot + '+', Comment.Preproc), + (r'\n', Text, 'root'), + ], + + 'filter-block': [ + (r'([^#\n]|#[^{\n]|(\\\\)*\\#\{)+', Name.Decorator), + (r'(#\{)(' + _dot + r'*?)(\})', + bygroups(String.Interpol, using(ScalaLexer), String.Interpol)), + (r'\n', Text, 'root'), + ], + } +JadeLexer = PugLexer # compat + + +class UrlEncodedLexer(RegexLexer): + """ + Lexer for urlencoded data + """ + + name = 'urlencoded' + aliases = ['urlencoded'] + mimetypes = ['application/x-www-form-urlencoded'] + url = 'https://en.wikipedia.org/wiki/Percent-encoding' + version_added = '2.16' + + tokens = { + 'root': [ + ('([^&=]*)(=)([^=&]*)(&?)', bygroups(Name.Tag, Operator, String, Punctuation)), + ], + } + + +class VueLexer(HtmlLexer): + """ + For Vue Single-File Component. + """ + + name = 'Vue' + url = 'https://vuejs.org/api/sfc-spec.html' + aliases = ['vue'] + filenames = ['*.vue'] + mimetypes = [] + version_added = '2.19' + + flags = re.IGNORECASE | re.DOTALL + tokens = { + 'root': [ + (r'(\{\{)(.*?)(\}\})', bygroups(Comment.Preproc, + using(JavascriptLexer), Comment.Preproc)), + ('[^<&{]+', Text), + inherit, + ], + 'tag': [ + (r'\s+', Text), + (r'((?:[@:]|v-)(?:[.\w:-]|\[[^\]]*?\])+\s*)(=)(\s*)', + bygroups(using(this, state=['name']), Operator, Text), + 'attr-directive'), + (r'([\w:-]+\s*)(=)(\s*)', bygroups(Name.Attribute, Operator, Text), + 'attr'), + (r'[\w:-]+', Name.Attribute), + (r'(/?)(\s*)(>)', bygroups(Punctuation, Text, Punctuation), '#pop'), + ], + 'name': [ + (r'[\w-]+', Name.Attribute), + (r'[:@.]', Punctuation), + (r'(\[)([^\]]*?)(\])', bygroups(Comment.Preproc, + using(JavascriptLexer), Comment.Preproc)), + ], + 'attr-directive': [ + (r'(["\'])(.*?)(\1)', bygroups(String, + using(JavascriptLexer), String), '#pop'), + (r'[^\s>]+', using(JavascriptLexer), '#pop'), + ], + } diff --git a/pygments/lexers/idl.py b/pygments/lexers/idl.py new file mode 100644 index 0000000000000000000000000000000000000000..21b8d3130c75de770ae13908b65618973c979b69 --- /dev/null +++ b/pygments/lexers/idl.py @@ -0,0 +1,284 @@ +""" + pygments.lexers.idl + ~~~~~~~~~~~~~~~~~~~ + + Lexers for IDL. + + :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS. + :license: BSD, see LICENSE for details. +""" + +import re + +from pygments.lexer import RegexLexer, words, bygroups +from pygments.token import Text, Comment, Operator, Keyword, Name, Number, \ + String, Whitespace + +__all__ = ['IDLLexer'] + + +class IDLLexer(RegexLexer): + """ + Pygments Lexer for IDL (Interactive Data Language). + """ + name = 'IDL' + url = 'https://www.l3harrisgeospatial.com/Software-Technology/IDL' + aliases = ['idl'] + filenames = ['*.pro'] + mimetypes = ['text/idl'] + version_added = '1.6' + + flags = re.IGNORECASE | re.MULTILINE + + _RESERVED = ( + 'and', 'begin', 'break', 'case', 'common', 'compile_opt', + 'continue', 'do', 'else', 'end', 'endcase', 'endelse', + 'endfor', 'endforeach', 'endif', 'endrep', 'endswitch', + 'endwhile', 'eq', 'for', 'foreach', 'forward_function', + 'function', 'ge', 'goto', 'gt', 'if', 'inherits', 'le', + 'lt', 'mod', 'ne', 'not', 'of', 'on_ioerror', 'or', 'pro', + 'repeat', 'switch', 'then', 'until', 'while', 'xor') + """Reserved words from: http://www.exelisvis.com/docs/reswords.html""" + + _BUILTIN_LIB = ( + 'abs', 'acos', 'adapt_hist_equal', 'alog', 'alog10', + 'amoeba', 'annotate', 'app_user_dir', 'app_user_dir_query', + 'arg_present', 'array_equal', 'array_indices', 'arrow', + 'ascii_template', 'asin', 'assoc', 'atan', 'axis', + 'a_correlate', 'bandpass_filter', 'bandreject_filter', + 'barplot', 'bar_plot', 'beseli', 'beselj', 'beselk', + 'besely', 'beta', 'bilinear', 'binary_template', 'bindgen', + 'binomial', 'bin_date', 'bit_ffs', 'bit_population', + 'blas_axpy', 'blk_con', 'box_cursor', 'breakpoint', + 'broyden', 'butterworth', 'bytarr', 'byte', 'byteorder', + 'bytscl', 'caldat', 'calendar', 'call_external', + 'call_function', 'call_method', 'call_procedure', 'canny', + 'catch', 'cd', r'cdf_\w*', 'ceil', 'chebyshev', + 'check_math', + 'chisqr_cvf', 'chisqr_pdf', 'choldc', 'cholsol', 'cindgen', + 'cir_3pnt', 'close', 'cluster', 'cluster_tree', 'clust_wts', + 'cmyk_convert', 'colorbar', 'colorize_sample', + 'colormap_applicable', 'colormap_gradient', + 'colormap_rotation', 'colortable', 'color_convert', + 'color_exchange', 'color_quan', 'color_range_map', 'comfit', + 'command_line_args', 'complex', 'complexarr', 'complexround', + 'compute_mesh_normals', 'cond', 'congrid', 'conj', + 'constrained_min', 'contour', 'convert_coord', 'convol', + 'convol_fft', 'coord2to3', 'copy_lun', 'correlate', 'cos', + 'cosh', 'cpu', 'cramer', 'create_cursor', 'create_struct', + 'create_view', 'crossp', 'crvlength', 'cti_test', + 'ct_luminance', 'cursor', 'curvefit', 'cvttobm', 'cv_coord', + 'cw_animate', 'cw_animate_getp', 'cw_animate_load', + 'cw_animate_run', 'cw_arcball', 'cw_bgroup', 'cw_clr_index', + 'cw_colorsel', 'cw_defroi', 'cw_field', 'cw_filesel', + 'cw_form', 'cw_fslider', 'cw_light_editor', + 'cw_light_editor_get', 'cw_light_editor_set', 'cw_orient', + 'cw_palette_editor', 'cw_palette_editor_get', + 'cw_palette_editor_set', 'cw_pdmenu', 'cw_rgbslider', + 'cw_tmpl', 'cw_zoom', 'c_correlate', 'dblarr', 'db_exists', + 'dcindgen', 'dcomplex', 'dcomplexarr', 'define_key', + 'define_msgblk', 'define_msgblk_from_file', 'defroi', + 'defsysv', 'delvar', 'dendrogram', 'dendro_plot', 'deriv', + 'derivsig', 'determ', 'device', 'dfpmin', 'diag_matrix', + 'dialog_dbconnect', 'dialog_message', 'dialog_pickfile', + 'dialog_printersetup', 'dialog_printjob', + 'dialog_read_image', 'dialog_write_image', 'digital_filter', + 'dilate', 'dindgen', 'dissolve', 'dist', 'distance_measure', + 'dlm_load', 'dlm_register', 'doc_library', 'double', + 'draw_roi', 'edge_dog', 'efont', 'eigenql', 'eigenvec', + 'ellipse', 'elmhes', 'emboss', 'empty', 'enable_sysrtn', + 'eof', r'eos_\w*', 'erase', 'erf', 'erfc', 'erfcx', + 'erode', 'errorplot', 'errplot', 'estimator_filter', + 'execute', 'exit', 'exp', 'expand', 'expand_path', 'expint', + 'extrac', 'extract_slice', 'factorial', 'fft', 'filepath', + 'file_basename', 'file_chmod', 'file_copy', 'file_delete', + 'file_dirname', 'file_expand_path', 'file_info', + 'file_lines', 'file_link', 'file_mkdir', 'file_move', + 'file_poll_input', 'file_readlink', 'file_same', + 'file_search', 'file_test', 'file_which', 'findgen', + 'finite', 'fix', 'flick', 'float', 'floor', 'flow3', + 'fltarr', 'flush', 'format_axis_values', 'free_lun', + 'fstat', 'fulstr', 'funct', 'fv_test', 'fx_root', + 'fz_roots', 'f_cvf', 'f_pdf', 'gamma', 'gamma_ct', + 'gauss2dfit', 'gaussfit', 'gaussian_function', 'gaussint', + 'gauss_cvf', 'gauss_pdf', 'gauss_smooth', 'getenv', + 'getwindows', 'get_drive_list', 'get_dxf_objects', + 'get_kbrd', 'get_login_info', 'get_lun', 'get_screen_size', + 'greg2jul', r'grib_\w*', 'grid3', 'griddata', + 'grid_input', 'grid_tps', 'gs_iter', + r'h5[adfgirst]_\w*', 'h5_browser', 'h5_close', + 'h5_create', 'h5_get_libversion', 'h5_open', 'h5_parse', + 'hanning', 'hash', r'hdf_\w*', 'heap_free', + 'heap_gc', 'heap_nosave', 'heap_refcount', 'heap_save', + 'help', 'hilbert', 'histogram', 'hist_2d', 'hist_equal', + 'hls', 'hough', 'hqr', 'hsv', 'h_eq_ct', 'h_eq_int', + 'i18n_multibytetoutf8', 'i18n_multibytetowidechar', + 'i18n_utf8tomultibyte', 'i18n_widechartomultibyte', + 'ibeta', 'icontour', 'iconvertcoord', 'idelete', 'identity', + 'idlexbr_assistant', 'idlitsys_createtool', 'idl_base64', + 'idl_validname', 'iellipse', 'igamma', 'igetcurrent', + 'igetdata', 'igetid', 'igetproperty', 'iimage', 'image', + 'image_cont', 'image_statistics', 'imaginary', 'imap', + 'indgen', 'intarr', 'interpol', 'interpolate', + 'interval_volume', 'int_2d', 'int_3d', 'int_tabulated', + 'invert', 'ioctl', 'iopen', 'iplot', 'ipolygon', + 'ipolyline', 'iputdata', 'iregister', 'ireset', 'iresolve', + 'irotate', 'ir_filter', 'isa', 'isave', 'iscale', + 'isetcurrent', 'isetproperty', 'ishft', 'isocontour', + 'isosurface', 'isurface', 'itext', 'itranslate', 'ivector', + 'ivolume', 'izoom', 'i_beta', 'journal', 'json_parse', + 'json_serialize', 'jul2greg', 'julday', 'keyword_set', + 'krig2d', 'kurtosis', 'kw_test', 'l64indgen', 'label_date', + 'label_region', 'ladfit', 'laguerre', 'laplacian', + 'la_choldc', 'la_cholmprove', 'la_cholsol', 'la_determ', + 'la_eigenproblem', 'la_eigenql', 'la_eigenvec', 'la_elmhes', + 'la_gm_linear_model', 'la_hqr', 'la_invert', + 'la_least_squares', 'la_least_square_equality', + 'la_linear_equation', 'la_ludc', 'la_lumprove', 'la_lusol', + 'la_svd', 'la_tridc', 'la_trimprove', 'la_triql', + 'la_trired', 'la_trisol', 'least_squares_filter', 'leefilt', + 'legend', 'legendre', 'linbcg', 'lindgen', 'linfit', + 'linkimage', 'list', 'll_arc_distance', 'lmfit', 'lmgr', + 'lngamma', 'lnp_test', 'loadct', 'locale_get', + 'logical_and', 'logical_or', 'logical_true', 'lon64arr', + 'lonarr', 'long', 'long64', 'lsode', 'ludc', 'lumprove', + 'lusol', 'lu_complex', 'machar', 'make_array', 'make_dll', + 'make_rt', 'map', 'mapcontinents', 'mapgrid', 'map_2points', + 'map_continents', 'map_grid', 'map_image', 'map_patch', + 'map_proj_forward', 'map_proj_image', 'map_proj_info', + 'map_proj_init', 'map_proj_inverse', 'map_set', + 'matrix_multiply', 'matrix_power', 'max', 'md_test', + 'mean', 'meanabsdev', 'mean_filter', 'median', 'memory', + 'mesh_clip', 'mesh_decimate', 'mesh_issolid', 'mesh_merge', + 'mesh_numtriangles', 'mesh_obj', 'mesh_smooth', + 'mesh_surfacearea', 'mesh_validate', 'mesh_volume', + 'message', 'min', 'min_curve_surf', 'mk_html_help', + 'modifyct', 'moment', 'morph_close', 'morph_distance', + 'morph_gradient', 'morph_hitormiss', 'morph_open', + 'morph_thin', 'morph_tophat', 'multi', 'm_correlate', + r'ncdf_\w*', 'newton', 'noise_hurl', 'noise_pick', + 'noise_scatter', 'noise_slur', 'norm', 'n_elements', + 'n_params', 'n_tags', 'objarr', 'obj_class', 'obj_destroy', + 'obj_hasmethod', 'obj_isa', 'obj_new', 'obj_valid', + 'online_help', 'on_error', 'open', 'oplot', 'oploterr', + 'parse_url', 'particle_trace', 'path_cache', 'path_sep', + 'pcomp', 'plot', 'plot3d', 'ploterr', 'plots', 'plot_3dbox', + 'plot_field', 'pnt_line', 'point_lun', 'polarplot', + 'polar_contour', 'polar_surface', 'poly', 'polyfill', + 'polyfillv', 'polygon', 'polyline', 'polyshade', 'polywarp', + 'poly_2d', 'poly_area', 'poly_fit', 'popd', 'powell', + 'pref_commit', 'pref_get', 'pref_set', 'prewitt', 'primes', + 'print', 'printd', 'product', 'profile', 'profiler', + 'profiles', 'project_vol', 'psafm', 'pseudo', + 'ps_show_fonts', 'ptrarr', 'ptr_free', 'ptr_new', + 'ptr_valid', 'pushd', 'p_correlate', 'qgrid3', 'qhull', + 'qromb', 'qromo', 'qsimp', 'query_ascii', 'query_bmp', + 'query_csv', 'query_dicom', 'query_gif', 'query_image', + 'query_jpeg', 'query_jpeg2000', 'query_mrsid', 'query_pict', + 'query_png', 'query_ppm', 'query_srf', 'query_tiff', + 'query_wav', 'radon', 'randomn', 'randomu', 'ranks', + 'rdpix', 'read', 'reads', 'readu', 'read_ascii', + 'read_binary', 'read_bmp', 'read_csv', 'read_dicom', + 'read_gif', 'read_image', 'read_interfile', 'read_jpeg', + 'read_jpeg2000', 'read_mrsid', 'read_pict', 'read_png', + 'read_ppm', 'read_spr', 'read_srf', 'read_sylk', + 'read_tiff', 'read_wav', 'read_wave', 'read_x11_bitmap', + 'read_xwd', 'real_part', 'rebin', 'recall_commands', + 'recon3', 'reduce_colors', 'reform', 'region_grow', + 'register_cursor', 'regress', 'replicate', + 'replicate_inplace', 'resolve_all', 'resolve_routine', + 'restore', 'retall', 'return', 'reverse', 'rk4', 'roberts', + 'rot', 'rotate', 'round', 'routine_filepath', + 'routine_info', 'rs_test', 'r_correlate', 'r_test', + 'save', 'savgol', 'scale3', 'scale3d', 'scope_level', + 'scope_traceback', 'scope_varfetch', 'scope_varname', + 'search2d', 'search3d', 'sem_create', 'sem_delete', + 'sem_lock', 'sem_release', 'setenv', 'set_plot', + 'set_shading', 'sfit', 'shade_surf', 'shade_surf_irr', + 'shade_volume', 'shift', 'shift_diff', 'shmdebug', 'shmmap', + 'shmunmap', 'shmvar', 'show3', 'showfont', 'simplex', 'sin', + 'sindgen', 'sinh', 'size', 'skewness', 'skip_lun', + 'slicer3', 'slide_image', 'smooth', 'sobel', 'socket', + 'sort', 'spawn', 'spher_harm', 'sph_4pnt', 'sph_scat', + 'spline', 'spline_p', 'spl_init', 'spl_interp', 'sprsab', + 'sprsax', 'sprsin', 'sprstp', 'sqrt', 'standardize', + 'stddev', 'stop', 'strarr', 'strcmp', 'strcompress', + 'streamline', 'stregex', 'stretch', 'string', 'strjoin', + 'strlen', 'strlowcase', 'strmatch', 'strmessage', 'strmid', + 'strpos', 'strput', 'strsplit', 'strtrim', 'struct_assign', + 'struct_hide', 'strupcase', 'surface', 'surfr', 'svdc', + 'svdfit', 'svsol', 'swap_endian', 'swap_endian_inplace', + 'symbol', 'systime', 's_test', 't3d', 'tag_names', 'tan', + 'tanh', 'tek_color', 'temporary', 'tetra_clip', + 'tetra_surface', 'tetra_volume', 'text', 'thin', 'threed', + 'timegen', 'time_test2', 'tm_test', 'total', 'trace', + 'transpose', 'triangulate', 'trigrid', 'triql', 'trired', + 'trisol', 'tri_surf', 'truncate_lun', 'ts_coef', 'ts_diff', + 'ts_fcast', 'ts_smooth', 'tv', 'tvcrs', 'tvlct', 'tvrd', + 'tvscl', 'typename', 't_cvt', 't_pdf', 'uindgen', 'uint', + 'uintarr', 'ul64indgen', 'ulindgen', 'ulon64arr', 'ulonarr', + 'ulong', 'ulong64', 'uniq', 'unsharp_mask', 'usersym', + 'value_locate', 'variance', 'vector', 'vector_field', 'vel', + 'velovect', 'vert_t3d', 'voigt', 'voronoi', 'voxel_proj', + 'wait', 'warp_tri', 'watershed', 'wdelete', 'wf_draw', + 'where', 'widget_base', 'widget_button', 'widget_combobox', + 'widget_control', 'widget_displaycontextmen', 'widget_draw', + 'widget_droplist', 'widget_event', 'widget_info', + 'widget_label', 'widget_list', 'widget_propertysheet', + 'widget_slider', 'widget_tab', 'widget_table', + 'widget_text', 'widget_tree', 'widget_tree_move', + 'widget_window', 'wiener_filter', 'window', 'writeu', + 'write_bmp', 'write_csv', 'write_gif', 'write_image', + 'write_jpeg', 'write_jpeg2000', 'write_nrif', 'write_pict', + 'write_png', 'write_ppm', 'write_spr', 'write_srf', + 'write_sylk', 'write_tiff', 'write_wav', 'write_wave', + 'wset', 'wshow', 'wtn', 'wv_applet', 'wv_cwt', + 'wv_cw_wavelet', 'wv_denoise', 'wv_dwt', 'wv_fn_coiflet', + 'wv_fn_daubechies', 'wv_fn_gaussian', 'wv_fn_haar', + 'wv_fn_morlet', 'wv_fn_paul', 'wv_fn_symlet', + 'wv_import_data', 'wv_import_wavelet', 'wv_plot3d_wps', + 'wv_plot_multires', 'wv_pwt', 'wv_tool_denoise', + 'xbm_edit', 'xdisplayfile', 'xdxf', 'xfont', + 'xinteranimate', 'xloadct', 'xmanager', 'xmng_tmpl', + 'xmtool', 'xobjview', 'xobjview_rotate', + 'xobjview_write_image', 'xpalette', 'xpcolor', 'xplot3d', + 'xregistered', 'xroi', 'xsq_test', 'xsurface', 'xvaredit', + 'xvolume', 'xvolume_rotate', 'xvolume_write_image', + 'xyouts', 'zoom', 'zoom_24') + """Functions from: http://www.exelisvis.com/docs/routines-1.html""" + + tokens = { + 'root': [ + (r'(^\s*)(;.*?)(\n)', bygroups(Whitespace, Comment.Single, + Whitespace)), + (words(_RESERVED, prefix=r'\b', suffix=r'\b'), Keyword), + (words(_BUILTIN_LIB, prefix=r'\b', suffix=r'\b'), Name.Builtin), + (r'\+=|-=|\^=|\*=|/=|#=|##=|<=|>=|=', Operator), + (r'\+\+|--|->|\+|-|##|#|\*|/|<|>|&&|\^|~|\|\|\?|:', Operator), + (r'\b(mod=|lt=|le=|eq=|ne=|ge=|gt=|not=|and=|or=|xor=)', Operator), + (r'\b(mod|lt|le|eq|ne|ge|gt|not|and|or|xor)\b', Operator), + (r'"[^\"]*"', String.Double), + (r"'[^\']*'", String.Single), + (r'\b[+\-]?([0-9]*\.[0-9]+|[0-9]+\.[0-9]*)(D|E)?([+\-]?[0-9]+)?\b', + Number.Float), + (r'\b\'[+\-]?[0-9A-F]+\'X(U?(S?|L{1,2})|B)\b', Number.Hex), + (r'\b\'[+\-]?[0-7]+\'O(U?(S?|L{1,2})|B)\b', Number.Oct), + (r'\b[+\-]?[0-9]+U?L{1,2}\b', Number.Integer.Long), + (r'\b[+\-]?[0-9]+U?S?\b', Number.Integer), + (r'\b[+\-]?[0-9]+B\b', Number), + (r'[ \t]+', Whitespace), + (r'\n', Whitespace), + (r'.', Text), + ] + } + + def analyse_text(text): + """endelse seems to be unique to IDL, endswitch is rare at least.""" + result = 0 + + if 'endelse' in text: + result += 0.2 + if 'endswitch' in text: + result += 0.01 + + return result diff --git a/pygments/lexers/igor.py b/pygments/lexers/igor.py new file mode 100644 index 0000000000000000000000000000000000000000..136c5b89224f0bcad2d7a92be246ef7f138e16e9 --- /dev/null +++ b/pygments/lexers/igor.py @@ -0,0 +1,435 @@ +""" + pygments.lexers.igor + ~~~~~~~~~~~~~~~~~~~~ + + Lexers for Igor Pro. + + :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS. + :license: BSD, see LICENSE for details. +""" + +import re + +from pygments.lexer import RegexLexer, words +from pygments.token import Text, Comment, Keyword, Name, String, Whitespace + +__all__ = ['IgorLexer'] + + +class IgorLexer(RegexLexer): + """ + Pygments Lexer for Igor Pro procedure files (.ipf). + """ + + name = 'Igor' + aliases = ['igor', 'igorpro'] + filenames = ['*.ipf'] + mimetypes = ['text/ipf'] + url = 'http://www.wavemetrics.com' + version_added = '2.0' + + flags = re.IGNORECASE | re.MULTILINE + + flowControl = ( + 'if', 'else', 'elseif', 'endif', 'for', 'endfor', 'strswitch', 'switch', + 'case', 'default', 'endswitch', 'do', 'while', 'try', 'catch', 'endtry', + 'break', 'continue', 'return', 'AbortOnRTE', 'AbortOnValue' + ) + types = ( + 'variable', 'string', 'constant', 'strconstant', 'NVAR', 'SVAR', 'WAVE', + 'STRUCT', 'dfref', 'funcref', 'char', 'uchar', 'int16', 'uint16', 'int32', + 'uint32', 'int64', 'uint64', 'float', 'double', 'int' + ) + keywords = ( + 'override', 'ThreadSafe', 'MultiThread', 'static', 'Proc', + 'Picture', 'Prompt', 'DoPrompt', 'macro', 'window', 'function', 'end', + 'Structure', 'EndStructure', 'EndMacro', 'Menu', 'SubMenu' + ) + operations = ( + 'Abort', 'AddFIFOData', 'AddFIFOVectData', 'AddMovieAudio', 'AddMovieFrame', + 'AddWavesToBoxPlot', 'AddWavesToViolinPlot', 'AdoptFiles', 'APMath', 'Append', + 'AppendBoxPlot', 'AppendImage', 'AppendLayoutObject', 'AppendMatrixContour', + 'AppendText', 'AppendToGizmo', 'AppendToGraph', 'AppendToLayout', + 'AppendToTable', 'AppendViolinPlot', 'AppendXYZContour', 'AutoPositionWindow', + 'AxonTelegraphFindServers', 'BackgroundInfo', 'Beep', 'BezierToPolygon', + 'BoundingBall', 'BoxSmooth', 'BrowseURL', 'BuildMenu', 'Button', 'cd', 'Chart', + 'CheckBox', 'CheckDisplayed', 'ChooseColor', 'Close', 'CloseHelp', 'CloseMovie', + 'CloseProc', 'ColorScale', 'ColorTab2Wave', 'Concatenate', 'ControlBar', + 'ControlInfo', 'ControlUpdate', 'ConvertGlobalStringTextEncoding', 'ConvexHull', + 'Convolve', 'CopyDimLabels', 'CopyFile', 'CopyFolder', 'CopyScales', 'Correlate', + 'CreateAliasShortcut', 'CreateBrowser', 'Cross', 'CtrlBackground', 'CtrlFIFO', + 'CtrlNamedBackground', 'Cursor', 'CurveFit', 'CustomControl', 'CWT', + 'DAQmx_AI_SetupReader', 'DAQmx_AO_SetOutputs', 'DAQmx_CTR_CountEdges', + 'DAQmx_CTR_OutputPulse', 'DAQmx_CTR_Period', 'DAQmx_CTR_PulseWidth', + 'DAQmx_DeviceInfo', 'DAQmx_DIO_Config', 'DAQmx_DIO_WriteNewData', 'DAQmx_Scan', + 'DAQmx_WaveformGen', 'Debugger', 'DebuggerOptions', 'DefaultFont', + 'DefaultGuiControls', 'DefaultGuiFont', 'DefaultTextEncoding', 'DefineGuide', + 'DelayUpdate', 'DeleteAnnotations', 'DeleteFile', 'DeleteFolder', 'DeletePoints', + 'Differentiate', 'dir', 'Display', 'DisplayHelpTopic', 'DisplayProcedure', + 'DoAlert', 'DoIgorMenu', 'DoUpdate', 'DoWindow', 'DoXOPIdle', 'DPSS', + 'DrawAction', 'DrawArc', 'DrawBezier', 'DrawLine', 'DrawOval', 'DrawPICT', + 'DrawPoly', 'DrawRect', 'DrawRRect', 'DrawText', 'DrawUserShape', 'DSPDetrend', + 'DSPPeriodogram', 'Duplicate', 'DuplicateDataFolder', 'DWT', 'EdgeStats', 'Edit', + 'ErrorBars', 'EstimatePeakSizes', 'Execute', 'ExecuteScriptText', + 'ExperimentInfo', 'ExperimentModified', 'ExportGizmo', 'Extract', + 'FastGaussTransform', 'FastOp', 'FBinRead', 'FBinWrite', 'FCALL_CallFunction', + 'FCALL_FreeLibrary', 'FCALL_GetFunctionList', 'FCALL_GetParamTypeList', + 'FCALL_LoadLibrary', 'FCALL_Version', 'FFT', 'FGetPos', 'FIFOStatus', + 'FIFO2Wave', 'FilterFIR', 'FilterIIR', 'FindAPeak', 'FindContour', + 'FindDuplicates', 'FindLevel', 'FindLevels', 'FindPeak', 'FindPointsInPoly', + 'FindRoots', 'FindSequence', 'FindValue', 'FMaxFlat', 'FPClustering', 'fprintf', + 'FReadLine', 'FSetPos', 'FStatus', 'FTPCreateDirectory', 'FTPDelete', + 'FTPDownload', 'FTPUpload', 'FuncFit', 'FuncFitMD', 'GBLoadWave', 'GetAxis', + 'GetCamera', 'GetFileFolderInfo', 'GetGizmo', 'GetLastUserMenuInfo', + 'GetMarquee', 'GetMouse', 'GetSelection', 'GetWindow', 'GISCreateVectorLayer', + 'GISGetRasterInfo', 'GISGetRegisteredFileInfo', 'GISGetVectorLayerInfo', + 'GISLoadRasterData', 'GISLoadVectorData', 'GISRasterizeVectorData', + 'GISRegisterFile', 'GISTransformCoords', 'GISUnRegisterFile', + 'GISWriteFieldData', 'GISWriteGeometryData', 'GISWriteRaster', + 'GPIBReadBinaryWave2', 'GPIBReadBinary2', 'GPIBReadWave2', 'GPIBRead2', + 'GPIBWriteBinaryWave2', 'GPIBWriteBinary2', 'GPIBWriteWave2', 'GPIBWrite2', + 'GPIB2', 'GraphNormal', 'GraphWaveDraw', 'GraphWaveEdit', 'Grep', 'GroupBox', + 'Hanning', 'HCluster', 'HDFInfo', 'HDFReadImage', 'HDFReadSDS', 'HDFReadVset', + 'HDF5CloseFile', 'HDF5CloseGroup', 'HDF5Control', 'HDF5CreateFile', + 'HDF5CreateGroup', 'HDF5CreateLink', 'HDF5DimensionScale', 'HDF5Dump', + 'HDF5DumpErrors', 'HDF5FlushFile', 'HDF5ListAttributes', 'HDF5ListGroup', + 'HDF5LoadData', 'HDF5LoadGroup', 'HDF5LoadImage', 'HDF5OpenFile', + 'HDF5OpenGroup', 'HDF5SaveData', 'HDF5SaveGroup', 'HDF5SaveImage', + 'HDF5UnlinkObject', 'HideIgorMenus', 'HideInfo', 'HideProcedures', 'HideTools', + 'HilbertTransform', 'Histogram', 'ICA', 'IFFT', 'ImageAnalyzeParticles', + 'ImageBlend', 'ImageBoundaryToMask', 'ImageComposite', 'ImageEdgeDetection', + 'ImageFileInfo', 'ImageFilter', 'ImageFocus', 'ImageFromXYZ', + 'ImageGenerateROIMask', 'ImageGLCM', 'ImageHistModification', 'ImageHistogram', + 'ImageInterpolate', 'ImageLineProfile', 'ImageLoad', 'ImageMorphology', + 'ImageRegistration', 'ImageRemoveBackground', 'ImageRestore', 'ImageRotate', + 'ImageSave', 'ImageSeedFill', 'ImageSkeleton3d', 'ImageSnake', 'ImageStats', + 'ImageThreshold', 'ImageTransform', 'ImageUnwrapPhase', 'ImageWindow', + 'IndexSort', 'InsertPoints', 'InstantFrequency', 'Integrate', 'IntegrateODE', + 'Integrate2D', 'Interpolate2', 'Interpolate3D', 'Interp3DPath', 'ITCCloseAll2', + 'ITCCloseDevice2', 'ITCConfigAllChannels2', 'ITCConfigChannelReset2', + 'ITCConfigChannelUpload2', 'ITCConfigChannel2', 'ITCFIFOAvailableAll2', + 'ITCFIFOAvailable2', 'ITCGetAllChannelsConfig2', 'ITCGetChannelConfig2', + 'ITCGetCurrentDevice2', 'ITCGetDeviceInfo2', 'ITCGetDevices2', + 'ITCGetErrorString2', 'ITCGetSerialNumber2', 'ITCGetState2', 'ITCGetVersions2', + 'ITCInitialize2', 'ITCOpenDevice2', 'ITCReadADC2', 'ITCReadDigital2', + 'ITCReadTimer2', 'ITCSelectDevice2', 'ITCSetDAC2', 'ITCSetGlobals2', + 'ITCSetModes2', 'ITCSetState2', 'ITCStartAcq2', 'ITCStopAcq2', + 'ITCUpdateFIFOPositionAll2', 'ITCUpdateFIFOPosition2', 'ITCWriteDigital2', + 'JCAMPLoadWave', 'JointHistogram', 'JSONXOP_AddTree', 'JSONXOP_AddValue', + 'JSONXOP_Dump', 'JSONXOP_GetArraySize', 'JSONXOP_GetKeys', + 'JSONXOP_GetMaxArraySize', 'JSONXOP_GetType', 'JSONXOP_GetValue', 'JSONXOP_New', + 'JSONXOP_Parse', 'JSONXOP_Release', 'JSONXOP_Remove', 'JSONXOP_Version', + 'KillBackground', 'KillControl', 'KillDataFolder', 'KillFIFO', 'KillFreeAxis', + 'KillPath', 'KillPICTs', 'KillStrings', 'KillVariables', 'KillWaves', + 'KillWindow', 'KMeans', 'Label', 'Layout', 'LayoutPageAction', + 'LayoutSlideShow', 'Legend', 'LinearFeedbackShiftRegister', 'ListBox', + 'LoadData', 'LoadPackagePreferences', 'LoadPICT', 'LoadWave', 'Loess', + 'LombPeriodogram', 'Make', 'MakeIndex', 'MarkPerfTestTime', 'MatrixBalance', + 'MatrixConvolve', 'MatrixCorr', 'MatrixEigenV', 'MatrixFactor', 'MatrixFilter', + 'MatrixGaussJ', 'MatrixGLM', 'MatrixInverse', 'MatrixLinearSolve', + 'MatrixLinearSolveTD', 'MatrixLLS', 'MatrixLUBkSub', 'MatrixLUD', 'MatrixLUDTD', + 'MatrixMultiply', 'MatrixMultiplyAdd', 'MatrixOP', 'MatrixReverseBalance', + 'MatrixSchur', 'MatrixSolve', 'MatrixSparse', 'MatrixSVBkSub', 'MatrixSVD', + 'MatrixTranspose', 'MCC_FindServers', 'MeasureStyledText', + 'MFR_CheckForNewBricklets', 'MFR_CloseResultFile', 'MFR_CreateOverviewTable', + 'MFR_GetBrickletCount', 'MFR_GetBrickletData', 'MFR_GetBrickletDeployData', + 'MFR_GetBrickletMetaData', 'MFR_GetBrickletRawData', 'MFR_GetReportTemplate', + 'MFR_GetResultFileMetaData', 'MFR_GetResultFileName', + 'MFR_GetVernissageVersion', 'MFR_GetVersion', 'MFR_GetXOPErrorMessage', + 'MFR_OpenResultFile', 'MLLoadWave', 'Modify', 'ModifyBoxPlot', 'ModifyBrowser', + 'ModifyCamera', 'ModifyContour', 'ModifyControl', 'ModifyControlList', + 'ModifyFreeAxis', 'ModifyGizmo', 'ModifyGraph', 'ModifyImage', 'ModifyLayout', + 'ModifyPanel', 'ModifyProcedure', 'ModifyTable', 'ModifyViolinPlot', + 'ModifyWaterfall', 'MoveDataFolder', 'MoveFile', 'MoveFolder', 'MoveString', + 'MoveSubwindow', 'MoveVariable', 'MoveWave', 'MoveWindow', 'MultiTaperPSD', + 'MultiThreadingControl', 'NC_CloseFile', 'NC_DumpErrors', 'NC_Inquire', + 'NC_ListAttributes', 'NC_ListObjects', 'NC_LoadData', 'NC_OpenFile', + 'NeuralNetworkRun', 'NeuralNetworkTrain', 'NewCamera', 'NewDataFolder', + 'NewFIFO', 'NewFIFOChan', 'NewFreeAxis', 'NewGizmo', 'NewImage', 'NewLayout', + 'NewMovie', 'NewNotebook', 'NewPanel', 'NewPath', 'NewWaterfall', 'NILoadWave', + 'NI4882', 'Note', 'Notebook', 'NotebookAction', 'Open', 'OpenHelp', + 'OpenNotebook', 'Optimize', 'ParseOperationTemplate', 'PathInfo', + 'PauseForUser', 'PauseUpdate', 'PCA', 'PlayMovie', 'PlayMovieAction', + 'PlaySound', 'PolygonOp', 'PopupContextualMenu', 'PopupMenu', 'Preferences', + 'PrimeFactors', 'Print', 'printf', 'PrintGraphs', 'PrintLayout', + 'PrintNotebook', 'PrintSettings', 'PrintTable', 'Project', 'PulseStats', + 'PutScrapText', 'pwd', 'Quit', 'RatioFromNumber', 'Redimension', 'Remez', + 'Remove', 'RemoveContour', 'RemoveFromGizmo', 'RemoveFromGraph', + 'RemoveFromLayout', 'RemoveFromTable', 'RemoveImage', 'RemoveLayoutObjects', + 'RemovePath', 'Rename', 'RenameDataFolder', 'RenamePath', 'RenamePICT', + 'RenameWindow', 'ReorderImages', 'ReorderTraces', 'ReplaceText', 'ReplaceWave', + 'Resample', 'ResumeUpdate', 'Reverse', 'Rotate', 'Save', 'SaveData', + 'SaveExperiment', 'SaveGizmoCopy', 'SaveGraphCopy', 'SaveNotebook', + 'SavePackagePreferences', 'SavePICT', 'SaveTableCopy', 'SetActiveSubwindow', + 'SetAxis', 'SetBackground', 'SetDashPattern', 'SetDataFolder', 'SetDimLabel', + 'SetDrawEnv', 'SetDrawLayer', 'SetFileFolderInfo', 'SetFormula', + 'SetIdlePeriod', 'SetIgorHook', 'SetIgorMenuMode', 'SetIgorOption', + 'SetMarquee', 'SetProcessSleep', 'SetRandomSeed', 'SetScale', 'SetVariable', + 'SetWaveLock', 'SetWaveTextEncoding', 'SetWindow', 'ShowIgorMenus', 'ShowInfo', + 'ShowTools', 'Silent', 'Sleep', 'Slider', 'Smooth', 'SmoothCustom', 'Sort', + 'SortColumns', 'SoundInRecord', 'SoundInSet', 'SoundInStartChart', + 'SoundInStatus', 'SoundInStopChart', 'SoundLoadWave', 'SoundSaveWave', + 'SphericalInterpolate', 'SphericalTriangulate', 'SplitString', 'SplitWave', + 'sprintf', 'SQLHighLevelOp', 'sscanf', 'Stack', 'StackWindows', + 'StatsAngularDistanceTest', 'StatsANOVA1Test', 'StatsANOVA2NRTest', + 'StatsANOVA2RMTest', 'StatsANOVA2Test', 'StatsChiTest', + 'StatsCircularCorrelationTest', 'StatsCircularMeans', 'StatsCircularMoments', + 'StatsCircularTwoSampleTest', 'StatsCochranTest', 'StatsContingencyTable', + 'StatsDIPTest', 'StatsDunnettTest', 'StatsFriedmanTest', 'StatsFTest', + 'StatsHodgesAjneTest', 'StatsJBTest', 'StatsKDE', 'StatsKendallTauTest', + 'StatsKSTest', 'StatsKWTest', 'StatsLinearCorrelationTest', + 'StatsLinearRegression', 'StatsMultiCorrelationTest', 'StatsNPMCTest', + 'StatsNPNominalSRTest', 'StatsQuantiles', 'StatsRankCorrelationTest', + 'StatsResample', 'StatsSample', 'StatsScheffeTest', 'StatsShapiroWilkTest', + 'StatsSignTest', 'StatsSRTest', 'StatsTTest', 'StatsTukeyTest', + 'StatsVariancesTest', 'StatsWatsonUSquaredTest', 'StatsWatsonWilliamsTest', + 'StatsWheelerWatsonTest', 'StatsWilcoxonRankTest', 'StatsWRCorrelationTest', + 'STFT', 'StructFill', 'StructGet', 'StructPut', 'SumDimension', 'SumSeries', + 'TabControl', 'Tag', 'TDMLoadData', 'TDMSaveData', 'TextBox', 'TextHistogram', + 'Text2Bezier', 'ThreadGroupPutDF', 'ThreadStart', 'TickWavesFromAxis', 'Tile', + 'TileWindows', 'TitleBox', 'ToCommandLine', 'ToolsGrid', 'Triangulate3d', + 'TUFXOP_AcquireLock', 'TUFXOP_Clear', 'TUFXOP_GetStorage', 'TUFXOP_Init', + 'TUFXOP_ReleaseLock', 'TUFXOP_RunningInMainThread', 'TUFXOP_Version', 'Unwrap', + 'UnzipFile', 'URLRequest', 'ValDisplay', 'VDTClosePort2', 'VDTGetPortList2', + 'VDTGetStatus2', 'VDTOpenPort2', 'VDTOperationsPort2', 'VDTReadBinaryWave2', + 'VDTReadBinary2', 'VDTReadHexWave2', 'VDTReadHex2', 'VDTReadWave2', 'VDTRead2', + 'VDTTerminalPort2', 'VDTWriteBinaryWave2', 'VDTWriteBinary2', + 'VDTWriteHexWave2', 'VDTWriteHex2', 'VDTWriteWave2', 'VDTWrite2', 'VDT2', + 'VISAControl', 'VISARead', 'VISAReadBinary', 'VISAReadBinaryWave', + 'VISAReadWave', 'VISAWrite', 'VISAWriteBinary', 'VISAWriteBinaryWave', + 'VISAWriteWave', 'WaveMeanStdv', 'WaveStats', 'WaveTracking', 'WaveTransform', + 'wfprintf', 'WignerTransform', 'WindowFunction', 'XLLoadWave' + ) + functions = ( + 'abs', 'acos', 'acosh', 'AddListItem', 'AiryA', 'AiryAD', 'AiryB', 'AiryBD', + 'alog', 'AnnotationInfo', 'AnnotationList', 'area', 'areaXY', 'asin', 'asinh', + 'atan', 'atanh', 'atan2', 'AxisInfo', 'AxisLabel', 'AxisList', + 'AxisValFromPixel', 'AxonTelegraphAGetDataNum', 'AxonTelegraphAGetDataString', + 'AxonTelegraphAGetDataStruct', 'AxonTelegraphGetDataNum', + 'AxonTelegraphGetDataString', 'AxonTelegraphGetDataStruct', + 'AxonTelegraphGetTimeoutMs', 'AxonTelegraphSetTimeoutMs', 'Base64Decode', + 'Base64Encode', 'Besseli', 'Besselj', 'Besselk', 'Bessely', 'beta', 'betai', + 'BinarySearch', 'BinarySearchInterp', 'binomial', 'binomialln', 'binomialNoise', + 'cabs', 'CaptureHistory', 'CaptureHistoryStart', 'ceil', 'centerOfMass', + 'centerOfMassXY', 'cequal', 'char2num', 'chebyshev', 'chebyshevU', 'CheckName', + 'ChildWindowList', 'CleanupName', 'cmplx', 'cmpstr', 'conj', 'ContourInfo', + 'ContourNameList', 'ContourNameToWaveRef', 'ContourZ', 'ControlNameList', + 'ConvertTextEncoding', 'cos', 'cosh', 'cosIntegral', 'cot', 'coth', + 'CountObjects', 'CountObjectsDFR', 'cpowi', 'CreateDataObjectName', + 'CreationDate', 'csc', 'csch', 'CsrInfo', 'CsrWave', 'CsrWaveRef', 'CsrXWave', + 'CsrXWaveRef', 'CTabList', 'DataFolderDir', 'DataFolderExists', + 'DataFolderList', 'DataFolderRefChanges', 'DataFolderRefsEqual', + 'DataFolderRefStatus', 'date', 'datetime', 'DateToJulian', 'date2secs', + 'Dawson', 'defined', 'deltax', 'digamma', 'dilogarithm', 'DimDelta', + 'DimOffset', 'DimSize', 'ei', 'ellipticE', 'ellipticK', 'enoise', 'equalWaves', + 'erf', 'erfc', 'erfcw', 'erfcx', 'exists', 'exp', 'expInt', 'expIntegralE1', + 'expNoise', 'factorial', 'Faddeeva', 'fakedata', 'faverage', 'faverageXY', + 'fDAQmx_AI_ChannelConfigs', 'fDAQmx_AI_GetReader', 'fDAQmx_AO_UpdateOutputs', + 'fDAQmx_ConnectTerminals', 'fDAQmx_CTR_Finished', 'fDAQmx_CTR_IsFinished', + 'fDAQmx_CTR_IsPulseFinished', 'fDAQmx_CTR_ReadCounter', + 'fDAQmx_CTR_ReadWithOptions', 'fDAQmx_CTR_SetPulseFrequency', + 'fDAQmx_CTR_Start', 'fDAQmx_DeviceNames', 'fDAQmx_DIO_Finished', + 'fDAQmx_DIO_PortWidth', 'fDAQmx_DIO_Read', 'fDAQmx_DIO_Write', + 'fDAQmx_DisconnectTerminals', 'fDAQmx_ErrorString', 'fDAQmx_ExternalCalDate', + 'fDAQmx_NumAnalogInputs', 'fDAQmx_NumAnalogOutputs', 'fDAQmx_NumCounters', + 'fDAQmx_NumDIOPorts', 'fDAQmx_ReadChan', 'fDAQmx_ReadNamedChan', + 'fDAQmx_ResetDevice', 'fDAQmx_ScanGetAvailable', 'fDAQmx_ScanGetNextIndex', + 'fDAQmx_ScanStart', 'fDAQmx_ScanStop', 'fDAQmx_ScanWait', + 'fDAQmx_ScanWaitWithTimeout', 'fDAQmx_SelfCalDate', 'fDAQmx_SelfCalibration', + 'fDAQmx_WaveformStart', 'fDAQmx_WaveformStop', 'fDAQmx_WF_IsFinished', + 'fDAQmx_WF_WaitUntilFinished', 'fDAQmx_WriteChan', 'FetchURL', 'FindDimLabel', + 'FindListItem', 'floor', 'FontList', 'FontSizeHeight', 'FontSizeStringWidth', + 'FresnelCos', 'FresnelSin', 'FuncRefInfo', 'FunctionInfo', 'FunctionList', + 'FunctionPath', 'gamma', 'gammaEuler', 'gammaInc', 'gammaNoise', 'gammln', + 'gammp', 'gammq', 'Gauss', 'Gauss1D', 'Gauss2D', 'gcd', 'GeometricMean', + 'GetBrowserLine', 'GetBrowserSelection', 'GetDataFolder', 'GetDataFolderDFR', + 'GetDefaultFont', 'GetDefaultFontSize', 'GetDefaultFontStyle', 'GetDimLabel', + 'GetEnvironmentVariable', 'GetErrMessage', 'GetFormula', + 'GetIndependentModuleName', 'GetIndexedObjName', 'GetIndexedObjNameDFR', + 'GetKeyState', 'GetRTErrMessage', 'GetRTError', 'GetRTLocation', 'GetRTLocInfo', + 'GetRTStackInfo', 'GetScrapText', 'GetUserData', 'GetWavesDataFolder', + 'GetWavesDataFolderDFR', 'GetWindowBrowserSelection', 'GISGetAllFileFormats', + 'GISSRefsAreEqual', 'GizmoInfo', 'GizmoScale', 'gnoise', 'GrepList', + 'GrepString', 'GuideInfo', 'GuideNameList', 'Hash', 'hcsr', 'HDF5AttributeInfo', + 'HDF5DatasetInfo', 'HDF5LibraryInfo', 'HDF5LinkInfo', 'HDF5TypeInfo', 'hermite', + 'hermiteGauss', 'HyperGNoise', 'HyperGPFQ', 'HyperG0F1', 'HyperG1F1', + 'HyperG2F1', 'i', 'IgorInfo', 'IgorVersion', 'imag', 'ImageInfo', + 'ImageNameList', 'ImageNameToWaveRef', 'IndependentModuleList', 'IndexedDir', + 'IndexedFile', 'IndexToScale', 'Inf', 'Integrate1D', 'interp', 'Interp2D', + 'Interp3D', 'inverseERF', 'inverseERFC', 'ItemsInList', 'JacobiCn', 'JacobiSn', + 'JulianToDate', 'Laguerre', 'LaguerreA', 'LaguerreGauss', 'LambertW', + 'LayoutInfo', 'leftx', 'LegendreA', 'limit', 'ListMatch', 'ListToTextWave', + 'ListToWaveRefWave', 'ln', 'log', 'logNormalNoise', 'lorentzianNoise', + 'LowerStr', 'MacroInfo', 'MacroList', 'MacroPath', 'magsqr', 'MandelbrotPoint', + 'MarcumQ', 'MatrixCondition', 'MatrixDet', 'MatrixDot', 'MatrixRank', + 'MatrixTrace', 'max', 'MCC_AutoBridgeBal', 'MCC_AutoFastComp', + 'MCC_AutoPipetteOffset', 'MCC_AutoSlowComp', 'MCC_AutoWholeCellComp', + 'MCC_GetBridgeBalEnable', 'MCC_GetBridgeBalResist', 'MCC_GetFastCompCap', + 'MCC_GetFastCompTau', 'MCC_GetHolding', 'MCC_GetHoldingEnable', 'MCC_GetMode', + 'MCC_GetNeutralizationCap', 'MCC_GetNeutralizationEnable', + 'MCC_GetOscKillerEnable', 'MCC_GetPipetteOffset', 'MCC_GetPrimarySignalGain', + 'MCC_GetPrimarySignalHPF', 'MCC_GetPrimarySignalLPF', 'MCC_GetRsCompBandwidth', + 'MCC_GetRsCompCorrection', 'MCC_GetRsCompEnable', 'MCC_GetRsCompPrediction', + 'MCC_GetSecondarySignalGain', 'MCC_GetSecondarySignalLPF', 'MCC_GetSlowCompCap', + 'MCC_GetSlowCompTau', 'MCC_GetSlowCompTauX20Enable', + 'MCC_GetSlowCurrentInjEnable', 'MCC_GetSlowCurrentInjLevel', + 'MCC_GetSlowCurrentInjSetlTime', 'MCC_GetWholeCellCompCap', + 'MCC_GetWholeCellCompEnable', 'MCC_GetWholeCellCompResist', + 'MCC_SelectMultiClamp700B', 'MCC_SetBridgeBalEnable', 'MCC_SetBridgeBalResist', + 'MCC_SetFastCompCap', 'MCC_SetFastCompTau', 'MCC_SetHolding', + 'MCC_SetHoldingEnable', 'MCC_SetMode', 'MCC_SetNeutralizationCap', + 'MCC_SetNeutralizationEnable', 'MCC_SetOscKillerEnable', 'MCC_SetPipetteOffset', + 'MCC_SetPrimarySignalGain', 'MCC_SetPrimarySignalHPF', 'MCC_SetPrimarySignalLPF', + 'MCC_SetRsCompBandwidth', 'MCC_SetRsCompCorrection', 'MCC_SetRsCompEnable', + 'MCC_SetRsCompPrediction', 'MCC_SetSecondarySignalGain', + 'MCC_SetSecondarySignalLPF', 'MCC_SetSlowCompCap', 'MCC_SetSlowCompTau', + 'MCC_SetSlowCompTauX20Enable', 'MCC_SetSlowCurrentInjEnable', + 'MCC_SetSlowCurrentInjLevel', 'MCC_SetSlowCurrentInjSetlTime', + 'MCC_SetTimeoutMs', 'MCC_SetWholeCellCompCap', 'MCC_SetWholeCellCompEnable', + 'MCC_SetWholeCellCompResist', 'mean', 'median', 'min', 'mod', 'ModDate', + 'MPFXEMGPeak', 'MPFXExpConvExpPeak', 'MPFXGaussPeak', 'MPFXLorentzianPeak', + 'MPFXVoigtPeak', 'NameOfWave', 'NaN', 'NewFreeDataFolder', 'NewFreeWave', 'norm', + 'NormalizeUnicode', 'note', 'NumberByKey', 'numpnts', 'numtype', + 'NumVarOrDefault', 'num2char', 'num2istr', 'num2str', 'NVAR_Exists', + 'OperationList', 'PadString', 'PanelResolution', 'ParamIsDefault', + 'ParseFilePath', 'PathList', 'pcsr', 'Pi', 'PICTInfo', 'PICTList', + 'PixelFromAxisVal', 'pnt2x', 'poissonNoise', 'poly', 'PolygonArea', 'poly2D', + 'PossiblyQuoteName', 'ProcedureText', 'ProcedureVersion', 'p2rect', 'qcsr', + 'real', 'RemoveByKey', 'RemoveEnding', 'RemoveFromList', 'RemoveListItem', + 'ReplaceNumberByKey', 'ReplaceString', 'ReplaceStringByKey', 'ReplicateString', + 'rightx', 'round', 'r2polar', 'sawtooth', 'scaleToIndex', 'ScreenResolution', + 'sec', 'sech', 'Secs2Date', 'Secs2Time', 'SelectNumber', 'SelectString', + 'SetEnvironmentVariable', 'sign', 'sin', 'sinc', 'sinh', 'sinIntegral', + 'SortList', 'SpecialCharacterInfo', 'SpecialCharacterList', 'SpecialDirPath', + 'SphericalBessJ', 'SphericalBessJD', 'SphericalBessY', 'SphericalBessYD', + 'SphericalHarmonics', 'SQLAllocHandle', 'SQLAllocStmt', + 'SQLBinaryWavesToTextWave', 'SQLBindCol', 'SQLBindParameter', 'SQLBrowseConnect', + 'SQLBulkOperations', 'SQLCancel', 'SQLCloseCursor', 'SQLColAttributeNum', + 'SQLColAttributeStr', 'SQLColumnPrivileges', 'SQLColumns', 'SQLConnect', + 'SQLDataSources', 'SQLDescribeCol', 'SQLDescribeParam', 'SQLDisconnect', + 'SQLDriverConnect', 'SQLDrivers', 'SQLEndTran', 'SQLError', 'SQLExecDirect', + 'SQLExecute', 'SQLFetch', 'SQLFetchScroll', 'SQLForeignKeys', 'SQLFreeConnect', + 'SQLFreeEnv', 'SQLFreeHandle', 'SQLFreeStmt', 'SQLGetConnectAttrNum', + 'SQLGetConnectAttrStr', 'SQLGetCursorName', 'SQLGetDataNum', 'SQLGetDataStr', + 'SQLGetDescFieldNum', 'SQLGetDescFieldStr', 'SQLGetDescRec', + 'SQLGetDiagFieldNum', 'SQLGetDiagFieldStr', 'SQLGetDiagRec', 'SQLGetEnvAttrNum', + 'SQLGetEnvAttrStr', 'SQLGetFunctions', 'SQLGetInfoNum', 'SQLGetInfoStr', + 'SQLGetStmtAttrNum', 'SQLGetStmtAttrStr', 'SQLGetTypeInfo', 'SQLMoreResults', + 'SQLNativeSql', 'SQLNumParams', 'SQLNumResultCols', 'SQLNumResultRowsIfKnown', + 'SQLNumRowsFetched', 'SQLParamData', 'SQLPrepare', 'SQLPrimaryKeys', + 'SQLProcedureColumns', 'SQLProcedures', 'SQLPutData', 'SQLReinitialize', + 'SQLRowCount', 'SQLSetConnectAttrNum', 'SQLSetConnectAttrStr', + 'SQLSetCursorName', 'SQLSetDescFieldNum', 'SQLSetDescFieldStr', 'SQLSetDescRec', + 'SQLSetEnvAttrNum', 'SQLSetEnvAttrStr', 'SQLSetPos', 'SQLSetStmtAttrNum', + 'SQLSetStmtAttrStr', 'SQLSpecialColumns', 'SQLStatistics', 'SQLTablePrivileges', + 'SQLTables', 'SQLTextWaveToBinaryWaves', 'SQLTextWaveTo2DBinaryWave', + 'SQLUpdateBoundValues', 'SQLXOPCheckState', 'SQL2DBinaryWaveToTextWave', 'sqrt', + 'StartMSTimer', 'StatsBetaCDF', 'StatsBetaPDF', 'StatsBinomialCDF', + 'StatsBinomialPDF', 'StatsCauchyCDF', 'StatsCauchyPDF', 'StatsChiCDF', + 'StatsChiPDF', 'StatsCMSSDCDF', 'StatsCorrelation', 'StatsDExpCDF', + 'StatsDExpPDF', 'StatsErlangCDF', 'StatsErlangPDF', 'StatsErrorPDF', + 'StatsEValueCDF', 'StatsEValuePDF', 'StatsExpCDF', 'StatsExpPDF', 'StatsFCDF', + 'StatsFPDF', 'StatsFriedmanCDF', 'StatsGammaCDF', 'StatsGammaPDF', + 'StatsGeometricCDF', 'StatsGeometricPDF', 'StatsGEVCDF', 'StatsGEVPDF', + 'StatsHyperGCDF', 'StatsHyperGPDF', 'StatsInvBetaCDF', 'StatsInvBinomialCDF', + 'StatsInvCauchyCDF', 'StatsInvChiCDF', 'StatsInvCMSSDCDF', 'StatsInvDExpCDF', + 'StatsInvEValueCDF', 'StatsInvExpCDF', 'StatsInvFCDF', 'StatsInvFriedmanCDF', + 'StatsInvGammaCDF', 'StatsInvGeometricCDF', 'StatsInvKuiperCDF', + 'StatsInvLogisticCDF', 'StatsInvLogNormalCDF', 'StatsInvMaxwellCDF', + 'StatsInvMooreCDF', 'StatsInvNBinomialCDF', 'StatsInvNCChiCDF', 'StatsInvNCFCDF', + 'StatsInvNormalCDF', 'StatsInvParetoCDF', 'StatsInvPoissonCDF', + 'StatsInvPowerCDF', 'StatsInvQCDF', 'StatsInvQpCDF', 'StatsInvRayleighCDF', + 'StatsInvRectangularCDF', 'StatsInvSpearmanCDF', 'StatsInvStudentCDF', + 'StatsInvTopDownCDF', 'StatsInvTriangularCDF', 'StatsInvUsquaredCDF', + 'StatsInvVonMisesCDF', 'StatsInvWeibullCDF', 'StatsKuiperCDF', + 'StatsLogisticCDF', 'StatsLogisticPDF', 'StatsLogNormalCDF', 'StatsLogNormalPDF', + 'StatsMaxwellCDF', 'StatsMaxwellPDF', 'StatsMedian', 'StatsMooreCDF', + 'StatsNBinomialCDF', 'StatsNBinomialPDF', 'StatsNCChiCDF', 'StatsNCChiPDF', + 'StatsNCFCDF', 'StatsNCFPDF', 'StatsNCTCDF', 'StatsNCTPDF', 'StatsNormalCDF', + 'StatsNormalPDF', 'StatsParetoCDF', 'StatsParetoPDF', 'StatsPermute', + 'StatsPoissonCDF', 'StatsPoissonPDF', 'StatsPowerCDF', 'StatsPowerNoise', + 'StatsPowerPDF', 'StatsQCDF', 'StatsQpCDF', 'StatsRayleighCDF', + 'StatsRayleighPDF', 'StatsRectangularCDF', 'StatsRectangularPDF', 'StatsRunsCDF', + 'StatsSpearmanRhoCDF', 'StatsStudentCDF', 'StatsStudentPDF', 'StatsTopDownCDF', + 'StatsTriangularCDF', 'StatsTriangularPDF', 'StatsTrimmedMean', + 'StatsUSquaredCDF', 'StatsVonMisesCDF', 'StatsVonMisesNoise', 'StatsVonMisesPDF', + 'StatsWaldCDF', 'StatsWaldPDF', 'StatsWeibullCDF', 'StatsWeibullPDF', + 'StopMSTimer', 'StringByKey', 'stringCRC', 'StringFromList', 'StringList', + 'stringmatch', 'StringToUnsignedByteWave', 'strlen', 'strsearch', + 'StrVarOrDefault', 'str2num', 'StudentA', 'StudentT', 'sum', 'SVAR_Exists', + 'TableInfo', 'TagVal', 'TagWaveRef', 'tan', 'tanh', 'TDMAddChannel', + 'TDMAddGroup', 'TDMAppendDataValues', 'TDMAppendDataValuesTime', + 'TDMChannelPropertyExists', 'TDMCloseChannel', 'TDMCloseFile', 'TDMCloseGroup', + 'TDMCreateChannelProperty', 'TDMCreateFile', 'TDMCreateFileProperty', + 'TDMCreateGroupProperty', 'TDMFilePropertyExists', 'TDMGetChannelPropertyNames', + 'TDMGetChannelPropertyNum', 'TDMGetChannelPropertyStr', + 'TDMGetChannelPropertyTime', 'TDMGetChannelPropertyType', 'TDMGetChannels', + 'TDMGetChannelStringPropertyLen', 'TDMGetDataType', 'TDMGetDataValues', + 'TDMGetDataValuesTime', 'TDMGetFilePropertyNames', 'TDMGetFilePropertyNum', + 'TDMGetFilePropertyStr', 'TDMGetFilePropertyTime', 'TDMGetFilePropertyType', + 'TDMGetFileStringPropertyLen', 'TDMGetGroupPropertyNames', + 'TDMGetGroupPropertyNum', 'TDMGetGroupPropertyStr', 'TDMGetGroupPropertyTime', + 'TDMGetGroupPropertyType', 'TDMGetGroups', 'TDMGetGroupStringPropertyLen', + 'TDMGetLibraryErrorDescription', 'TDMGetNumChannelProperties', + 'TDMGetNumChannels', 'TDMGetNumDataValues', 'TDMGetNumFileProperties', + 'TDMGetNumGroupProperties', 'TDMGetNumGroups', 'TDMGroupPropertyExists', + 'TDMOpenFile', 'TDMOpenFileEx', 'TDMRemoveChannel', 'TDMRemoveGroup', + 'TDMReplaceDataValues', 'TDMReplaceDataValuesTime', 'TDMSaveFile', + 'TDMSetChannelPropertyNum', 'TDMSetChannelPropertyStr', + 'TDMSetChannelPropertyTime', 'TDMSetDataValues', 'TDMSetDataValuesTime', + 'TDMSetFilePropertyNum', 'TDMSetFilePropertyStr', 'TDMSetFilePropertyTime', + 'TDMSetGroupPropertyNum', 'TDMSetGroupPropertyStr', 'TDMSetGroupPropertyTime', + 'TextEncodingCode', 'TextEncodingName', 'TextFile', 'ThreadGroupCreate', + 'ThreadGroupGetDF', 'ThreadGroupGetDFR', 'ThreadGroupRelease', 'ThreadGroupWait', + 'ThreadProcessorCount', 'ThreadReturnValue', 'ticks', 'time', 'TraceFromPixel', + 'TraceInfo', 'TraceNameList', 'TraceNameToWaveRef', 'TrimString', 'trunc', + 'UniqueName', 'UnPadString', 'UnsetEnvironmentVariable', 'UpperStr', 'URLDecode', + 'URLEncode', 'VariableList', 'Variance', 'vcsr', 'viAssertIntrSignal', + 'viAssertTrigger', 'viAssertUtilSignal', 'viClear', 'viClose', 'viDisableEvent', + 'viDiscardEvents', 'viEnableEvent', 'viFindNext', 'viFindRsrc', 'viGetAttribute', + 'viGetAttributeString', 'viGpibCommand', 'viGpibControlATN', 'viGpibControlREN', + 'viGpibPassControl', 'viGpibSendIFC', 'viIn8', 'viIn16', 'viIn32', 'viLock', + 'viMapAddress', 'viMapTrigger', 'viMemAlloc', 'viMemFree', 'viMoveIn8', + 'viMoveIn16', 'viMoveIn32', 'viMoveOut8', 'viMoveOut16', 'viMoveOut32', 'viOpen', + 'viOpenDefaultRM', 'viOut8', 'viOut16', 'viOut32', 'viPeek8', 'viPeek16', + 'viPeek32', 'viPoke8', 'viPoke16', 'viPoke32', 'viRead', 'viReadSTB', + 'viSetAttribute', 'viSetAttributeString', 'viStatusDesc', 'viTerminate', + 'viUnlock', 'viUnmapAddress', 'viUnmapTrigger', 'viUsbControlIn', + 'viUsbControlOut', 'viVxiCommandQuery', 'viWaitOnEvent', 'viWrite', 'VoigtFunc', + 'VoigtPeak', 'WaveCRC', 'WaveDataToString', 'WaveDims', 'WaveExists', 'WaveHash', + 'WaveInfo', 'WaveList', 'WaveMax', 'WaveMin', 'WaveMinAndMax', 'WaveModCount', + 'WaveName', 'WaveRefIndexed', 'WaveRefIndexedDFR', 'WaveRefsEqual', + 'WaveRefWaveToList', 'WaveTextEncoding', 'WaveType', 'WaveUnits', + 'WhichListItem', 'WinList', 'WinName', 'WinRecreation', 'WinType', 'wnoise', + 'xcsr', 'XWaveName', 'XWaveRefFromTrace', 'x2pnt', 'zcsr', 'ZernikeR', + 'zeromq_client_connect', 'zeromq_client_recv', 'zeromq_client_send', + 'zeromq_handler_start', 'zeromq_handler_stop', 'zeromq_pub_bind', + 'zeromq_pub_send', 'zeromq_server_bind', 'zeromq_server_recv', + 'zeromq_server_send', 'zeromq_set', 'zeromq_set_logging_template', 'zeromq_stop', + 'zeromq_sub_add_filter', 'zeromq_sub_connect', 'zeromq_sub_recv', + 'zeromq_sub_remove_filter', 'zeromq_test_callfunction', + 'zeromq_test_serializeWave', 'zeta' + ) + + tokens = { + 'root': [ + (r'//.*$', Comment.Single), + (r'"([^"\\]|\\.)*"', String), + # Flow Control. + (words(flowControl, prefix=r'\b', suffix=r'\b'), Keyword), + # Types. + (words(types, prefix=r'\b', suffix=r'\b'), Keyword.Type), + # Keywords. + (words(keywords, prefix=r'\b', suffix=r'\b'), Keyword.Reserved), + # Built-in operations. + (words(operations, prefix=r'\b', suffix=r'\b'), Name.Class), + # Built-in functions. + (words(functions, prefix=r'\b', suffix=r'\b'), Name.Function), + # Compiler directives. + (r'^#(include|pragma|define|undef|ifdef|ifndef|if|elif|else|endif)', + Name.Decorator), + (r'\s+', Whitespace), + (r'[^a-z"/]+$', Text), + (r'.', Text), + ], + } diff --git a/pygments/lexers/inferno.py b/pygments/lexers/inferno.py new file mode 100644 index 0000000000000000000000000000000000000000..b5caf55a002d4b1131731c54c1c44bdd64f1b471 --- /dev/null +++ b/pygments/lexers/inferno.py @@ -0,0 +1,95 @@ +""" + pygments.lexers.inferno + ~~~~~~~~~~~~~~~~~~~~~~~ + + Lexers for Inferno os and all the related stuff. + + :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS. + :license: BSD, see LICENSE for details. +""" + +import re + +from pygments.lexer import RegexLexer, include, bygroups, default +from pygments.token import Punctuation, Comment, Operator, Keyword, \ + Name, String, Number, Whitespace + +__all__ = ['LimboLexer'] + + +class LimboLexer(RegexLexer): + """ + Lexer for Limbo programming language + + TODO: + - maybe implement better var declaration highlighting + - some simple syntax error highlighting + """ + name = 'Limbo' + url = 'http://www.vitanuova.com/inferno/limbo.html' + aliases = ['limbo'] + filenames = ['*.b'] + mimetypes = ['text/limbo'] + version_added = '2.0' + + tokens = { + 'whitespace': [ + (r'^(\s*)([a-zA-Z_]\w*:)(\s*\n)', + bygroups(Whitespace, Name.Label, Whitespace)), + (r'\n', Whitespace), + (r'\s+', Whitespace), + (r'#(\n|(.|\n)*?[^\\]\n)', Comment.Single), + ], + 'string': [ + (r'"', String, '#pop'), + (r'\\([\\abfnrtv"\']|x[a-fA-F0-9]{2,4}|' + r'u[a-fA-F0-9]{4}|U[a-fA-F0-9]{8}|[0-7]{1,3})', String.Escape), + (r'[^\\"\n]+', String), # all other characters + (r'\\', String), # stray backslash + ], + 'statements': [ + (r'"', String, 'string'), + (r"'(\\.|\\[0-7]{1,3}|\\x[a-fA-F0-9]{1,2}|[^\\\'\n])'", String.Char), + (r'(\d+\.\d*|\.\d+|\d+)[eE][+-]?\d+', Number.Float), + (r'(\d+\.\d*|\.\d+|\d+[fF])', Number.Float), + (r'16r[0-9a-fA-F]+', Number.Hex), + (r'8r[0-7]+', Number.Oct), + (r'((([1-3]\d)|([2-9]))r)?(\d+)', Number.Integer), + (r'[()\[\],.]', Punctuation), + (r'[~!%^&*+=|?:<>/-]|(->)|(<-)|(=>)|(::)', Operator), + (r'(alt|break|case|continue|cyclic|do|else|exit' + r'for|hd|if|implement|import|include|len|load|or' + r'pick|return|spawn|tagof|tl|to|while)\b', Keyword), + (r'(byte|int|big|real|string|array|chan|list|adt' + r'|fn|ref|of|module|self|type)\b', Keyword.Type), + (r'(con|iota|nil)\b', Keyword.Constant), + (r'[a-zA-Z_]\w*', Name), + ], + 'statement' : [ + include('whitespace'), + include('statements'), + ('[{}]', Punctuation), + (';', Punctuation, '#pop'), + ], + 'root': [ + include('whitespace'), + default('statement'), + ], + } + + def analyse_text(text): + # Any limbo module implements something + if re.search(r'^implement \w+;', text, re.MULTILINE): + return 0.7 + +# TODO: +# - Make lexers for: +# - asm sources +# - man pages +# - mkfiles +# - module definitions +# - namespace definitions +# - shell scripts +# - maybe keyfiles and fonts +# they all seem to be quite similar to their equivalents +# from unix world, so there should not be a lot of problems diff --git a/pygments/lexers/installers.py b/pygments/lexers/installers.py new file mode 100644 index 0000000000000000000000000000000000000000..3f12a25b2ce4d64b9e9af1950b35cbd5cbcfee47 --- /dev/null +++ b/pygments/lexers/installers.py @@ -0,0 +1,352 @@ +""" + pygments.lexers.installers + ~~~~~~~~~~~~~~~~~~~~~~~~~~ + + Lexers for installer/packager DSLs and formats. + + :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS. + :license: BSD, see LICENSE for details. +""" + +import re + +from pygments.lexer import RegexLexer, include, bygroups, using, this, default +from pygments.token import Text, Comment, Operator, Keyword, Name, String, \ + Punctuation, Generic, Number, Whitespace + +__all__ = ['NSISLexer', 'RPMSpecLexer', + 'DebianSourcesLexer', 'SourcesListLexer', + 'DebianControlLexer'] + + +class NSISLexer(RegexLexer): + """ + For NSIS scripts. + """ + name = 'NSIS' + url = 'http://nsis.sourceforge.net/' + aliases = ['nsis', 'nsi', 'nsh'] + filenames = ['*.nsi', '*.nsh'] + mimetypes = ['text/x-nsis'] + version_added = '1.6' + + flags = re.IGNORECASE + + tokens = { + 'root': [ + (r'([;#].*)(\n)', bygroups(Comment, Whitespace)), + (r"'.*?'", String.Single), + (r'"', String.Double, 'str_double'), + (r'`', String.Backtick, 'str_backtick'), + include('macro'), + include('interpol'), + include('basic'), + (r'\$\{[a-z_|][\w|]*\}', Keyword.Pseudo), + (r'/[a-z_]\w*', Name.Attribute), + (r'\s+', Whitespace), + (r'[\w.]+', Text), + ], + 'basic': [ + (r'(\n)(Function)(\s+)([._a-z][.\w]*)\b', + bygroups(Whitespace, Keyword, Whitespace, Name.Function)), + (r'\b([_a-z]\w*)(::)([a-z][a-z0-9]*)\b', + bygroups(Keyword.Namespace, Punctuation, Name.Function)), + (r'\b([_a-z]\w*)(:)', bygroups(Name.Label, Punctuation)), + (r'(\b[ULS]|\B)([!<>=]?=|\<\>?|\>)\B', Operator), + (r'[|+-]', Operator), + (r'\\', Punctuation), + (r'\b(Abort|Add(?:BrandingImage|Size)|' + r'Allow(?:RootDirInstall|SkipFiles)|AutoCloseWindow|' + r'BG(?:Font|Gradient)|BrandingText|BringToFront|Call(?:InstDLL)?|' + r'(?:Sub)?Caption|ChangeUI|CheckBitmap|ClearErrors|CompletedText|' + r'ComponentText|CopyFiles|CRCCheck|' + r'Create(?:Directory|Font|Shortcut)|Delete(?:INI(?:Sec|Str)|' + r'Reg(?:Key|Value))?|DetailPrint|DetailsButtonText|' + r'Dir(?:Show|Text|Var|Verify)|(?:Disabled|Enabled)Bitmap|' + r'EnableWindow|EnumReg(?:Key|Value)|Exch|Exec(?:Shell|Wait)?|' + r'ExpandEnvStrings|File(?:BufSize|Close|ErrorText|Open|' + r'Read(?:Byte)?|Seek|Write(?:Byte)?)?|' + r'Find(?:Close|First|Next|Window)|FlushINI|Function(?:End)?|' + r'Get(?:CurInstType|CurrentAddress|DlgItem|DLLVersion(?:Local)?|' + r'ErrorLevel|FileTime(?:Local)?|FullPathName|FunctionAddress|' + r'InstDirError|LabelAddress|TempFileName)|' + r'Goto|HideWindow|Icon|' + r'If(?:Abort|Errors|FileExists|RebootFlag|Silent)|' + r'InitPluginsDir|Install(?:ButtonText|Colors|Dir(?:RegKey)?)|' + r'Inst(?:ProgressFlags|Type(?:[GS]etText)?)|Int(?:CmpU?|Fmt|Op)|' + r'IsWindow|LangString(?:UP)?|' + r'License(?:BkColor|Data|ForceSelection|LangString|Text)|' + r'LoadLanguageFile|LockWindow|Log(?:Set|Text)|MessageBox|' + r'MiscButtonText|Name|Nop|OutFile|(?:Uninst)?Page(?:Ex(?:End)?)?|' + r'PluginDir|Pop|Push|Quit|Read(?:(?:Env|INI|Reg)Str|RegDWORD)|' + r'Reboot|(?:Un)?RegDLL|Rename|RequestExecutionLevel|ReserveFile|' + r'Return|RMDir|SearchPath|Section(?:Divider|End|' + r'(?:(?:Get|Set)(?:Flags|InstTypes|Size|Text))|Group(?:End)?|In)?|' + r'SendMessage|Set(?:AutoClose|BrandingImage|Compress(?:ionLevel|' + r'or(?:DictSize)?)?|CtlColors|CurInstType|DatablockOptimize|' + r'DateSave|Details(?:Print|View)|Error(?:s|Level)|FileAttributes|' + r'Font|OutPath|Overwrite|PluginUnload|RebootFlag|ShellVarContext|' + r'Silent|StaticBkColor)|' + r'Show(?:(?:I|Uni)nstDetails|Window)|Silent(?:Un)?Install|Sleep|' + r'SpaceTexts|Str(?:CmpS?|Cpy|Len)|SubSection(?:End)?|' + r'Uninstall(?:ButtonText|(?:Sub)?Caption|EXEName|Icon|Text)|' + r'UninstPage|Var|VI(?:AddVersionKey|ProductVersion)|WindowIcon|' + r'Write(?:INIStr|Reg(:?Bin|DWORD|(?:Expand)?Str)|Uninstaller)|' + r'XPStyle)\b', Keyword), + (r'\b(CUR|END|(?:FILE_ATTRIBUTE_)?' + r'(?:ARCHIVE|HIDDEN|NORMAL|OFFLINE|READONLY|SYSTEM|TEMPORARY)|' + r'HK(CC|CR|CU|DD|LM|PD|U)|' + r'HKEY_(?:CLASSES_ROOT|CURRENT_(?:CONFIG|USER)|DYN_DATA|' + r'LOCAL_MACHINE|PERFORMANCE_DATA|USERS)|' + r'ID(?:ABORT|CANCEL|IGNORE|NO|OK|RETRY|YES)|' + r'MB_(?:ABORTRETRYIGNORE|DEFBUTTON[1-4]|' + r'ICON(?:EXCLAMATION|INFORMATION|QUESTION|STOP)|' + r'OK(?:CANCEL)?|RETRYCANCEL|RIGHT|SETFOREGROUND|TOPMOST|USERICON|' + r'YESNO(?:CANCEL)?)|SET|SHCTX|' + r'SW_(?:HIDE|SHOW(?:MAXIMIZED|MINIMIZED|NORMAL))|' + r'admin|all|auto|both|bottom|bzip2|checkbox|colored|current|false|' + r'force|hide|highest|if(?:diff|newer)|lastused|leave|left|' + r'listonly|lzma|nevershow|none|normal|off|on|pop|push|' + r'radiobuttons|right|show|silent|silentlog|smooth|textonly|top|' + r'true|try|user|zlib)\b', Name.Constant), + ], + 'macro': [ + (r'\!(addincludedir(?:dir)?|addplugindir|appendfile|cd|define|' + r'delfilefile|echo(?:message)?|else|endif|error|execute|' + r'if(?:macro)?n?(?:def)?|include|insertmacro|macro(?:end)?|packhdr|' + r'search(?:parse|replace)|system|tempfilesymbol|undef|verbose|' + r'warning)\b', Comment.Preproc), + ], + 'interpol': [ + (r'\$(R?[0-9])', Name.Builtin.Pseudo), # registers + (r'\$(ADMINTOOLS|APPDATA|CDBURN_AREA|COOKIES|COMMONFILES(?:32|64)|' + r'DESKTOP|DOCUMENTS|EXE(?:DIR|FILE|PATH)|FAVORITES|FONTS|HISTORY|' + r'HWNDPARENT|INTERNET_CACHE|LOCALAPPDATA|MUSIC|NETHOOD|PICTURES|' + r'PLUGINSDIR|PRINTHOOD|PROFILE|PROGRAMFILES(?:32|64)|QUICKLAUNCH|' + r'RECENT|RESOURCES(?:_LOCALIZED)?|SENDTO|SM(?:PROGRAMS|STARTUP)|' + r'STARTMENU|SYSDIR|TEMP(?:LATES)?|VIDEOS|WINDIR|\{NSISDIR\})', + Name.Builtin), + (r'\$(CMDLINE|INSTDIR|OUTDIR|LANGUAGE)', Name.Variable.Global), + (r'\$[a-z_]\w*', Name.Variable), + ], + 'str_double': [ + (r'"', String.Double, '#pop'), + (r'\$(\\[nrt"]|\$)', String.Escape), + include('interpol'), + (r'[^"]+', String.Double), + ], + 'str_backtick': [ + (r'`', String.Double, '#pop'), + (r'\$(\\[nrt"]|\$)', String.Escape), + include('interpol'), + (r'[^`]+', String.Double), + ], + } + + +class RPMSpecLexer(RegexLexer): + """ + For RPM ``.spec`` files. + """ + + name = 'RPMSpec' + aliases = ['spec'] + filenames = ['*.spec'] + mimetypes = ['text/x-rpm-spec'] + url = 'https://rpm-software-management.github.io/rpm/manual/spec.html' + version_added = '1.6' + + _directives = ('(?:package|prep|build|install|clean|check|pre[a-z]*|' + 'post[a-z]*|trigger[a-z]*|files)') + + tokens = { + 'root': [ + (r'#.*$', Comment), + include('basic'), + ], + 'description': [ + (r'^(%' + _directives + ')(.*)$', + bygroups(Name.Decorator, Text), '#pop'), + (r'\s+', Whitespace), + (r'.', Text), + ], + 'changelog': [ + (r'\*.*$', Generic.Subheading), + (r'^(%' + _directives + ')(.*)$', + bygroups(Name.Decorator, Text), '#pop'), + (r'\s+', Whitespace), + (r'.', Text), + ], + 'string': [ + (r'"', String.Double, '#pop'), + (r'\\([\\abfnrtv"\']|x[a-fA-F0-9]{2,4}|[0-7]{1,3})', String.Escape), + include('interpol'), + (r'.', String.Double), + ], + 'basic': [ + include('macro'), + (r'(?i)^(Name|Version|Release|Epoch|Summary|Group|License|Packager|' + r'Vendor|Icon|URL|Distribution|Prefix|Patch[0-9]*|Source[0-9]*|' + r'Requires\(?[a-z]*\)?|[a-z]+Req|Obsoletes|Suggests|Provides|Conflicts|' + r'Build[a-z]+|[a-z]+Arch|Auto[a-z]+)(:)(.*)$', + bygroups(Generic.Heading, Punctuation, using(this))), + (r'^%description', Name.Decorator, 'description'), + (r'^%changelog', Name.Decorator, 'changelog'), + (r'^(%' + _directives + ')(.*)$', bygroups(Name.Decorator, Text)), + (r'%(attr|defattr|dir|doc(?:dir)?|setup|config(?:ure)?|' + r'make(?:install)|ghost|patch[0-9]+|find_lang|exclude|verify)', + Keyword), + include('interpol'), + (r"'.*?'", String.Single), + (r'"', String.Double, 'string'), + (r'\s+', Whitespace), + (r'.', Text), + ], + 'macro': [ + (r'%define.*$', Comment.Preproc), + (r'%\{\!\?.*%define.*\}', Comment.Preproc), + (r'(%(?:if(?:n?arch)?|else(?:if)?|endif))(.*)$', + bygroups(Comment.Preproc, Text)), + ], + 'interpol': [ + (r'%\{?__[a-z_]+\}?', Name.Function), + (r'%\{?_([a-z_]+dir|[a-z_]+path|prefix)\}?', Keyword.Pseudo), + (r'%\{\?\w+\}', Name.Variable), + (r'\$\{?RPM_[A-Z0-9_]+\}?', Name.Variable.Global), + (r'%\{[a-zA-Z]\w+\}', Keyword.Constant), + ] + } + + +class DebianSourcesLexer(RegexLexer): + """ + Lexer that highlights debian.sources files. + """ + + name = 'Debian Sources file' + aliases = ['debian.sources'] + filenames = ['*.sources'] + version_added = '2.19' + url = 'https://manpages.debian.org/bookworm/apt/sources.list.5.en.html#THE_DEB_AND_DEB-SRC_TYPES:_GENERAL_FORMAT' + + tokens = { + 'root': [ + (r'^(Signed-By)(:)(\s*)', bygroups(Keyword, Punctuation, Whitespace), 'signed-by'), + (r'^([a-zA-Z\-0-9\.]*?)(:)(\s*)(.*?)$', + bygroups(Keyword, Punctuation, Whitespace, String)), + ], + 'signed-by': [ + (r' -----END PGP PUBLIC KEY BLOCK-----\n', Text, '#pop'), + (r'.+\n', Text), + ], + } + + +class SourcesListLexer(RegexLexer): + """ + Lexer that highlights debian sources.list files. + """ + + name = 'Debian Sourcelist' + aliases = ['debsources', 'sourceslist', 'sources.list'] + filenames = ['sources.list'] + version_added = '0.7' + mimetype = ['application/x-debian-sourceslist'] + url = 'https://wiki.debian.org/SourcesList' + + tokens = { + 'root': [ + (r'\s+', Whitespace), + (r'#.*?$', Comment), + (r'^(deb(?:-src)?)(\s+)', + bygroups(Keyword, Whitespace), 'distribution') + ], + 'distribution': [ + (r'#.*?$', Comment, '#pop'), + (r'\$\(ARCH\)', Name.Variable), + (r'[^\s$[]+', String), + (r'\[', String.Other, 'escaped-distribution'), + (r'\$', String), + (r'\s+', Whitespace, 'components') + ], + 'escaped-distribution': [ + (r'\]', String.Other, '#pop'), + (r'\$\(ARCH\)', Name.Variable), + (r'[^\]$]+', String.Other), + (r'\$', String.Other) + ], + 'components': [ + (r'#.*?$', Comment, '#pop:2'), + (r'$', Text, '#pop:2'), + (r'\s+', Whitespace), + (r'\S+', Keyword.Pseudo), + ] + } + + def analyse_text(text): + for line in text.splitlines(): + line = line.strip() + if line.startswith('deb ') or line.startswith('deb-src '): + return True + + +class DebianControlLexer(RegexLexer): + """ + Lexer for Debian ``control`` files and ``apt-cache show `` outputs. + """ + name = 'Debian Control file' + url = 'https://www.debian.org/doc/debian-policy/ch-controlfields.html' + aliases = ['debcontrol', 'control'] + filenames = ['control'] + version_added = '0.9' + + tokens = { + 'root': [ + (r'^(Description)', Keyword, 'description'), + (r'^(Maintainer|Uploaders|Changed-By)(:)(\s*)', + bygroups(Keyword, Punctuation, Whitespace), + 'maintainer'), + (r'^((?:Build-|Pre-)?Depends(?:-Indep|-Arch)?)(:)(\s*)', + bygroups(Keyword, Punctuation, Whitespace), 'package_list'), + (r'^(Recommends|Suggests|Enhances|Breaks|Replaces|Provides|Conflicts)(:)(\s*)', + bygroups(Keyword, Punctuation, Whitespace), 'package_list'), + (r'^((?:Python-)?Version)(:)(\s*)(\S+)$', + bygroups(Keyword, Punctuation, Whitespace, Number)), + (r'^((?:Installed-)?Size)(:)(\s*)(\S+)$', + bygroups(Keyword, Punctuation, Whitespace, Number)), + (r'^(MD5Sum|SHA1|SHA256)(:)(\s*)(\S+)$', + bygroups(Keyword, Punctuation, Whitespace, Number)), + (r'^([a-zA-Z\-0-9\.]*?)(:)(\s*)(.*?)$', + bygroups(Keyword, Punctuation, Whitespace, String)), + ], + 'maintainer': [ + (r'<[^>]+>$', Generic.Strong, '#pop'), + (r'<[^>]+>', Generic.Strong), + (r',\n?', Whitespace), + (r'[^,<]+$', Text, '#pop'), + (r'[^,<]+', Text), + ], + 'description': [ + (r'(.*)(Homepage)(: )(\S+)', + bygroups(Text, String, Name, Name.Class)), + (r':.*\n', Generic.Strong), + (r' .*\n', Text), + default('#pop'), + ], + 'package_list': [ + (r'(\$)(\{)(\w+)(\s*)(:)(\s*)(\w+)(\})', + bygroups(Operator, Punctuation, Name.Entity, Whitespace, + Punctuation, Whitespace, Text, Punctuation)), + (r'\(', Punctuation, 'package_list_vers'), + (r'\|', Operator), + (r'\n\s', Whitespace), + (r'\n', Whitespace, '#pop'), + (r'[,\s]', Text), + (r'[+.a-zA-Z0-9-]+', Name.Function), + (r'\[.*?\]', Name.Entity), + ], + 'package_list_vers': [ + (r'\)', Punctuation, '#pop'), + (r'([><=]+)(\s*)([^)]+)', bygroups(Operator, Whitespace, Number)), + ] + } diff --git a/pygments/lexers/int_fiction.py b/pygments/lexers/int_fiction.py new file mode 100644 index 0000000000000000000000000000000000000000..b142c544990306b5ccfd6d13a4f164e9e1dee3e7 --- /dev/null +++ b/pygments/lexers/int_fiction.py @@ -0,0 +1,1370 @@ +""" + pygments.lexers.int_fiction + ~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + Lexers for interactive fiction languages. + + :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS. + :license: BSD, see LICENSE for details. +""" + +import re + +from pygments.lexer import RegexLexer, include, bygroups, using, \ + this, default, words +from pygments.token import Text, Comment, Operator, Keyword, Name, String, \ + Number, Punctuation, Error, Generic + +__all__ = ['Inform6Lexer', 'Inform6TemplateLexer', 'Inform7Lexer', + 'Tads3Lexer'] + + +class Inform6Lexer(RegexLexer): + """ + For Inform 6 source code. + """ + + name = 'Inform 6' + url = 'http://inform-fiction.org/' + aliases = ['inform6', 'i6'] + filenames = ['*.inf'] + version_added = '2.0' + + flags = re.MULTILINE | re.DOTALL + + _name = r'[a-zA-Z_]\w*' + + # Inform 7 maps these four character classes to their ASCII + # equivalents. To support Inform 6 inclusions within Inform 7, + # Inform6Lexer maps them too. + _dash = '\\-\u2010-\u2014' + _dquote = '"\u201c\u201d' + _squote = "'\u2018\u2019" + _newline = '\\n\u0085\u2028\u2029' + + tokens = { + 'root': [ + (rf'\A(!%[^{_newline}]*[{_newline}])+', Comment.Preproc, + 'directive'), + default('directive') + ], + '_whitespace': [ + (r'\s+', Text), + (rf'![^{_newline}]*', Comment.Single) + ], + 'default': [ + include('_whitespace'), + (r'\[', Punctuation, 'many-values'), # Array initialization + (r':|(?=;)', Punctuation, '#pop'), + (r'<', Punctuation), # Second angle bracket in an action statement + default(('expression', '_expression')) + ], + + # Expressions + '_expression': [ + include('_whitespace'), + (r'(?=sp\b)', Text, '#pop'), + (rf'(?=[{_dquote}{_squote}$0-9#a-zA-Z_])', Text, + ('#pop', 'value')), + (rf'\+\+|[{_dash}]{{1,2}}(?!>)|~~?', Operator), + (rf'(?=[()\[{_dash},?@{{:;])', Text, '#pop') + ], + 'expression': [ + include('_whitespace'), + (r'\(', Punctuation, ('expression', '_expression')), + (r'\)', Punctuation, '#pop'), + (r'\[', Punctuation, ('#pop', 'statements', 'locals')), + (rf'>(?=(\s+|(![^{_newline}]*))*[>;])', Punctuation), + (rf'\+\+|[{_dash}]{{2}}(?!>)', Operator), + (r',', Punctuation, '_expression'), + (rf'&&?|\|\|?|[=~><]?=|[{_dash}]{{1,2}}>?|\.\.?[&#]?|::|[<>+*/%]', + Operator, '_expression'), + (r'(has|hasnt|in|notin|ofclass|or|provides)\b', Operator.Word, + '_expression'), + (r'sp\b', Name), + (r'\?~?', Name.Label, 'label?'), + (r'[@{]', Error), + default('#pop') + ], + '_assembly-expression': [ + (r'\(', Punctuation, ('#push', '_expression')), + (r'[\[\]]', Punctuation), + (rf'[{_dash}]>', Punctuation, '_expression'), + (r'sp\b', Keyword.Pseudo), + (r';', Punctuation, '#pop:3'), + include('expression') + ], + '_for-expression': [ + (r'\)', Punctuation, '#pop:2'), + (r':', Punctuation, '#pop'), + include('expression') + ], + '_keyword-expression': [ + (r'(from|near|to)\b', Keyword, '_expression'), + include('expression') + ], + '_list-expression': [ + (r',', Punctuation, '#pop'), + include('expression') + ], + '_object-expression': [ + (r'has\b', Keyword.Declaration, '#pop'), + include('_list-expression') + ], + + # Values + 'value': [ + include('_whitespace'), + # Strings + (rf'[{_squote}][^@][{_squote}]', String.Char, '#pop'), + (rf'([{_squote}])(@\{{[0-9a-fA-F]*\}})([{_squote}])', + bygroups(String.Char, String.Escape, String.Char), '#pop'), + (rf'([{_squote}])(@.{{2}})([{_squote}])', + bygroups(String.Char, String.Escape, String.Char), '#pop'), + (rf'[{_squote}]', String.Single, ('#pop', 'dictionary-word')), + (rf'[{_dquote}]', String.Double, ('#pop', 'string')), + # Numbers + (rf'\$[<>]?[+{_dash}][0-9]*\.?[0-9]*([eE][+{_dash}]?[0-9]+)?', + Number.Float, '#pop'), + (r'\$[0-9a-fA-F]+', Number.Hex, '#pop'), + (r'\$\$[01]+', Number.Bin, '#pop'), + (r'[0-9]+', Number.Integer, '#pop'), + # Values prefixed by hashes + (rf'(##|#a\$)({_name})', bygroups(Operator, Name), '#pop'), + (rf'(#g\$)({_name})', + bygroups(Operator, Name.Variable.Global), '#pop'), + (r'#[nw]\$', Operator, ('#pop', 'obsolete-dictionary-word')), + (rf'(#r\$)({_name})', bygroups(Operator, Name.Function), '#pop'), + (r'#', Name.Builtin, ('#pop', 'system-constant')), + # System functions + (words(( + 'child', 'children', 'elder', 'eldest', 'glk', 'indirect', 'metaclass', + 'parent', 'random', 'sibling', 'younger', 'youngest'), suffix=r'\b'), + Name.Builtin, '#pop'), + # Metaclasses + (r'(?i)(Class|Object|Routine|String)\b', Name.Builtin, '#pop'), + # Veneer routines + (words(( + 'Box__Routine', 'CA__Pr', 'CDefArt', 'CInDefArt', 'Cl__Ms', + 'Copy__Primitive', 'CP__Tab', 'DA__Pr', 'DB__Pr', 'DefArt', 'Dynam__String', + 'EnglishNumber', 'Glk__Wrap', 'IA__Pr', 'IB__Pr', 'InDefArt', 'Main__', + 'Meta__class', 'OB__Move', 'OB__Remove', 'OC__Cl', 'OP__Pr', 'Print__Addr', + 'Print__PName', 'PrintShortName', 'RA__Pr', 'RA__Sc', 'RL__Pr', 'R_Process', + 'RT__ChG', 'RT__ChGt', 'RT__ChLDB', 'RT__ChLDW', 'RT__ChPR', 'RT__ChPrintA', + 'RT__ChPrintC', 'RT__ChPrintO', 'RT__ChPrintS', 'RT__ChPS', 'RT__ChR', + 'RT__ChSTB', 'RT__ChSTW', 'RT__ChT', 'RT__Err', 'RT__TrPS', 'RV__Pr', + 'Symb__Tab', 'Unsigned__Compare', 'WV__Pr', 'Z__Region'), + prefix='(?i)', suffix=r'\b'), + Name.Builtin, '#pop'), + # Other built-in symbols + (words(( + 'call', 'copy', 'create', 'DEBUG', 'destroy', 'DICT_CHAR_SIZE', + 'DICT_ENTRY_BYTES', 'DICT_IS_UNICODE', 'DICT_WORD_SIZE', 'DOUBLE_HI_INFINITY', + 'DOUBLE_HI_NAN', 'DOUBLE_HI_NINFINITY', 'DOUBLE_LO_INFINITY', 'DOUBLE_LO_NAN', + 'DOUBLE_LO_NINFINITY', 'false', 'FLOAT_INFINITY', 'FLOAT_NAN', 'FLOAT_NINFINITY', + 'GOBJFIELD_CHAIN', 'GOBJFIELD_CHILD', 'GOBJFIELD_NAME', 'GOBJFIELD_PARENT', + 'GOBJFIELD_PROPTAB', 'GOBJFIELD_SIBLING', 'GOBJ_EXT_START', + 'GOBJ_TOTAL_LENGTH', 'Grammar__Version', 'INDIV_PROP_START', 'INFIX', + 'infix__watching', 'MODULE_MODE', 'name', 'nothing', 'NUM_ATTR_BYTES', 'print', + 'print_to_array', 'recreate', 'remaining', 'self', 'sender', 'STRICT_MODE', + 'sw__var', 'sys__glob0', 'sys__glob1', 'sys__glob2', 'sys_statusline_flag', + 'TARGET_GLULX', 'TARGET_ZCODE', 'temp__global2', 'temp__global3', + 'temp__global4', 'temp_global', 'true', 'USE_MODULES', 'WORDSIZE'), + prefix='(?i)', suffix=r'\b'), + Name.Builtin, '#pop'), + # Other values + (_name, Name, '#pop') + ], + 'value?': [ + include('value'), + default('#pop') + ], + # Strings + 'dictionary-word': [ + (rf'[~^]+|//[^{_squote}]*', String.Escape), + (rf'[^~^/\\@({{{_squote}]+', String.Single), + (r'[/({]', String.Single), + (r'@\{[0-9a-fA-F]*\}', String.Escape), + (r'@.{2}', String.Escape), + (rf'[{_squote}]', String.Single, '#pop') + ], + 'string': [ + (r'[~^]+', String.Escape), + (rf'[^~^\\@({{{_dquote}]+', String.Double), + (r'[({]', String.Double), + (r'\\', String.Escape), + (rf'@(\\\s*[{_newline}]\s*)*@((\\\s*[{_newline}]\s*)*[0-9])*', String.Escape), + (rf'@(\\\s*[{_newline}]\s*)*[({{]((\\\s*[{_newline}]\s*)*[0-9a-zA-Z_])*' + rf'(\\\s*[{_newline}]\s*)*[)}}]', + String.Escape), + (rf'@(\\\s*[{_newline}]\s*)*.(\\\s*[{_newline}]\s*)*.', + String.Escape), + (rf'[{_dquote}]', String.Double, '#pop') + ], + 'plain-string': [ + (rf'[^~^\\({{\[\]{_dquote}]+', String.Double), + (r'[~^({\[\]]', String.Double), + (r'\\', String.Escape), + (rf'[{_dquote}]', String.Double, '#pop') + ], + # Names + '_constant': [ + include('_whitespace'), + (_name, Name.Constant, '#pop'), + include('value') + ], + 'constant*': [ + include('_whitespace'), + (r',', Punctuation), + (r'=', Punctuation, 'value?'), + (_name, Name.Constant, 'value?'), + default('#pop') + ], + '_global': [ + include('_whitespace'), + (_name, Name.Variable.Global, '#pop'), + include('value') + ], + 'label?': [ + include('_whitespace'), + (_name, Name.Label, '#pop'), + default('#pop') + ], + 'variable?': [ + include('_whitespace'), + (_name, Name.Variable, '#pop'), + default('#pop') + ], + # Values after hashes + 'obsolete-dictionary-word': [ + (r'\S\w*', String.Other, '#pop') + ], + 'system-constant': [ + include('_whitespace'), + (_name, Name.Builtin, '#pop') + ], + + # Directives + 'directive': [ + include('_whitespace'), + (r'#', Punctuation), + (r';', Punctuation, '#pop'), + (r'\[', Punctuation, + ('default', 'statements', 'locals', 'routine-name?')), + (words(( + 'abbreviate', 'endif', 'dictionary', 'ifdef', 'iffalse', 'ifndef', 'ifnot', + 'iftrue', 'ifv3', 'ifv5', 'release', 'serial', 'switches', 'system_file', + 'version'), prefix='(?i)', suffix=r'\b'), + Keyword, 'default'), + (r'(?i)(array|global)\b', Keyword, + ('default', 'directive-keyword?', '_global')), + (r'(?i)attribute\b', Keyword, ('default', 'alias?', '_constant')), + (r'(?i)class\b', Keyword, + ('object-body', 'duplicates', 'class-name')), + (r'(?i)(constant|default)\b', Keyword, + ('default', 'constant*')), + (r'(?i)(end\b)(.*)', bygroups(Keyword, Text)), + (r'(?i)(extend|verb)\b', Keyword, 'grammar'), + (r'(?i)fake_action\b', Keyword, ('default', '_constant')), + (r'(?i)import\b', Keyword, 'manifest'), + (r'(?i)(include|link|origsource)\b', Keyword, + ('default', 'before-plain-string?')), + (r'(?i)(lowstring|undef)\b', Keyword, ('default', '_constant')), + (r'(?i)message\b', Keyword, ('default', 'diagnostic')), + (r'(?i)(nearby|object)\b', Keyword, + ('object-body', '_object-head')), + (r'(?i)property\b', Keyword, + ('default', 'alias?', '_constant', 'property-keyword*')), + (r'(?i)replace\b', Keyword, + ('default', 'routine-name?', 'routine-name?')), + (r'(?i)statusline\b', Keyword, ('default', 'directive-keyword?')), + (r'(?i)stub\b', Keyword, ('default', 'routine-name?')), + (r'(?i)trace\b', Keyword, + ('default', 'trace-keyword?', 'trace-keyword?')), + (r'(?i)zcharacter\b', Keyword, + ('default', 'directive-keyword?', 'directive-keyword?')), + (_name, Name.Class, ('object-body', '_object-head')) + ], + # [, Replace, Stub + 'routine-name?': [ + include('_whitespace'), + (_name, Name.Function, '#pop'), + default('#pop') + ], + 'locals': [ + include('_whitespace'), + (r';', Punctuation, '#pop'), + (r'\*', Punctuation), + (r'"', String.Double, 'plain-string'), + (_name, Name.Variable) + ], + # Array + 'many-values': [ + include('_whitespace'), + (r';', Punctuation), + (r'\]', Punctuation, '#pop'), + (r':', Error), + default(('expression', '_expression')) + ], + # Attribute, Property + 'alias?': [ + include('_whitespace'), + (r'alias\b', Keyword, ('#pop', '_constant')), + default('#pop') + ], + # Class, Object, Nearby + 'class-name': [ + include('_whitespace'), + (r'(?=[,;]|(class|has|private|with)\b)', Text, '#pop'), + (_name, Name.Class, '#pop') + ], + 'duplicates': [ + include('_whitespace'), + (r'\(', Punctuation, ('#pop', 'expression', '_expression')), + default('#pop') + ], + '_object-head': [ + (rf'[{_dash}]>', Punctuation), + (r'(class|has|private|with)\b', Keyword.Declaration, '#pop'), + include('_global') + ], + 'object-body': [ + include('_whitespace'), + (r';', Punctuation, '#pop:2'), + (r',', Punctuation), + (r'class\b', Keyword.Declaration, 'class-segment'), + (r'(has|private|with)\b', Keyword.Declaration), + (r':', Error), + default(('_object-expression', '_expression')) + ], + 'class-segment': [ + include('_whitespace'), + (r'(?=[,;]|(class|has|private|with)\b)', Text, '#pop'), + (_name, Name.Class), + default('value') + ], + # Extend, Verb + 'grammar': [ + include('_whitespace'), + (r'=', Punctuation, ('#pop', 'default')), + (r'\*', Punctuation, ('#pop', 'grammar-line')), + default('_directive-keyword') + ], + 'grammar-line': [ + include('_whitespace'), + (r';', Punctuation, '#pop'), + (r'[/*]', Punctuation), + (rf'[{_dash}]>', Punctuation, 'value'), + (r'(noun|scope)\b', Keyword, '=routine'), + default('_directive-keyword') + ], + '=routine': [ + include('_whitespace'), + (r'=', Punctuation, 'routine-name?'), + default('#pop') + ], + # Import + 'manifest': [ + include('_whitespace'), + (r';', Punctuation, '#pop'), + (r',', Punctuation), + (r'(?i)global\b', Keyword, '_global'), + default('_global') + ], + # Include, Link, Message + 'diagnostic': [ + include('_whitespace'), + (rf'[{_dquote}]', String.Double, ('#pop', 'message-string')), + default(('#pop', 'before-plain-string?', 'directive-keyword?')) + ], + 'before-plain-string?': [ + include('_whitespace'), + (rf'[{_dquote}]', String.Double, ('#pop', 'plain-string')), + default('#pop') + ], + 'message-string': [ + (r'[~^]+', String.Escape), + include('plain-string') + ], + + # Keywords used in directives + '_directive-keyword!': [ + include('_whitespace'), + (words(( + 'additive', 'alias', 'buffer', 'class', 'creature', 'data', 'error', 'fatalerror', + 'first', 'has', 'held', 'individual', 'initial', 'initstr', 'last', 'long', 'meta', + 'multi', 'multiexcept', 'multiheld', 'multiinside', 'noun', 'number', 'only', + 'private', 'replace', 'reverse', 'scope', 'score', 'special', 'string', 'table', + 'terminating', 'time', 'topic', 'warning', 'with'), suffix=r'\b'), + Keyword, '#pop'), + (r'static\b', Keyword), + (rf'[{_dash}]{{1,2}}>|[+=]', Punctuation, '#pop') + ], + '_directive-keyword': [ + include('_directive-keyword!'), + include('value') + ], + 'directive-keyword?': [ + include('_directive-keyword!'), + default('#pop') + ], + 'property-keyword*': [ + include('_whitespace'), + (words(('additive', 'individual', 'long'), + suffix=rf'\b(?=(\s*|(![^{_newline}]*[{_newline}]))*[_a-zA-Z])'), + Keyword), + default('#pop') + ], + 'trace-keyword?': [ + include('_whitespace'), + (words(( + 'assembly', 'dictionary', 'expressions', 'lines', 'linker', + 'objects', 'off', 'on', 'symbols', 'tokens', 'verbs'), suffix=r'\b'), + Keyword, '#pop'), + default('#pop') + ], + + # Statements + 'statements': [ + include('_whitespace'), + (r'\]', Punctuation, '#pop'), + (r'[;{}]', Punctuation), + (words(( + 'box', 'break', 'continue', 'default', 'give', 'inversion', + 'new_line', 'quit', 'read', 'remove', 'return', 'rfalse', 'rtrue', + 'spaces', 'string', 'until'), suffix=r'\b'), + Keyword, 'default'), + (r'(do|else)\b', Keyword), + (r'(font|style)\b', Keyword, + ('default', 'miscellaneous-keyword?')), + (r'for\b', Keyword, ('for', '(?')), + (r'(if|switch|while)', Keyword, + ('expression', '_expression', '(?')), + (r'(jump|save|restore)\b', Keyword, ('default', 'label?')), + (r'objectloop\b', Keyword, + ('_keyword-expression', 'variable?', '(?')), + (rf'print(_ret)?\b|(?=[{_dquote}])', Keyword, 'print-list'), + (r'\.', Name.Label, 'label?'), + (r'@', Keyword, 'opcode'), + (r'#(?![agrnw]\$|#)', Punctuation, 'directive'), + (r'<', Punctuation, 'default'), + (r'move\b', Keyword, + ('default', '_keyword-expression', '_expression')), + default(('default', '_keyword-expression', '_expression')) + ], + 'miscellaneous-keyword?': [ + include('_whitespace'), + (r'(bold|fixed|from|near|off|on|reverse|roman|to|underline)\b', + Keyword, '#pop'), + (r'(a|A|an|address|char|name|number|object|property|string|the|' + rf'The)\b(?=(\s+|(![^{_newline}]*))*\))', Keyword.Pseudo, + '#pop'), + (rf'{_name}(?=(\s+|(![^{_newline}]*))*\))', Name.Function, + '#pop'), + default('#pop') + ], + '(?': [ + include('_whitespace'), + (r'\(', Punctuation, '#pop'), + default('#pop') + ], + 'for': [ + include('_whitespace'), + (r';', Punctuation, ('_for-expression', '_expression')), + default(('_for-expression', '_expression')) + ], + 'print-list': [ + include('_whitespace'), + (r';', Punctuation, '#pop'), + (r':', Error), + default(('_list-expression', '_expression', '_list-expression', 'form')) + ], + 'form': [ + include('_whitespace'), + (r'\(', Punctuation, ('#pop', 'miscellaneous-keyword?')), + default('#pop') + ], + + # Assembly + 'opcode': [ + include('_whitespace'), + (rf'[{_dquote}]', String.Double, ('operands', 'plain-string')), + (rf'[{_dash}]{{1,2}}>', Punctuation, 'operands'), + (_name, Keyword, 'operands') + ], + 'operands': [ + (r':', Error), + default(('_assembly-expression', '_expression')) + ] + } + + def get_tokens_unprocessed(self, text): + # 'in' is either a keyword or an operator. + # If the token two tokens after 'in' is ')', 'in' is a keyword: + # objectloop(a in b) + # Otherwise, it is an operator: + # objectloop(a in b && true) + objectloop_queue = [] + objectloop_token_count = -1 + previous_token = None + for index, token, value in RegexLexer.get_tokens_unprocessed(self, + text): + if previous_token is Name.Variable and value == 'in': + objectloop_queue = [[index, token, value]] + objectloop_token_count = 2 + elif objectloop_token_count > 0: + if token not in Comment and token not in Text: + objectloop_token_count -= 1 + objectloop_queue.append((index, token, value)) + else: + if objectloop_token_count == 0: + if objectloop_queue[-1][2] == ')': + objectloop_queue[0][1] = Keyword + while objectloop_queue: + yield objectloop_queue.pop(0) + objectloop_token_count = -1 + yield index, token, value + if token not in Comment and token not in Text: + previous_token = token + while objectloop_queue: + yield objectloop_queue.pop(0) + + def analyse_text(text): + """We try to find a keyword which seem relatively common, unfortunately + there is a decent overlap with Smalltalk keywords otherwise here..""" + result = 0 + if re.search('\borigsource\b', text, re.IGNORECASE): + result += 0.05 + + return result + + +class Inform7Lexer(RegexLexer): + """ + For Inform 7 source code. + """ + + name = 'Inform 7' + url = 'http://inform7.com/' + aliases = ['inform7', 'i7'] + filenames = ['*.ni', '*.i7x'] + version_added = '2.0' + + flags = re.MULTILINE | re.DOTALL + + _dash = Inform6Lexer._dash + _dquote = Inform6Lexer._dquote + _newline = Inform6Lexer._newline + _start = rf'\A|(?<=[{_newline}])' + + # There are three variants of Inform 7, differing in how to + # interpret at signs and braces in I6T. In top-level inclusions, at + # signs in the first column are inweb syntax. In phrase definitions + # and use options, tokens in braces are treated as I7. Use options + # also interpret "{N}". + tokens = {} + token_variants = ['+i6t-not-inline', '+i6t-inline', '+i6t-use-option'] + + for level in token_variants: + tokens[level] = { + '+i6-root': list(Inform6Lexer.tokens['root']), + '+i6t-root': [ # For Inform6TemplateLexer + (rf'[^{Inform6Lexer._newline}]*', Comment.Preproc, + ('directive', '+p')) + ], + 'root': [ + (r'(\|?\s)+', Text), + (r'\[', Comment.Multiline, '+comment'), + (rf'[{_dquote}]', Generic.Heading, + ('+main', '+titling', '+titling-string')), + default(('+main', '+heading?')) + ], + '+titling-string': [ + (rf'[^{_dquote}]+', Generic.Heading), + (rf'[{_dquote}]', Generic.Heading, '#pop') + ], + '+titling': [ + (r'\[', Comment.Multiline, '+comment'), + (rf'[^{_dquote}.;:|{_newline}]+', Generic.Heading), + (rf'[{_dquote}]', Generic.Heading, '+titling-string'), + (rf'[{_newline}]{{2}}|(?<=[\s{_dquote}])\|[\s{_dquote}]', + Text, ('#pop', '+heading?')), + (rf'[.;:]|(?<=[\s{_dquote}])\|', Text, '#pop'), + (rf'[|{_newline}]', Generic.Heading) + ], + '+main': [ + (rf'(?i)[^{_dquote}:a\[(|{_newline}]+', Text), + (rf'[{_dquote}]', String.Double, '+text'), + (r':', Text, '+phrase-definition'), + (r'(?i)\bas\b', Text, '+use-option'), + (r'\[', Comment.Multiline, '+comment'), + (rf'(\([{_dash}])(.*?)([{_dash}]\))', + bygroups(Punctuation, + using(this, state=('+i6-root', 'directive'), + i6t='+i6t-not-inline'), Punctuation)), + (rf'({_start}|(?<=[\s;:.{_dquote}]))\|\s|[{_newline}]{{2,}}', Text, '+heading?'), + (rf'(?i)[a(|{_newline}]', Text) + ], + '+phrase-definition': [ + (r'\s+', Text), + (r'\[', Comment.Multiline, '+comment'), + (rf'(\([{_dash}])(.*?)([{_dash}]\))', + bygroups(Punctuation, + using(this, state=('+i6-root', 'directive', + 'default', 'statements'), + i6t='+i6t-inline'), Punctuation), '#pop'), + default('#pop') + ], + '+use-option': [ + (r'\s+', Text), + (r'\[', Comment.Multiline, '+comment'), + (rf'(\([{_dash}])(.*?)([{_dash}]\))', + bygroups(Punctuation, + using(this, state=('+i6-root', 'directive'), + i6t='+i6t-use-option'), Punctuation), '#pop'), + default('#pop') + ], + '+comment': [ + (r'[^\[\]]+', Comment.Multiline), + (r'\[', Comment.Multiline, '#push'), + (r'\]', Comment.Multiline, '#pop') + ], + '+text': [ + (rf'[^\[{_dquote}]+', String.Double), + (r'\[.*?\]', String.Interpol), + (rf'[{_dquote}]', String.Double, '#pop') + ], + '+heading?': [ + (r'(\|?\s)+', Text), + (r'\[', Comment.Multiline, '+comment'), + (rf'[{_dash}]{{4}}\s+', Text, '+documentation-heading'), + (rf'[{_dash}]{{1,3}}', Text), + (rf'(?i)(volume|book|part|chapter|section)\b[^{_newline}]*', + Generic.Heading, '#pop'), + default('#pop') + ], + '+documentation-heading': [ + (r'\s+', Text), + (r'\[', Comment.Multiline, '+comment'), + (r'(?i)documentation\s+', Text, '+documentation-heading2'), + default('#pop') + ], + '+documentation-heading2': [ + (r'\s+', Text), + (r'\[', Comment.Multiline, '+comment'), + (rf'[{_dash}]{{4}}\s', Text, '+documentation'), + default('#pop:2') + ], + '+documentation': [ + (rf'(?i)({_start})\s*(chapter|example)\s*:[^{_newline}]*', Generic.Heading), + (rf'(?i)({_start})\s*section\s*:[^{_newline}]*', + Generic.Subheading), + (rf'(({_start})\t.*?[{_newline}])+', + using(this, state='+main')), + (rf'[^{_newline}\[]+|[{_newline}\[]', Text), + (r'\[', Comment.Multiline, '+comment'), + ], + '+i6t-not-inline': [ + (rf'({_start})@c( .*?)?([{_newline}]|\Z)', + Comment.Preproc), + (rf'({_start})@([{_dash}]+|Purpose:)[^{_newline}]*', + Comment.Preproc), + (rf'({_start})@p( .*?)?([{_newline}]|\Z)', + Generic.Heading, '+p') + ], + '+i6t-use-option': [ + include('+i6t-not-inline'), + (r'(\{)(N)(\})', bygroups(Punctuation, Text, Punctuation)) + ], + '+i6t-inline': [ + (r'(\{)(\S[^}]*)?(\})', + bygroups(Punctuation, using(this, state='+main'), + Punctuation)) + ], + '+i6t': [ + (rf'(\{{[{_dash}])(![^}}]*)(\}}?)', + bygroups(Punctuation, Comment.Single, Punctuation)), + (rf'(\{{[{_dash}])(lines)(:)([^}}]*)(\}}?)', + bygroups(Punctuation, Keyword, Punctuation, Text, + Punctuation), '+lines'), + (rf'(\{{[{_dash}])([^:}}]*)(:?)([^}}]*)(\}}?)', + bygroups(Punctuation, Keyword, Punctuation, Text, + Punctuation)), + (r'(\(\+)(.*?)(\+\)|\Z)', + bygroups(Punctuation, using(this, state='+main'), + Punctuation)) + ], + '+p': [ + (r'[^@]+', Comment.Preproc), + (rf'({_start})@c( .*?)?([{_newline}]|\Z)', + Comment.Preproc, '#pop'), + (rf'({_start})@([{_dash}]|Purpose:)', Comment.Preproc), + (rf'({_start})@p( .*?)?([{_newline}]|\Z)', + Generic.Heading), + (r'@', Comment.Preproc) + ], + '+lines': [ + (rf'({_start})@c( .*?)?([{_newline}]|\Z)', + Comment.Preproc), + (rf'({_start})@([{_dash}]|Purpose:)[^{_newline}]*', + Comment.Preproc), + (rf'({_start})@p( .*?)?([{_newline}]|\Z)', + Generic.Heading, '+p'), + (rf'({_start})@\w*[ {_newline}]', Keyword), + (rf'![^{_newline}]*', Comment.Single), + (rf'(\{{)([{_dash}]endlines)(\}})', + bygroups(Punctuation, Keyword, Punctuation), '#pop'), + (rf'[^@!{{]+?([{_newline}]|\Z)|.', Text) + ] + } + # Inform 7 can include snippets of Inform 6 template language, + # so all of Inform6Lexer's states are copied here, with + # modifications to account for template syntax. Inform7Lexer's + # own states begin with '+' to avoid name conflicts. Some of + # Inform6Lexer's states begin with '_': these are not modified. + # They deal with template syntax either by including modified + # states, or by matching r'' then pushing to modified states. + for token in Inform6Lexer.tokens: + if token == 'root': + continue + tokens[level][token] = list(Inform6Lexer.tokens[token]) + if not token.startswith('_'): + tokens[level][token][:0] = [include('+i6t'), include(level)] + + def __init__(self, **options): + level = options.get('i6t', '+i6t-not-inline') + if level not in self._all_tokens: + self._tokens = self.__class__.process_tokendef(level) + else: + self._tokens = self._all_tokens[level] + RegexLexer.__init__(self, **options) + + +class Inform6TemplateLexer(Inform7Lexer): + """ + For Inform 6 template code. + """ + + name = 'Inform 6 template' + aliases = ['i6t'] + filenames = ['*.i6t'] + version_added = '2.0' + + def get_tokens_unprocessed(self, text, stack=('+i6t-root',)): + return Inform7Lexer.get_tokens_unprocessed(self, text, stack) + + +class Tads3Lexer(RegexLexer): + """ + For TADS 3 source code. + """ + + name = 'TADS 3' + aliases = ['tads3'] + filenames = ['*.t'] + url = 'https://www.tads.org' + version_added = '' + + flags = re.DOTALL | re.MULTILINE + + _comment_single = r'(?://(?:[^\\\n]|\\+[\w\W])*$)' + _comment_multiline = r'(?:/\*(?:[^*]|\*(?!/))*\*/)' + _escape = (r'(?:\\(?:[\n\\<>"\'^v bnrt]|u[\da-fA-F]{,4}|x[\da-fA-F]{,2}|' + r'[0-3]?[0-7]{1,2}))') + _name = r'(?:[_a-zA-Z]\w*)' + _no_quote = r'(?=\s|\\?>)' + _operator = (r'(?:&&|\|\||\+\+|--|\?\?|::|[.,@\[\]~]|' + r'(?:[=+\-*/%!&|^]|<>?>?)=?)') + _ws = rf'(?:\\|\s|{_comment_single}|{_comment_multiline})' + _ws_pp = rf'(?:\\\n|[^\S\n]|{_comment_single}|{_comment_multiline})' + + def _make_string_state(triple, double, verbatim=None, _escape=_escape): + if verbatim: + verbatim = ''.join([f'(?:{re.escape(c.lower())}|{re.escape(c.upper())})' + for c in verbatim]) + char = r'"' if double else r"'" + token = String.Double if double else String.Single + escaped_quotes = rf'+|{char}(?!{char}{{2}})' if triple else r'' + prefix = '{}{}'.format('t' if triple else '', 'd' if double else 's') + tag_state_name = f'{prefix}qt' + state = [] + if triple: + state += [ + (rf'{char}{{3,}}', token, '#pop'), + (rf'\\{char}+', String.Escape), + (char, token) + ] + else: + state.append((char, token, '#pop')) + state += [ + include('s/verbatim'), + (rf'[^\\<&{{}}{char}]+', token) + ] + if verbatim: + # This regex can't use `(?i)` because escape sequences are + # case-sensitive. `<\XMP>` works; `<\xmp>` doesn't. + state.append((rf'\\?<(/|\\\\|(?!{_escape})\\){verbatim}(?=[\s=>])', + Name.Tag, ('#pop', f'{prefix}qs', tag_state_name))) + else: + state += [ + (rf'\\?<\\{char}]|<(?!<)|\\{char}{escaped_quotes}|{_escape}|\\.)*>?', Comment.Multiline), + (r'(?i)\\?]|\\>)', Name.Tag, + ('#pop', f'{prefix}qs/listing', tag_state_name)), + (r'(?i)\\?]|\\>)', Name.Tag, + ('#pop', f'{prefix}qs/xmp', tag_state_name)), + (rf'\\?<([^\s=><\\{char}]|<(?!<)|\\{char}{escaped_quotes}|{_escape}|\\.)*', Name.Tag, + tag_state_name), + include('s/entity') + ] + state += [ + include('s/escape'), + (rf'\{{([^}}<\\{char}]|<(?!<)|\\{char}{escaped_quotes}|{_escape}|\\.)*\}}', String.Interpol), + (r'[\\&{}<]', token) + ] + return state + + def _make_tag_state(triple, double, _escape=_escape): + char = r'"' if double else r"'" + quantifier = r'{3,}' if triple else r'' + state_name = '{}{}qt'.format('t' if triple else '', 'd' if double else 's') + token = String.Double if double else String.Single + escaped_quotes = rf'+|{char}(?!{char}{{2}})' if triple else r'' + return [ + (rf'{char}{quantifier}', token, '#pop:2'), + (r'(\s|\\\n)+', Text), + (r'(=)(\\?")', bygroups(Punctuation, String.Double), + f'dqs/{state_name}'), + (r"(=)(\\?')", bygroups(Punctuation, String.Single), + f'sqs/{state_name}'), + (r'=', Punctuation, f'uqs/{state_name}'), + (r'\\?>', Name.Tag, '#pop'), + (rf'\{{([^}}<\\{char}]|<(?!<)|\\{char}{escaped_quotes}|{_escape}|\\.)*\}}', String.Interpol), + (rf'([^\s=><\\{char}]|<(?!<)|\\{char}{escaped_quotes}|{_escape}|\\.)+', Name.Attribute), + include('s/escape'), + include('s/verbatim'), + include('s/entity'), + (r'[\\{}&]', Name.Attribute) + ] + + def _make_attribute_value_state(terminator, host_triple, host_double, + _escape=_escape): + token = (String.Double if terminator == r'"' else + String.Single if terminator == r"'" else String.Other) + host_char = r'"' if host_double else r"'" + host_quantifier = r'{3,}' if host_triple else r'' + host_token = String.Double if host_double else String.Single + escaped_quotes = (rf'+|{host_char}(?!{host_char}{{2}})' + if host_triple else r'') + return [ + (rf'{host_char}{host_quantifier}', host_token, '#pop:3'), + (r'{}{}'.format(r'' if token is String.Other else r'\\?', terminator), + token, '#pop'), + include('s/verbatim'), + include('s/entity'), + (rf'\{{([^}}<\\{host_char}]|<(?!<)|\\{host_char}{escaped_quotes}|{_escape}|\\.)*\}}', String.Interpol), + (r'([^\s"\'<%s{}\\&])+' % (r'>' if token is String.Other else r''), + token), + include('s/escape'), + (r'["\'\s&{<}\\]', token) + ] + + tokens = { + 'root': [ + ('\ufeff', Text), + (r'\{', Punctuation, 'object-body'), + (r';+', Punctuation), + (r'(?=(argcount|break|case|catch|continue|default|definingobj|' + r'delegated|do|else|for|foreach|finally|goto|if|inherited|' + r'invokee|local|nil|new|operator|replaced|return|self|switch|' + r'targetobj|targetprop|throw|true|try|while)\b)', Text, 'block'), + (rf'({_name})({_ws}*)(\()', + bygroups(Name.Function, using(this, state='whitespace'), + Punctuation), + ('block?/root', 'more/parameters', 'main/parameters')), + include('whitespace'), + (r'\++', Punctuation), + (r'[^\s!"%-(*->@-_a-z{-~]+', Error), # Averts an infinite loop + (r'(?!\Z)', Text, 'main/root') + ], + 'main/root': [ + include('main/basic'), + default(('#pop', 'object-body/no-braces', 'classes', 'class')) + ], + 'object-body/no-braces': [ + (r';', Punctuation, '#pop'), + (r'\{', Punctuation, ('#pop', 'object-body')), + include('object-body') + ], + 'object-body': [ + (r';', Punctuation), + (r'\{', Punctuation, '#push'), + (r'\}', Punctuation, '#pop'), + (r':', Punctuation, ('classes', 'class')), + (rf'({_name}?)({_ws}*)(\()', + bygroups(Name.Function, using(this, state='whitespace'), + Punctuation), + ('block?', 'more/parameters', 'main/parameters')), + (rf'({_name})({_ws}*)(\{{)', + bygroups(Name.Function, using(this, state='whitespace'), + Punctuation), 'block'), + (rf'({_name})({_ws}*)(:)', + bygroups(Name.Variable, using(this, state='whitespace'), + Punctuation), + ('object-body/no-braces', 'classes', 'class')), + include('whitespace'), + (rf'->|{_operator}', Punctuation, 'main'), + default('main/object-body') + ], + 'main/object-body': [ + include('main/basic'), + (rf'({_name})({_ws}*)(=?)', + bygroups(Name.Variable, using(this, state='whitespace'), + Punctuation), ('#pop', 'more', 'main')), + default('#pop:2') + ], + 'block?/root': [ + (r'\{', Punctuation, ('#pop', 'block')), + include('whitespace'), + (r'(?=[\[\'"<(:])', Text, # It might be a VerbRule macro. + ('#pop', 'object-body/no-braces', 'grammar', 'grammar-rules')), + # It might be a macro like DefineAction. + default(('#pop', 'object-body/no-braces')) + ], + 'block?': [ + (r'\{', Punctuation, ('#pop', 'block')), + include('whitespace'), + default('#pop') + ], + 'block/basic': [ + (r'[;:]+', Punctuation), + (r'\{', Punctuation, '#push'), + (r'\}', Punctuation, '#pop'), + (r'default\b', Keyword.Reserved), + (rf'({_name})({_ws}*)(:)', + bygroups(Name.Label, using(this, state='whitespace'), + Punctuation)), + include('whitespace') + ], + 'block': [ + include('block/basic'), + (r'(?!\Z)', Text, ('more', 'main')) + ], + 'block/embed': [ + (r'>>', String.Interpol, '#pop'), + include('block/basic'), + (r'(?!\Z)', Text, ('more/embed', 'main')) + ], + 'main/basic': [ + include('whitespace'), + (r'\(', Punctuation, ('#pop', 'more', 'main')), + (r'\[', Punctuation, ('#pop', 'more/list', 'main')), + (r'\{', Punctuation, ('#pop', 'more/inner', 'main/inner', + 'more/parameters', 'main/parameters')), + (r'\*|\.{3}', Punctuation, '#pop'), + (r'(?i)0x[\da-f]+', Number.Hex, '#pop'), + (r'(\d+\.(?!\.)\d*|\.\d+)([eE][-+]?\d+)?|\d+[eE][-+]?\d+', + Number.Float, '#pop'), + (r'0[0-7]+', Number.Oct, '#pop'), + (r'\d+', Number.Integer, '#pop'), + (r'"""', String.Double, ('#pop', 'tdqs')), + (r"'''", String.Single, ('#pop', 'tsqs')), + (r'"', String.Double, ('#pop', 'dqs')), + (r"'", String.Single, ('#pop', 'sqs')), + (r'R"""', String.Regex, ('#pop', 'tdqr')), + (r"R'''", String.Regex, ('#pop', 'tsqr')), + (r'R"', String.Regex, ('#pop', 'dqr')), + (r"R'", String.Regex, ('#pop', 'sqr')), + # Two-token keywords + (rf'(extern)({_ws}+)(object\b)', + bygroups(Keyword.Reserved, using(this, state='whitespace'), + Keyword.Reserved)), + (rf'(function|method)({_ws}*)(\()', + bygroups(Keyword.Reserved, using(this, state='whitespace'), + Punctuation), + ('#pop', 'block?', 'more/parameters', 'main/parameters')), + (rf'(modify)({_ws}+)(grammar\b)', + bygroups(Keyword.Reserved, using(this, state='whitespace'), + Keyword.Reserved), + ('#pop', 'object-body/no-braces', ':', 'grammar')), + (rf'(new)({_ws}+(?=(?:function|method)\b))', + bygroups(Keyword.Reserved, using(this, state='whitespace'))), + (rf'(object)({_ws}+)(template\b)', + bygroups(Keyword.Reserved, using(this, state='whitespace'), + Keyword.Reserved), ('#pop', 'template')), + (rf'(string)({_ws}+)(template\b)', + bygroups(Keyword, using(this, state='whitespace'), + Keyword.Reserved), ('#pop', 'function-name')), + # Keywords + (r'(argcount|definingobj|invokee|replaced|targetobj|targetprop)\b', + Name.Builtin, '#pop'), + (r'(break|continue|goto)\b', Keyword.Reserved, ('#pop', 'label')), + (r'(case|extern|if|intrinsic|return|static|while)\b', + Keyword.Reserved), + (r'catch\b', Keyword.Reserved, ('#pop', 'catch')), + (r'class\b', Keyword.Reserved, + ('#pop', 'object-body/no-braces', 'class')), + (r'(default|do|else|finally|try)\b', Keyword.Reserved, '#pop'), + (r'(dictionary|property)\b', Keyword.Reserved, + ('#pop', 'constants')), + (r'enum\b', Keyword.Reserved, ('#pop', 'enum')), + (r'export\b', Keyword.Reserved, ('#pop', 'main')), + (r'(for|foreach)\b', Keyword.Reserved, + ('#pop', 'more/inner', 'main/inner')), + (r'(function|method)\b', Keyword.Reserved, + ('#pop', 'block?', 'function-name')), + (r'grammar\b', Keyword.Reserved, + ('#pop', 'object-body/no-braces', 'grammar')), + (r'inherited\b', Keyword.Reserved, ('#pop', 'inherited')), + (r'local\b', Keyword.Reserved, + ('#pop', 'more/local', 'main/local')), + (r'(modify|replace|switch|throw|transient)\b', Keyword.Reserved, + '#pop'), + (r'new\b', Keyword.Reserved, ('#pop', 'class')), + (r'(nil|true)\b', Keyword.Constant, '#pop'), + (r'object\b', Keyword.Reserved, ('#pop', 'object-body/no-braces')), + (r'operator\b', Keyword.Reserved, ('#pop', 'operator')), + (r'propertyset\b', Keyword.Reserved, + ('#pop', 'propertyset', 'main')), + (r'self\b', Name.Builtin.Pseudo, '#pop'), + (r'template\b', Keyword.Reserved, ('#pop', 'template')), + # Operators + (rf'(__objref|defined)({_ws}*)(\()', + bygroups(Operator.Word, using(this, state='whitespace'), + Operator), ('#pop', 'more/__objref', 'main')), + (r'delegated\b', Operator.Word), + # Compiler-defined macros and built-in properties + (r'(__DATE__|__DEBUG|__LINE__|__FILE__|' + r'__TADS_MACRO_FORMAT_VERSION|__TADS_SYS_\w*|__TADS_SYSTEM_NAME|' + r'__TADS_VERSION_MAJOR|__TADS_VERSION_MINOR|__TADS3|__TIME__|' + r'construct|finalize|grammarInfo|grammarTag|lexicalParent|' + r'miscVocab|sourceTextGroup|sourceTextGroupName|' + r'sourceTextGroupOrder|sourceTextOrder)\b', Name.Builtin, '#pop') + ], + 'main': [ + include('main/basic'), + (_name, Name, '#pop'), + default('#pop') + ], + 'more/basic': [ + (r'\(', Punctuation, ('more/list', 'main')), + (r'\[', Punctuation, ('more', 'main')), + (r'\.{3}', Punctuation), + (r'->|\.\.', Punctuation, 'main'), + (r'(?=;)|[:)\]]', Punctuation, '#pop'), + include('whitespace'), + (_operator, Operator, 'main'), + (r'\?', Operator, ('main', 'more/conditional', 'main')), + (rf'(is|not)({_ws}+)(in\b)', + bygroups(Operator.Word, using(this, state='whitespace'), + Operator.Word)), + (r'[^\s!"%-_a-z{-~]+', Error) # Averts an infinite loop + ], + 'more': [ + include('more/basic'), + default('#pop') + ], + # Then expression (conditional operator) + 'more/conditional': [ + (r':(?!:)', Operator, '#pop'), + include('more') + ], + # Embedded expressions + 'more/embed': [ + (r'>>', String.Interpol, '#pop:2'), + include('more') + ], + # For/foreach loop initializer or short-form anonymous function + 'main/inner': [ + (r'\(', Punctuation, ('#pop', 'more/inner', 'main/inner')), + (r'local\b', Keyword.Reserved, ('#pop', 'main/local')), + include('main') + ], + 'more/inner': [ + (r'\}', Punctuation, '#pop'), + (r',', Punctuation, 'main/inner'), + (r'(in|step)\b', Keyword, 'main/inner'), + include('more') + ], + # Local + 'main/local': [ + (_name, Name.Variable, '#pop'), + include('whitespace') + ], + 'more/local': [ + (r',', Punctuation, 'main/local'), + include('more') + ], + # List + 'more/list': [ + (r'[,:]', Punctuation, 'main'), + include('more') + ], + # Parameter list + 'main/parameters': [ + (rf'({_name})({_ws}*)(?=:)', + bygroups(Name.Variable, using(this, state='whitespace')), '#pop'), + (rf'({_name})({_ws}+)({_name})', + bygroups(Name.Class, using(this, state='whitespace'), + Name.Variable), '#pop'), + (r'\[+', Punctuation), + include('main/basic'), + (_name, Name.Variable, '#pop'), + default('#pop') + ], + 'more/parameters': [ + (rf'(:)({_ws}*(?=[?=,:)]))', + bygroups(Punctuation, using(this, state='whitespace'))), + (r'[?\]]+', Punctuation), + (r'[:)]', Punctuation, ('#pop', 'multimethod?')), + (r',', Punctuation, 'main/parameters'), + (r'=', Punctuation, ('more/parameter', 'main')), + include('more') + ], + 'more/parameter': [ + (r'(?=[,)])', Text, '#pop'), + include('more') + ], + 'multimethod?': [ + (r'multimethod\b', Keyword, '#pop'), + include('whitespace'), + default('#pop') + ], + + # Statements and expressions + 'more/__objref': [ + (r',', Punctuation, 'mode'), + (r'\)', Operator, '#pop'), + include('more') + ], + 'mode': [ + (r'(error|warn)\b', Keyword, '#pop'), + include('whitespace') + ], + 'catch': [ + (r'\(+', Punctuation), + (_name, Name.Exception, ('#pop', 'variables')), + include('whitespace') + ], + 'enum': [ + include('whitespace'), + (r'token\b', Keyword, ('#pop', 'constants')), + default(('#pop', 'constants')) + ], + 'grammar': [ + (r'\)+', Punctuation), + (r'\(', Punctuation, 'grammar-tag'), + (r':', Punctuation, 'grammar-rules'), + (_name, Name.Class), + include('whitespace') + ], + 'grammar-tag': [ + include('whitespace'), + (r'"""([^\\"<]|""?(?!")|\\"+|\\.|<(?!<))+("{3,}|<<)|' + r'R"""([^\\"]|""?(?!")|\\"+|\\.)+"{3,}|' + r"'''([^\\'<]|''?(?!')|\\'+|\\.|<(?!<))+('{3,}|<<)|" + r"R'''([^\\']|''?(?!')|\\'+|\\.)+'{3,}|" + r'"([^\\"<]|\\.|<(?!<))+("|<<)|R"([^\\"]|\\.)+"|' + r"'([^\\'<]|\\.|<(?!<))+('|<<)|R'([^\\']|\\.)+'|" + r"([^)\s\\/]|/(?![/*]))+|\)", String.Other, '#pop') + ], + 'grammar-rules': [ + include('string'), + include('whitespace'), + (rf'(\[)({_ws}*)(badness)', + bygroups(Punctuation, using(this, state='whitespace'), Keyword), + 'main'), + (rf'->|{_operator}|[()]', Punctuation), + (_name, Name.Constant), + default('#pop:2') + ], + ':': [ + (r':', Punctuation, '#pop') + ], + 'function-name': [ + (r'(<<([^>]|>>>|>(?!>))*>>)+', String.Interpol), + (rf'(?={_name}?{_ws}*[({{])', Text, '#pop'), + (_name, Name.Function, '#pop'), + include('whitespace') + ], + 'inherited': [ + (r'<', Punctuation, ('#pop', 'classes', 'class')), + include('whitespace'), + (_name, Name.Class, '#pop'), + default('#pop') + ], + 'operator': [ + (r'negate\b', Operator.Word, '#pop'), + include('whitespace'), + (_operator, Operator), + default('#pop') + ], + 'propertyset': [ + (r'\(', Punctuation, ('more/parameters', 'main/parameters')), + (r'\{', Punctuation, ('#pop', 'object-body')), + include('whitespace') + ], + 'template': [ + (r'(?=;)', Text, '#pop'), + include('string'), + (r'inherited\b', Keyword.Reserved), + include('whitespace'), + (rf'->|\?|{_operator}', Punctuation), + (_name, Name.Variable) + ], + + # Identifiers + 'class': [ + (r'\*|\.{3}', Punctuation, '#pop'), + (r'object\b', Keyword.Reserved, '#pop'), + (r'transient\b', Keyword.Reserved), + (_name, Name.Class, '#pop'), + include('whitespace'), + default('#pop') + ], + 'classes': [ + (r'[:,]', Punctuation, 'class'), + include('whitespace'), + (r'>', Punctuation, '#pop'), + default('#pop') + ], + 'constants': [ + (r',+', Punctuation), + (r';', Punctuation, '#pop'), + (r'property\b', Keyword.Reserved), + (_name, Name.Constant), + include('whitespace') + ], + 'label': [ + (_name, Name.Label, '#pop'), + include('whitespace'), + default('#pop') + ], + 'variables': [ + (r',+', Punctuation), + (r'\)', Punctuation, '#pop'), + include('whitespace'), + (_name, Name.Variable) + ], + + # Whitespace and comments + 'whitespace': [ + (rf'^{_ws_pp}*#({_comment_multiline}|[^\n]|(?<=\\)\n)*\n?', + Comment.Preproc), + (_comment_single, Comment.Single), + (_comment_multiline, Comment.Multiline), + (rf'\\+\n+{_ws_pp}*#?|\n+|([^\S\n]|\\)+', Text) + ], + + # Strings + 'string': [ + (r'"""', String.Double, 'tdqs'), + (r"'''", String.Single, 'tsqs'), + (r'"', String.Double, 'dqs'), + (r"'", String.Single, 'sqs') + ], + 's/escape': [ + (rf'\{{\{{|\}}\}}|{_escape}', String.Escape) + ], + 's/verbatim': [ + (r'<<\s*(as\s+decreasingly\s+likely\s+outcomes|cycling|else|end|' + r'first\s+time|one\s+of|only|or|otherwise|' + r'(sticky|(then\s+)?(purely\s+)?at)\s+random|stopping|' + r'(then\s+)?(half\s+)?shuffled|\|\|)\s*>>', String.Interpol), + (rf'<<(%(_({_escape}|\\?.)|[\-+ ,#]|\[\d*\]?)*\d*\.?\d*({_escape}|\\?.)|' + r'\s*((else|otherwise)\s+)?(if|unless)\b)?', + String.Interpol, ('block/embed', 'more/embed', 'main')) + ], + 's/entity': [ + (r'(?i)&(#(x[\da-f]+|\d+)|[a-z][\da-z]*);?', Name.Entity) + ], + 'tdqs': _make_string_state(True, True), + 'tsqs': _make_string_state(True, False), + 'dqs': _make_string_state(False, True), + 'sqs': _make_string_state(False, False), + 'tdqs/listing': _make_string_state(True, True, 'listing'), + 'tsqs/listing': _make_string_state(True, False, 'listing'), + 'dqs/listing': _make_string_state(False, True, 'listing'), + 'sqs/listing': _make_string_state(False, False, 'listing'), + 'tdqs/xmp': _make_string_state(True, True, 'xmp'), + 'tsqs/xmp': _make_string_state(True, False, 'xmp'), + 'dqs/xmp': _make_string_state(False, True, 'xmp'), + 'sqs/xmp': _make_string_state(False, False, 'xmp'), + + # Tags + 'tdqt': _make_tag_state(True, True), + 'tsqt': _make_tag_state(True, False), + 'dqt': _make_tag_state(False, True), + 'sqt': _make_tag_state(False, False), + 'dqs/tdqt': _make_attribute_value_state(r'"', True, True), + 'dqs/tsqt': _make_attribute_value_state(r'"', True, False), + 'dqs/dqt': _make_attribute_value_state(r'"', False, True), + 'dqs/sqt': _make_attribute_value_state(r'"', False, False), + 'sqs/tdqt': _make_attribute_value_state(r"'", True, True), + 'sqs/tsqt': _make_attribute_value_state(r"'", True, False), + 'sqs/dqt': _make_attribute_value_state(r"'", False, True), + 'sqs/sqt': _make_attribute_value_state(r"'", False, False), + 'uqs/tdqt': _make_attribute_value_state(_no_quote, True, True), + 'uqs/tsqt': _make_attribute_value_state(_no_quote, True, False), + 'uqs/dqt': _make_attribute_value_state(_no_quote, False, True), + 'uqs/sqt': _make_attribute_value_state(_no_quote, False, False), + + # Regular expressions + 'tdqr': [ + (r'[^\\"]+', String.Regex), + (r'\\"*', String.Regex), + (r'"{3,}', String.Regex, '#pop'), + (r'"', String.Regex) + ], + 'tsqr': [ + (r"[^\\']+", String.Regex), + (r"\\'*", String.Regex), + (r"'{3,}", String.Regex, '#pop'), + (r"'", String.Regex) + ], + 'dqr': [ + (r'[^\\"]+', String.Regex), + (r'\\"?', String.Regex), + (r'"', String.Regex, '#pop') + ], + 'sqr': [ + (r"[^\\']+", String.Regex), + (r"\\'?", String.Regex), + (r"'", String.Regex, '#pop') + ] + } + + def get_tokens_unprocessed(self, text, **kwargs): + pp = rf'^{self._ws_pp}*#{self._ws_pp}*' + if_false_level = 0 + for index, token, value in ( + RegexLexer.get_tokens_unprocessed(self, text, **kwargs)): + if if_false_level == 0: # Not in a false #if + if (token is Comment.Preproc and + re.match(rf'{pp}if{self._ws_pp}+(0|nil){self._ws_pp}*$\n?', value)): + if_false_level = 1 + else: # In a false #if + if token is Comment.Preproc: + if (if_false_level == 1 and + re.match(rf'{pp}el(if|se)\b', value)): + if_false_level = 0 + elif re.match(rf'{pp}if', value): + if_false_level += 1 + elif re.match(rf'{pp}endif\b', value): + if_false_level -= 1 + else: + token = Comment + yield index, token, value + + def analyse_text(text): + """This is a rather generic descriptive language without strong + identifiers. It looks like a 'GameMainDef' has to be present, + and/or a 'versionInfo' with an 'IFID' field.""" + result = 0 + if '__TADS' in text or 'GameMainDef' in text: + result += 0.2 + + # This is a fairly unique keyword which is likely used in source as well + if 'versionInfo' in text and 'IFID' in text: + result += 0.1 + + return result diff --git a/pygments/lexers/iolang.py b/pygments/lexers/iolang.py new file mode 100644 index 0000000000000000000000000000000000000000..0269384fac1ac51eb6c0fff6c0d2a380a7b3b887 --- /dev/null +++ b/pygments/lexers/iolang.py @@ -0,0 +1,61 @@ +""" + pygments.lexers.iolang + ~~~~~~~~~~~~~~~~~~~~~~ + + Lexers for the Io language. + + :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS. + :license: BSD, see LICENSE for details. +""" + +from pygments.lexer import RegexLexer +from pygments.token import Comment, Operator, Keyword, Name, String, Number, \ + Whitespace + +__all__ = ['IoLexer'] + + +class IoLexer(RegexLexer): + """ + For Io (a small, prototype-based programming language) source. + """ + name = 'Io' + url = 'http://iolanguage.com/' + filenames = ['*.io'] + aliases = ['io'] + mimetypes = ['text/x-iosrc'] + version_added = '0.10' + tokens = { + 'root': [ + (r'\n', Whitespace), + (r'\s+', Whitespace), + # Comments + (r'//(.*?)$', Comment.Single), + (r'#(.*?)$', Comment.Single), + (r'/(\\\n)?[*](.|\n)*?[*](\\\n)?/', Comment.Multiline), + (r'/\+', Comment.Multiline, 'nestedcomment'), + # DoubleQuotedString + (r'"(\\\\|\\[^\\]|[^"\\])*"', String), + # Operators + (r'::=|:=|=|\(|\)|;|,|\*|-|\+|>|<|@|!|/|\||\^|\.|%|&|\[|\]|\{|\}', + Operator), + # keywords + (r'(clone|do|doFile|doString|method|for|if|else|elseif|then)\b', + Keyword), + # constants + (r'(nil|false|true)\b', Name.Constant), + # names + (r'(Object|list|List|Map|args|Sequence|Coroutine|File)\b', + Name.Builtin), + (r'[a-zA-Z_]\w*', Name), + # numbers + (r'(\d+\.?\d*|\d*\.\d+)([eE][+-]?[0-9]+)?', Number.Float), + (r'\d+', Number.Integer) + ], + 'nestedcomment': [ + (r'[^+/]+', Comment.Multiline), + (r'/\+', Comment.Multiline, '#push'), + (r'\+/', Comment.Multiline, '#pop'), + (r'[+/]', Comment.Multiline), + ] + } diff --git a/pygments/lexers/j.py b/pygments/lexers/j.py new file mode 100644 index 0000000000000000000000000000000000000000..1eb418e1c386eac1e6be2d833ba11ae8134ca31e --- /dev/null +++ b/pygments/lexers/j.py @@ -0,0 +1,151 @@ +""" + pygments.lexers.j + ~~~~~~~~~~~~~~~~~ + + Lexer for the J programming language. + + :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS. + :license: BSD, see LICENSE for details. +""" + +from pygments.lexer import RegexLexer, words, include, bygroups +from pygments.token import Comment, Keyword, Name, Number, Operator, \ + Punctuation, String, Whitespace + +__all__ = ['JLexer'] + + +class JLexer(RegexLexer): + """ + For J source code. + """ + + name = 'J' + url = 'http://jsoftware.com/' + aliases = ['j'] + filenames = ['*.ijs'] + mimetypes = ['text/x-j'] + version_added = '2.1' + + validName = r'\b[a-zA-Z]\w*' + + tokens = { + 'root': [ + # Shebang script + (r'#!.*$', Comment.Preproc), + + # Comments + (r'NB\..*', Comment.Single), + (r'(\n+\s*)(Note)', bygroups(Whitespace, Comment.Multiline), + 'comment'), + (r'(\s*)(Note.*)', bygroups(Whitespace, Comment.Single)), + + # Whitespace + (r'\s+', Whitespace), + + # Strings + (r"'", String, 'singlequote'), + + # Definitions + (r'0\s+:\s*0', Name.Entity, 'nounDefinition'), + (r'(noun)(\s+)(define)(\s*)$', bygroups(Name.Entity, Whitespace, + Name.Entity, Whitespace), 'nounDefinition'), + (r'([1-4]|13)\s+:\s*0\b', + Name.Function, 'explicitDefinition'), + (r'(adverb|conjunction|dyad|monad|verb)(\s+)(define)\b', + bygroups(Name.Function, Whitespace, Name.Function), + 'explicitDefinition'), + + # Flow Control + (words(('for_', 'goto_', 'label_'), suffix=validName+r'\.'), Name.Label), + (words(( + 'assert', 'break', 'case', 'catch', 'catchd', + 'catcht', 'continue', 'do', 'else', 'elseif', + 'end', 'fcase', 'for', 'if', 'return', + 'select', 'throw', 'try', 'while', 'whilst', + ), suffix=r'\.'), Name.Label), + + # Variable Names + (validName, Name.Variable), + + # Standard Library + (words(( + 'ARGV', 'CR', 'CRLF', 'DEL', 'Debug', + 'EAV', 'EMPTY', 'FF', 'JVERSION', 'LF', + 'LF2', 'Note', 'TAB', 'alpha17', 'alpha27', + 'apply', 'bind', 'boxopen', 'boxxopen', 'bx', + 'clear', 'cutLF', 'cutopen', 'datatype', 'def', + 'dfh', 'drop', 'each', 'echo', 'empty', + 'erase', 'every', 'evtloop', 'exit', 'expand', + 'fetch', 'file2url', 'fixdotdot', 'fliprgb', 'getargs', + 'getenv', 'hfd', 'inv', 'inverse', 'iospath', + 'isatty', 'isutf8', 'items', 'leaf', 'list', + 'nameclass', 'namelist', 'names', 'nc', + 'nl', 'on', 'pick', 'rows', + 'script', 'scriptd', 'sign', 'sminfo', 'smoutput', + 'sort', 'split', 'stderr', 'stdin', 'stdout', + 'table', 'take', 'timespacex', 'timex', 'tmoutput', + 'toCRLF', 'toHOST', 'toJ', 'tolower', 'toupper', + 'type', 'ucp', 'ucpcount', 'usleep', 'utf8', + 'uucp', + )), Name.Function), + + # Copula + (r'=[.:]', Operator), + + # Builtins + (r'[-=+*#$%@!~`^&";:.,<>{}\[\]\\|/?]', Operator), + + # Short Keywords + (r'[abCdDeEfHiIjLMoprtT]\.', Keyword.Reserved), + (r'[aDiLpqsStux]\:', Keyword.Reserved), + (r'(_[0-9])\:', Keyword.Constant), + + # Parens + (r'\(', Punctuation, 'parentheses'), + + # Numbers + include('numbers'), + ], + + 'comment': [ + (r'[^)]', Comment.Multiline), + (r'^\)', Comment.Multiline, '#pop'), + (r'[)]', Comment.Multiline), + ], + + 'explicitDefinition': [ + (r'\b[nmuvxy]\b', Name.Decorator), + include('root'), + (r'[^)]', Name), + (r'^\)', Name.Label, '#pop'), + (r'[)]', Name), + ], + + 'numbers': [ + (r'\b_{1,2}\b', Number), + (r'_?\d+(\.\d+)?(\s*[ejr]\s*)_?\d+(\.?=\d+)?', Number), + (r'_?\d+\.(?=\d+)', Number.Float), + (r'_?\d+x', Number.Integer.Long), + (r'_?\d+', Number.Integer), + ], + + 'nounDefinition': [ + (r'[^)]+', String), + (r'^\)', Name.Label, '#pop'), + (r'[)]', String), + ], + + 'parentheses': [ + (r'\)', Punctuation, '#pop'), + # include('nounDefinition'), + include('explicitDefinition'), + include('root'), + ], + + 'singlequote': [ + (r"[^']+", String), + (r"''", String), + (r"'", String, '#pop'), + ], + } diff --git a/pygments/lexers/javascript.py b/pygments/lexers/javascript.py new file mode 100644 index 0000000000000000000000000000000000000000..d361b6e716e45792fa9f41a4137c03bb039798bc --- /dev/null +++ b/pygments/lexers/javascript.py @@ -0,0 +1,1591 @@ +""" + pygments.lexers.javascript + ~~~~~~~~~~~~~~~~~~~~~~~~~~ + + Lexers for JavaScript and related languages. + + :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS. + :license: BSD, see LICENSE for details. +""" + +import re + +from pygments.lexer import bygroups, combined, default, do_insertions, include, \ + inherit, Lexer, RegexLexer, this, using, words, line_re +from pygments.token import Text, Comment, Operator, Keyword, Name, String, \ + Number, Punctuation, Other, Generic, Whitespace +from pygments.util import get_bool_opt +import pygments.unistring as uni + +__all__ = ['JavascriptLexer', 'KalLexer', 'LiveScriptLexer', 'DartLexer', + 'TypeScriptLexer', 'LassoLexer', 'ObjectiveJLexer', + 'CoffeeScriptLexer', 'MaskLexer', 'EarlGreyLexer', 'JuttleLexer', + 'NodeConsoleLexer'] + +JS_IDENT_START = ('(?:[$_' + uni.combine('Lu', 'Ll', 'Lt', 'Lm', 'Lo', 'Nl') + + ']|\\\\u[a-fA-F0-9]{4})') +JS_IDENT_PART = ('(?:[$' + uni.combine('Lu', 'Ll', 'Lt', 'Lm', 'Lo', 'Nl', + 'Mn', 'Mc', 'Nd', 'Pc') + + '\u200c\u200d]|\\\\u[a-fA-F0-9]{4})') +JS_IDENT = JS_IDENT_START + '(?:' + JS_IDENT_PART + ')*' + + +class JavascriptLexer(RegexLexer): + """ + For JavaScript source code. + """ + + name = 'JavaScript' + url = 'https://www.ecma-international.org/publications-and-standards/standards/ecma-262/' + aliases = ['javascript', 'js'] + filenames = ['*.js', '*.jsm', '*.mjs', '*.cjs'] + mimetypes = ['application/javascript', 'application/x-javascript', + 'text/x-javascript', 'text/javascript'] + version_added = '' + + flags = re.DOTALL | re.MULTILINE + + tokens = { + 'commentsandwhitespace': [ + (r'\s+', Whitespace), + (r')?', Other), + (r'[^[<]+', Other), + ], + 'nosquarebrackets': [ + (r'\[noprocess\]', Comment.Preproc, 'noprocess'), + (r'\[', Other), + (r'<\?(lasso(script)?|=)', Comment.Preproc, 'anglebrackets'), + (r'<(!--.*?-->)?', Other), + (r'[^[<]+', Other), + ], + 'noprocess': [ + (r'\[/noprocess\]', Comment.Preproc, '#pop'), + (r'\[', Other), + (r'[^[]', Other), + ], + 'squarebrackets': [ + (r'\]', Comment.Preproc, '#pop'), + include('lasso'), + ], + 'anglebrackets': [ + (r'\?>', Comment.Preproc, '#pop'), + include('lasso'), + ], + 'lassofile': [ + (r'\]|\?>', Comment.Preproc, '#pop'), + include('lasso'), + ], + 'whitespacecomments': [ + (r'\s+', Whitespace), + (r'(//.*?)(\s*)$', bygroups(Comment.Single, Whitespace)), + (r'/\*\*!.*?\*/', String.Doc), + (r'/\*.*?\*/', Comment.Multiline), + ], + 'lasso': [ + # whitespace/comments + include('whitespacecomments'), + + # literals + (r'\d*\.\d+(e[+-]?\d+)?', Number.Float), + (r'0x[\da-f]+', Number.Hex), + (r'\d+', Number.Integer), + (r'(infinity|NaN)\b', Number), + (r"'", String.Single, 'singlestring'), + (r'"', String.Double, 'doublestring'), + (r'`[^`]*`', String.Backtick), + + # names + (r'\$[a-z_][\w.]*', Name.Variable), + (r'#([a-z_][\w.]*|\d+\b)', Name.Variable.Instance), + (r"(\.)(\s*)('[a-z_][\w.]*')", + bygroups(Name.Builtin.Pseudo, Whitespace, Name.Variable.Class)), + (r"(self)(\s*)(->)(\s*)('[a-z_][\w.]*')", + bygroups(Name.Builtin.Pseudo, Whitespace, Operator, Whitespace, + Name.Variable.Class)), + (r'(\.\.?)(\s*)([a-z_][\w.]*(=(?!=))?)', + bygroups(Name.Builtin.Pseudo, Whitespace, Name.Other.Member)), + (r'(->\\?|&)(\s*)([a-z_][\w.]*(=(?!=))?)', + bygroups(Operator, Whitespace, Name.Other.Member)), + (r'(?)(self|inherited|currentcapture|givenblock)\b', + Name.Builtin.Pseudo), + (r'-(?!infinity)[a-z_][\w.]*', Name.Attribute), + (r'(::)(\s*)([a-z_][\w.]*)', + bygroups(Punctuation, Whitespace, Name.Label)), + (r'(error_(code|msg)_\w+|Error_AddError|Error_ColumnRestriction|' + r'Error_DatabaseConnectionUnavailable|Error_DatabaseTimeout|' + r'Error_DeleteError|Error_FieldRestriction|Error_FileNotFound|' + r'Error_InvalidDatabase|Error_InvalidPassword|' + r'Error_InvalidUsername|Error_ModuleNotFound|' + r'Error_NoError|Error_NoPermission|Error_OutOfMemory|' + r'Error_ReqColumnMissing|Error_ReqFieldMissing|' + r'Error_RequiredColumnMissing|Error_RequiredFieldMissing|' + r'Error_UpdateError)\b', Name.Exception), + + # definitions + (r'(define)(\s+)([a-z_][\w.]*)(\s*)(=>)(\s*)(type|trait|thread)\b', + bygroups(Keyword.Declaration, Whitespace, Name.Class, + Whitespace, Operator, Whitespace, Keyword)), + (r'(define)(\s+)([a-z_][\w.]*)(\s*)(->)(\s*)([a-z_][\w.]*=?|[-+*/%])', + bygroups(Keyword.Declaration, Whitespace, Name.Class, + Whitespace, Operator, Whitespace, Name.Function), + 'signature'), + (r'(define)(\s+)([a-z_][\w.]*)', + bygroups(Keyword.Declaration, Whitespace, Name.Function), 'signature'), + (r'(public|protected|private|provide)(\s+)(([a-z_][\w.]*=?|[-+*/%])' + r'(?=\s*\())', bygroups(Keyword, Whitespace, Name.Function), + 'signature'), + (r'(public|protected|private|provide)(\s+)([a-z_][\w.]*)', + bygroups(Keyword, Whitespace, Name.Function)), + + # keywords + (r'(true|false|none|minimal|full|all|void)\b', Keyword.Constant), + (r'(local|var|variable|global|data(?=\s))\b', Keyword.Declaration), + (r'(array|date|decimal|duration|integer|map|pair|string|tag|xml|' + r'null|boolean|bytes|keyword|list|locale|queue|set|stack|' + r'staticarray)\b', Keyword.Type), + (r'([a-z_][\w.]*)(\s+)(in)\b', bygroups(Name, Whitespace, Keyword)), + (r'(let|into)(\s+)([a-z_][\w.]*)', bygroups(Keyword, Whitespace, Name)), + (r'require\b', Keyword, 'requiresection'), + (r'(/?)(Namespace_Using)\b', bygroups(Punctuation, Keyword.Namespace)), + (r'(/?)(Cache|Database_Names|Database_SchemaNames|' + r'Database_TableNames|Define_Tag|Define_Type|Email_Batch|' + r'Encode_Set|HTML_Comment|Handle|Handle_Error|Header|If|Inline|' + r'Iterate|LJAX_Target|Link|Link_CurrentAction|Link_CurrentGroup|' + r'Link_CurrentRecord|Link_Detail|Link_FirstGroup|Link_FirstRecord|' + r'Link_LastGroup|Link_LastRecord|Link_NextGroup|Link_NextRecord|' + r'Link_PrevGroup|Link_PrevRecord|Log|Loop|Output_None|Portal|' + r'Private|Protect|Records|Referer|Referrer|Repeating|ResultSet|' + r'Rows|Search_Args|Search_Arguments|Select|Sort_Args|' + r'Sort_Arguments|Thread_Atomic|Value_List|While|Abort|Case|Else|' + r'Fail_If|Fail_IfNot|Fail|If_Empty|If_False|If_Null|If_True|' + r'Loop_Abort|Loop_Continue|Loop_Count|Params|Params_Up|Return|' + r'Return_Value|Run_Children|SOAP_DefineTag|SOAP_LastRequest|' + r'SOAP_LastResponse|Tag_Name|ascending|average|by|define|' + r'descending|do|equals|frozen|group|handle_failure|import|in|into|' + r'join|let|match|max|min|on|order|parent|protected|provide|public|' + r'require|returnhome|skip|split_thread|sum|take|thread|to|trait|' + r'type|where|with|yield|yieldhome)\b', + bygroups(Punctuation, Keyword)), + + # other + (r',', Punctuation, 'commamember'), + (r'(and|or|not)\b', Operator.Word), + (r'([a-z_][\w.]*)(\s*)(::)(\s*)([a-z_][\w.]*)?(\s*=(?!=))', + bygroups(Name, Whitespace, Punctuation, Whitespace, Name.Label, + Operator)), + (r'(/?)([\w.]+)', bygroups(Punctuation, Name.Other)), + (r'(=)(n?bw|n?ew|n?cn|lte?|gte?|n?eq|n?rx|ft)\b', + bygroups(Operator, Operator.Word)), + (r':=|[-+*/%=<>&|!?\\]+', Operator), + (r'[{}():;,@^]', Punctuation), + ], + 'singlestring': [ + (r"'", String.Single, '#pop'), + (r"[^'\\]+", String.Single), + include('escape'), + (r"\\", String.Single), + ], + 'doublestring': [ + (r'"', String.Double, '#pop'), + (r'[^"\\]+', String.Double), + include('escape'), + (r'\\', String.Double), + ], + 'escape': [ + (r'\\(U[\da-f]{8}|u[\da-f]{4}|x[\da-f]{1,2}|[0-7]{1,3}|:[^:\n\r]+:|' + r'[abefnrtv?"\'\\]|$)', String.Escape), + ], + 'signature': [ + (r'=>', Operator, '#pop'), + (r'\)', Punctuation, '#pop'), + (r'[(,]', Punctuation, 'parameter'), + include('lasso'), + ], + 'parameter': [ + (r'\)', Punctuation, '#pop'), + (r'-?[a-z_][\w.]*', Name.Attribute, '#pop'), + (r'\.\.\.', Name.Builtin.Pseudo), + include('lasso'), + ], + 'requiresection': [ + (r'(([a-z_][\w.]*=?|[-+*/%])(?=\s*\())', Name, 'requiresignature'), + (r'(([a-z_][\w.]*=?|[-+*/%])(?=(\s*::\s*[\w.]+)?\s*,))', Name), + (r'[a-z_][\w.]*=?|[-+*/%]', Name, '#pop'), + (r'(::)(\s*)([a-z_][\w.]*)', + bygroups(Punctuation, Whitespace, Name.Label)), + (r',', Punctuation), + include('whitespacecomments'), + ], + 'requiresignature': [ + (r'(\)(?=(\s*::\s*[\w.]+)?\s*,))', Punctuation, '#pop'), + (r'\)', Punctuation, '#pop:2'), + (r'-?[a-z_][\w.]*', Name.Attribute), + (r'(::)(\s*)([a-z_][\w.]*)', + bygroups(Punctuation, Whitespace, Name.Label)), + (r'\.\.\.', Name.Builtin.Pseudo), + (r'[(,]', Punctuation), + include('whitespacecomments'), + ], + 'commamember': [ + (r'(([a-z_][\w.]*=?|[-+*/%])' + r'(?=\s*(\(([^()]*\([^()]*\))*[^)]*\)\s*)?(::[\w.\s]+)?=>))', + Name.Function, 'signature'), + include('whitespacecomments'), + default('#pop'), + ], + } + + def __init__(self, **options): + self.builtinshighlighting = get_bool_opt( + options, 'builtinshighlighting', True) + self.requiredelimiters = get_bool_opt( + options, 'requiredelimiters', False) + + self._builtins = set() + self._members = set() + if self.builtinshighlighting: + from pygments.lexers._lasso_builtins import BUILTINS, MEMBERS + for key, value in BUILTINS.items(): + self._builtins.update(value) + for key, value in MEMBERS.items(): + self._members.update(value) + RegexLexer.__init__(self, **options) + + def get_tokens_unprocessed(self, text): + stack = ['root'] + if self.requiredelimiters: + stack.append('delimiters') + for index, token, value in \ + RegexLexer.get_tokens_unprocessed(self, text, stack): + if (token is Name.Other and value.lower() in self._builtins or + token is Name.Other.Member and + value.lower().rstrip('=') in self._members): + yield index, Name.Builtin, value + continue + yield index, token, value + + def analyse_text(text): + rv = 0.0 + if 'bin/lasso9' in text: + rv += 0.8 + if re.search(r'<\?lasso', text, re.I): + rv += 0.4 + if re.search(r'local\(', text, re.I): + rv += 0.4 + return rv + + +class ObjectiveJLexer(RegexLexer): + """ + For Objective-J source code with preprocessor directives. + """ + + name = 'Objective-J' + aliases = ['objective-j', 'objectivej', 'obj-j', 'objj'] + filenames = ['*.j'] + mimetypes = ['text/x-objective-j'] + url = 'https://www.cappuccino.dev/learn/objective-j.html' + version_added = '1.3' + + #: optional Comment or Whitespace + _ws = r'(?:\s|//[^\n]*\n|/[*](?:[^*]|[*][^/])*[*]/)*' + + flags = re.DOTALL | re.MULTILINE + + tokens = { + 'root': [ + include('whitespace'), + + # function definition + (r'^(' + _ws + r'[+-]' + _ws + r')([(a-zA-Z_].*?[^(])(' + _ws + r'\{)', + bygroups(using(this), using(this, state='function_signature'), + using(this))), + + # class definition + (r'(@interface|@implementation)(\s+)', bygroups(Keyword, Whitespace), + 'classname'), + (r'(@class|@protocol)(\s*)', bygroups(Keyword, Whitespace), + 'forward_classname'), + (r'(\s*)(@end)(\s*)', bygroups(Whitespace, Keyword, Whitespace)), + + include('statements'), + ('[{()}]', Punctuation), + (';', Punctuation), + ], + 'whitespace': [ + (r'(@import)(\s+)("(?:\\\\|\\"|[^"])*")', + bygroups(Comment.Preproc, Whitespace, String.Double)), + (r'(@import)(\s+)(<(?:\\\\|\\>|[^>])*>)', + bygroups(Comment.Preproc, Whitespace, String.Double)), + (r'(#(?:include|import))(\s+)("(?:\\\\|\\"|[^"])*")', + bygroups(Comment.Preproc, Whitespace, String.Double)), + (r'(#(?:include|import))(\s+)(<(?:\\\\|\\>|[^>])*>)', + bygroups(Comment.Preproc, Whitespace, String.Double)), + + (r'#if\s+0', Comment.Preproc, 'if0'), + (r'#', Comment.Preproc, 'macro'), + + (r'\s+', Whitespace), + (r'(\\)(\n)', + bygroups(String.Escape, Whitespace)), # line continuation + (r'//(\n|(.|\n)*?[^\\]\n)', Comment.Single), + (r'/(\\\n)?[*](.|\n)*?[*](\\\n)?/', Comment.Multiline), + (r') + if attr_content[index-2:index] != '--': + break + start = index + 1 + + if index == -1: + # No tag end + yield from self.get_tokens_unprocessed(attr_content, stack=['root', 'attr']) + return + attr = attr_content[:index] + yield from self.get_tokens_unprocessed(attr, stack=['root', 'attr']) + yield match.start(3) + index, Punctuation, '>' + + lexer = None + content = attr_content[index+1:] + lang_match = re.findall(r'\blang=("|\'|)(\w+)(\1)', attr) + + if len(lang_match) >= 1: + # Pick the last match in case of multiple matches + lang = lang_match[-1][1] + try: + lexer = get_lexer_by_name(lang) + except ClassNotFound: + pass + + if lexer is None: + yield match.start() + index + 1, Text, content + else: + yield from lexer.get_tokens_unprocessed(content) + + def handle_score(self, match, ctx): + attr_content = match.group() + start = 0 + index = 0 + while True: + index = attr_content.find('>', start) + # Exclude comment end (-->) + if attr_content[index-2:index] != '--': + break + start = index + 1 + + if index == -1: + # No tag end + yield from self.get_tokens_unprocessed(attr_content, stack=['root', 'attr']) + return + attr = attr_content[:index] + content = attr_content[index+1:] + yield from self.get_tokens_unprocessed(attr, stack=['root', 'attr']) + yield match.start(3) + index, Punctuation, '>' + + lang_match = re.findall(r'\blang=("|\'|)(\w+)(\1)', attr) + # Pick the last match in case of multiple matches + lang = lang_match[-1][1] if len(lang_match) >= 1 else 'lilypond' + + if lang == 'lilypond': # Case sensitive + yield from LilyPondLexer().get_tokens_unprocessed(content) + else: # ABC + # FIXME: Use ABC lexer in the future + yield match.start() + index + 1, Text, content + + # a-z removed to prevent linter from complaining, REMEMBER to use (?i) + title_char = r' %!"$&\'()*,\-./0-9:;=?@A-Z\\\^_`~+\u0080-\uFFFF' + nbsp_char = r'(?:\t| |&\#0*160;|&\#[Xx]0*[Aa]0;|[ \xA0\u1680\u2000-\u200A\u202F\u205F\u3000])' + link_address = r'(?:[0-9.]+|\[[0-9a-f:.]+\]|[^\x00-\x20"<>\[\]\x7F\xA0\u1680\u2000-\u200A\u202F\u205F\u3000\uFFFD])' + link_char_class = r'[^\x00-\x20"<>\[\]\x7F\xA0\u1680\u2000-\u200A\u202F\u205F\u3000\uFFFD]' + double_slashes_i = { + '__FORCETOC__', '__NOCONTENTCONVERT__', '__NOCC__', '__NOEDITSECTION__', '__NOGALLERY__', + '__NOTITLECONVERT__', '__NOTC__', '__NOTOC__', '__TOC__', + } + double_slashes = { + '__EXPECTUNUSEDCATEGORY__', '__HIDDENCAT__', '__INDEX__', '__NEWSECTIONLINK__', + '__NOINDEX__', '__NONEWSECTIONLINK__', '__STATICREDIRECT__', '__NOGLOBAL__', + '__DISAMBIG__', '__EXPECTED_UNCONNECTED_PAGE__', + } + protocols = { + 'bitcoin:', 'ftp://', 'ftps://', 'geo:', 'git://', 'gopher://', 'http://', 'https://', + 'irc://', 'ircs://', 'magnet:', 'mailto:', 'mms://', 'news:', 'nntp://', 'redis://', + 'sftp://', 'sip:', 'sips:', 'sms:', 'ssh://', 'svn://', 'tel:', 'telnet://', 'urn:', + 'worldwind://', 'xmpp:', '//', + } + non_relative_protocols = protocols - {'//'} + html_tags = { + 'abbr', 'b', 'bdi', 'bdo', 'big', 'blockquote', 'br', 'caption', 'center', 'cite', 'code', + 'data', 'dd', 'del', 'dfn', 'div', 'dl', 'dt', 'em', 'font', 'h1', 'h2', 'h3', 'h4', 'h5', + 'h6', 'hr', 'i', 'ins', 'kbd', 'li', 'link', 'mark', 'meta', 'ol', 'p', 'q', 'rb', 'rp', + 'rt', 'rtc', 'ruby', 's', 'samp', 'small', 'span', 'strike', 'strong', 'sub', 'sup', + 'table', 'td', 'th', 'time', 'tr', 'tt', 'u', 'ul', 'var', 'wbr', + } + parser_tags = { + 'graph', 'charinsert', 'rss', 'chem', 'categorytree', 'nowiki', 'inputbox', 'math', + 'hiero', 'score', 'pre', 'ref', 'translate', 'imagemap', 'templatestyles', 'languages', + 'noinclude', 'mapframe', 'section', 'poem', 'syntaxhighlight', 'includeonly', 'tvar', + 'onlyinclude', 'templatedata', 'langconvert', 'timeline', 'dynamicpagelist', 'gallery', + 'maplink', 'ce', 'references', + } + variant_langs = { + # ZhConverter.php + 'zh', 'zh-hans', 'zh-hant', 'zh-cn', 'zh-hk', 'zh-mo', 'zh-my', 'zh-sg', 'zh-tw', + # WuuConverter.php + 'wuu', 'wuu-hans', 'wuu-hant', + # UzConverter.php + 'uz', 'uz-latn', 'uz-cyrl', + # TlyConverter.php + 'tly', 'tly-cyrl', + # TgConverter.php + 'tg', 'tg-latn', + # SrConverter.php + 'sr', 'sr-ec', 'sr-el', + # ShiConverter.php + 'shi', 'shi-tfng', 'shi-latn', + # ShConverter.php + 'sh-latn', 'sh-cyrl', + # KuConverter.php + 'ku', 'ku-arab', 'ku-latn', + # IuConverter.php + 'iu', 'ike-cans', 'ike-latn', + # GanConverter.php + 'gan', 'gan-hans', 'gan-hant', + # EnConverter.php + 'en', 'en-x-piglatin', + # CrhConverter.php + 'crh', 'crh-cyrl', 'crh-latn', + # BanConverter.php + 'ban', 'ban-bali', 'ban-x-dharma', 'ban-x-palmleaf', 'ban-x-pku', + } + magic_vars_i = { + 'ARTICLEPATH', 'INT', 'PAGEID', 'SCRIPTPATH', 'SERVER', 'SERVERNAME', 'STYLEPATH', + } + magic_vars = { + '!', '=', 'BASEPAGENAME', 'BASEPAGENAMEE', 'CASCADINGSOURCES', 'CONTENTLANGUAGE', + 'CONTENTLANG', 'CURRENTDAY', 'CURRENTDAY2', 'CURRENTDAYNAME', 'CURRENTDOW', 'CURRENTHOUR', + 'CURRENTMONTH', 'CURRENTMONTH2', 'CURRENTMONTH1', 'CURRENTMONTHABBREV', 'CURRENTMONTHNAME', + 'CURRENTMONTHNAMEGEN', 'CURRENTTIME', 'CURRENTTIMESTAMP', 'CURRENTVERSION', 'CURRENTWEEK', + 'CURRENTYEAR', 'DIRECTIONMARK', 'DIRMARK', 'FULLPAGENAME', 'FULLPAGENAMEE', 'LOCALDAY', + 'LOCALDAY2', 'LOCALDAYNAME', 'LOCALDOW', 'LOCALHOUR', 'LOCALMONTH', 'LOCALMONTH2', + 'LOCALMONTH1', 'LOCALMONTHABBREV', 'LOCALMONTHNAME', 'LOCALMONTHNAMEGEN', 'LOCALTIME', + 'LOCALTIMESTAMP', 'LOCALWEEK', 'LOCALYEAR', 'NAMESPACE', 'NAMESPACEE', 'NAMESPACENUMBER', + 'NUMBEROFACTIVEUSERS', 'NUMBEROFADMINS', 'NUMBEROFARTICLES', 'NUMBEROFEDITS', + 'NUMBEROFFILES', 'NUMBEROFPAGES', 'NUMBEROFUSERS', 'PAGELANGUAGE', 'PAGENAME', 'PAGENAMEE', + 'REVISIONDAY', 'REVISIONDAY2', 'REVISIONID', 'REVISIONMONTH', 'REVISIONMONTH1', + 'REVISIONSIZE', 'REVISIONTIMESTAMP', 'REVISIONUSER', 'REVISIONYEAR', 'ROOTPAGENAME', + 'ROOTPAGENAMEE', 'SITENAME', 'SUBJECTPAGENAME', 'ARTICLEPAGENAME', 'SUBJECTPAGENAMEE', + 'ARTICLEPAGENAMEE', 'SUBJECTSPACE', 'ARTICLESPACE', 'SUBJECTSPACEE', 'ARTICLESPACEE', + 'SUBPAGENAME', 'SUBPAGENAMEE', 'TALKPAGENAME', 'TALKPAGENAMEE', 'TALKSPACE', 'TALKSPACEE', + } + parser_functions_i = { + 'ANCHORENCODE', 'BIDI', 'CANONICALURL', 'CANONICALURLE', 'FILEPATH', 'FORMATNUM', + 'FULLURL', 'FULLURLE', 'GENDER', 'GRAMMAR', 'INT', r'\#LANGUAGE', 'LC', 'LCFIRST', 'LOCALURL', + 'LOCALURLE', 'NS', 'NSE', 'PADLEFT', 'PADRIGHT', 'PAGEID', 'PLURAL', 'UC', 'UCFIRST', + 'URLENCODE', + } + parser_functions = { + 'BASEPAGENAME', 'BASEPAGENAMEE', 'CASCADINGSOURCES', 'DEFAULTSORT', 'DEFAULTSORTKEY', + 'DEFAULTCATEGORYSORT', 'FULLPAGENAME', 'FULLPAGENAMEE', 'NAMESPACE', 'NAMESPACEE', + 'NAMESPACENUMBER', 'NUMBERINGROUP', 'NUMINGROUP', 'NUMBEROFACTIVEUSERS', 'NUMBEROFADMINS', + 'NUMBEROFARTICLES', 'NUMBEROFEDITS', 'NUMBEROFFILES', 'NUMBEROFPAGES', 'NUMBEROFUSERS', + 'PAGENAME', 'PAGENAMEE', 'PAGESINCATEGORY', 'PAGESINCAT', 'PAGESIZE', 'PROTECTIONEXPIRY', + 'PROTECTIONLEVEL', 'REVISIONDAY', 'REVISIONDAY2', 'REVISIONID', 'REVISIONMONTH', + 'REVISIONMONTH1', 'REVISIONTIMESTAMP', 'REVISIONUSER', 'REVISIONYEAR', 'ROOTPAGENAME', + 'ROOTPAGENAMEE', 'SUBJECTPAGENAME', 'ARTICLEPAGENAME', 'SUBJECTPAGENAMEE', + 'ARTICLEPAGENAMEE', 'SUBJECTSPACE', 'ARTICLESPACE', 'SUBJECTSPACEE', 'ARTICLESPACEE', + 'SUBPAGENAME', 'SUBPAGENAMEE', 'TALKPAGENAME', 'TALKPAGENAMEE', 'TALKSPACE', 'TALKSPACEE', + 'INT', 'DISPLAYTITLE', 'PAGESINNAMESPACE', 'PAGESINNS', + } + + tokens = { + 'root': [ + # Redirects + (r"""(?xi) + (\A\s*?)(\#REDIRECT:?) # may contain a colon + (\s+)(\[\[) (?=[^\]\n]* \]\]$) + """, + bygroups(Whitespace, Keyword, Whitespace, Punctuation), 'redirect-inner'), + # Subheadings + (r'^(={2,6})(.+?)(\1)(\s*$\n)', + bygroups(Generic.Subheading, Generic.Subheading, Generic.Subheading, Whitespace)), + # Headings + (r'^(=.+?=)(\s*$\n)', + bygroups(Generic.Heading, Whitespace)), + # Double-slashed magic words + (words(double_slashes_i, prefix=r'(?i)'), Name.Function.Magic), + (words(double_slashes), Name.Function.Magic), + # Raw URLs + (r'(?i)\b(?:{}){}{}*'.format('|'.join(protocols), + link_address, link_char_class), Name.Label), + # Magic links + (rf'\b(?:RFC|PMID){nbsp_char}+[0-9]+\b', + Name.Function.Magic), + (r"""(?x) + \bISBN {nbsp_char} + (?: 97[89] {nbsp_dash}? )? + (?: [0-9] {nbsp_dash}? ){{9}} # escape format() + [0-9Xx]\b + """.format(nbsp_char=nbsp_char, nbsp_dash=f'(?:-|{nbsp_char})'), Name.Function.Magic), + include('list'), + include('inline'), + include('text'), + ], + 'redirect-inner': [ + (r'(\]\])(\s*?\n)', bygroups(Punctuation, Whitespace), '#pop'), + (r'(\#)([^#]*?)', bygroups(Punctuation, Name.Label)), + (rf'(?i)[{title_char}]+', Name.Tag), + ], + 'list': [ + # Description lists + (r'^;', Keyword, 'dt'), + # Ordered lists, unordered lists and indents + (r'^[#:*]+', Keyword), + # Horizontal rules + (r'^-{4,}', Keyword), + ], + 'inline': [ + # Signatures + (r'~{3,5}', Keyword), + # Entities + include('entity'), + # Bold & italic + (r"('')(''')(?!')", bygroups(Generic.Emph, + Generic.EmphStrong), 'inline-italic-bold'), + (r"'''(?!')", Generic.Strong, 'inline-bold'), + (r"''(?!')", Generic.Emph, 'inline-italic'), + # Comments & parameters & templates + include('replaceable'), + # Media links + ( + r"""(?xi) + (\[\[) + (File|Image) (:) + ((?: [{}] | \{{{{2,3}}[^{{}}]*?\}}{{2,3}} | )*) + (?: (\#) ([{}]*?) )? + """.format(title_char, f'{title_char}#'), + bygroups(Punctuation, Name.Namespace, Punctuation, + using(this, state=['wikilink-name']), Punctuation, Name.Label), + 'medialink-inner' + ), + # Wikilinks + ( + r"""(?xi) + (\[\[)(?!{}) # Should not contain URLs + (?: ([{}]*) (:))? + ((?: [{}] | \{{{{2,3}}[^{{}}]*?\}}{{2,3}} | )*?) + (?: (\#) ([{}]*?) )? + (\]\]) + """.format('|'.join(protocols), title_char.replace('/', ''), + title_char, f'{title_char}#'), + bygroups(Punctuation, Name.Namespace, Punctuation, + using(this, state=['wikilink-name']), Punctuation, Name.Label, Punctuation) + ), + ( + r"""(?xi) + (\[\[)(?!{}) + (?: ([{}]*) (:))? + ((?: [{}] | \{{{{2,3}}[^{{}}]*?\}}{{2,3}} | )*?) + (?: (\#) ([{}]*?) )? + (\|) + """.format('|'.join(protocols), title_char.replace('/', ''), + title_char, f'{title_char}#'), + bygroups(Punctuation, Name.Namespace, Punctuation, + using(this, state=['wikilink-name']), Punctuation, Name.Label, Punctuation), + 'wikilink-inner' + ), + # External links + ( + r"""(?xi) + (\[) + ((?:{}) {} {}*) + (\s*) + """.format('|'.join(protocols), link_address, link_char_class), + bygroups(Punctuation, Name.Label, Whitespace), + 'extlink-inner' + ), + # Tables + (r'^(:*)(\s*?)(\{\|)([^\n]*)$', bygroups(Keyword, + Whitespace, Punctuation, using(this, state=['root', 'attr'])), 'table'), + # HTML tags + (r'(?i)(<)({})\b'.format('|'.join(html_tags)), + bygroups(Punctuation, Name.Tag), 'tag-inner-ordinary'), + (r'(?i)()'.format('|'.join(html_tags)), + bygroups(Punctuation, Name.Tag, Whitespace, Punctuation)), + # + (r'(?i)(<)(nowiki)\b', bygroups(Punctuation, + Name.Tag), ('tag-nowiki', 'tag-inner')), + #
+            (r'(?i)(<)(pre)\b', bygroups(Punctuation,
+             Name.Tag), ('tag-pre', 'tag-inner')),
+            # 
+            (r'(?i)(<)(categorytree)\b', bygroups(
+                Punctuation, Name.Tag), ('tag-categorytree', 'tag-inner')),
+            # 
+            (r'(?i)(<)(hiero)\b', bygroups(Punctuation,
+             Name.Tag), ('tag-hiero', 'tag-inner')),
+            # 
+            (r'(?i)(<)(math)\b', bygroups(Punctuation,
+             Name.Tag), ('tag-math', 'tag-inner')),
+            # 
+            (r'(?i)(<)(chem)\b', bygroups(Punctuation,
+             Name.Tag), ('tag-chem', 'tag-inner')),
+            # 
+            (r'(?i)(<)(ce)\b', bygroups(Punctuation,
+             Name.Tag), ('tag-ce', 'tag-inner')),
+            # 
+            (r'(?i)(<)(charinsert)\b', bygroups(
+                Punctuation, Name.Tag), ('tag-charinsert', 'tag-inner')),
+            # 
+            (r'(?i)(<)(templatedata)\b', bygroups(
+                Punctuation, Name.Tag), ('tag-templatedata', 'tag-inner')),
+            # 
+            (r'(?i)(<)(gallery)\b', bygroups(
+                Punctuation, Name.Tag), ('tag-gallery', 'tag-inner')),
+            # 
+            (r'(?i)(<)(gallery)\b', bygroups(
+                Punctuation, Name.Tag), ('tag-graph', 'tag-inner')),
+            # 
+            (r'(?i)(<)(dynamicpagelist)\b', bygroups(
+                Punctuation, Name.Tag), ('tag-dynamicpagelist', 'tag-inner')),
+            # 
+            (r'(?i)(<)(inputbox)\b', bygroups(
+                Punctuation, Name.Tag), ('tag-inputbox', 'tag-inner')),
+            # 
+            (r'(?i)(<)(rss)\b', bygroups(
+                Punctuation, Name.Tag), ('tag-rss', 'tag-inner')),
+            # 
+            (r'(?i)(<)(imagemap)\b', bygroups(
+                Punctuation, Name.Tag), ('tag-imagemap', 'tag-inner')),
+            # 
+            (r'(?i)()',
+             bygroups(Punctuation, Name.Tag, Whitespace, Punctuation)),
+            (r'(?si)(<)(syntaxhighlight)\b([^>]*?(?.*?)(?=)',
+             bygroups(Punctuation, Name.Tag, handle_syntaxhighlight)),
+            # : Fallback case for self-closing tags
+            (r'(?i)(<)(syntaxhighlight)\b(\s*?)((?:[^>]|-->)*?)(/\s*?(?)*?)(/\s*?(?)*?)(/\s*?(?|\Z)', Comment.Multiline),
+            # Parameters
+            (
+                r"""(?x)
+                (\{{3})
+                    ([^|]*?)
+                    (?=\}{3}|\|)
+                """,
+                bygroups(Punctuation, Name.Variable),
+                'parameter-inner',
+            ),
+            # Magic variables
+            (r'(?i)(\{{\{{)(\s*)({})(\s*)(\}}\}})'.format('|'.join(magic_vars_i)),
+             bygroups(Punctuation, Whitespace, Name.Function, Whitespace, Punctuation)),
+            (r'(\{{\{{)(\s*)({})(\s*)(\}}\}})'.format('|'.join(magic_vars)),
+                bygroups(Punctuation, Whitespace, Name.Function, Whitespace, Punctuation)),
+            # Parser functions & templates
+            (r'\{\{', Punctuation, 'template-begin-space'),
+            #  legacy syntax
+            (r'(?i)(<)(tvar)\b(\|)([^>]*?)(>)', bygroups(Punctuation,
+             Name.Tag, Punctuation, String, Punctuation)),
+            (r'', Punctuation, '#pop'),
+            # 
+            (r'(?i)(<)(tvar)\b', bygroups(Punctuation, Name.Tag), 'tag-inner-ordinary'),
+            (r'(?i)()',
+             bygroups(Punctuation, Name.Tag, Whitespace, Punctuation)),
+        ],
+        'parameter-inner': [
+            (r'\}{3}', Punctuation, '#pop'),
+            (r'\|', Punctuation),
+            include('inline'),
+            include('text'),
+        ],
+        'template-begin-space': [
+            # Templates allow line breaks at the beginning, and due to how MediaWiki handles
+            # comments, an extra state is required to handle things like {{\n\n name}}
+            (r'|\Z)', Comment.Multiline),
+            (r'\s+', Whitespace),
+            # Parser functions
+            (
+                r'(?i)(\#[{}]*?|{})(:)'.format(title_char,
+                                           '|'.join(parser_functions_i)),
+                bygroups(Name.Function, Punctuation), ('#pop', 'template-inner')
+            ),
+            (
+                r'({})(:)'.format('|'.join(parser_functions)),
+                bygroups(Name.Function, Punctuation), ('#pop', 'template-inner')
+            ),
+            # Templates
+            (
+                rf'(?i)([{title_char}]*?)(:)',
+                bygroups(Name.Namespace, Punctuation), ('#pop', 'template-name')
+            ),
+            default(('#pop', 'template-name'),),
+        ],
+        'template-name': [
+            (r'(\s*?)(\|)', bygroups(Text, Punctuation), ('#pop', 'template-inner')),
+            (r'\}\}', Punctuation, '#pop'),
+            (r'\n', Text, '#pop'),
+            include('replaceable'),
+            *text_rules(Name.Tag),
+        ],
+        'template-inner': [
+            (r'\}\}', Punctuation, '#pop'),
+            (r'\|', Punctuation),
+            (
+                r"""(?x)
+                    (?<=\|)
+                    ( (?: (?! \{\{ | \}\} )[^=\|<])*? ) # Exclude templates and tags
+                    (=)
+                """,
+                bygroups(Name.Label, Operator)
+            ),
+            include('inline'),
+            include('text'),
+        ],
+        'table': [
+            # Use [ \t\n\r\0\x0B] instead of \s to follow PHP trim() behavior
+            # Endings
+            (r'^([ \t\n\r\0\x0B]*?)(\|\})',
+             bygroups(Whitespace, Punctuation), '#pop'),
+            # Table rows
+            (r'^([ \t\n\r\0\x0B]*?)(\|-+)(.*)$', bygroups(Whitespace, Punctuation,
+             using(this, state=['root', 'attr']))),
+            # Captions
+            (
+                r"""(?x)
+                ^([ \t\n\r\0\x0B]*?)(\|\+)
+                # Exclude links, template and tags
+                (?: ( (?: (?! \[\[ | \{\{ )[^|\n<] )*? )(\|) )?
+                (.*?)$
+                """,
+                bygroups(Whitespace, Punctuation, using(this, state=[
+                         'root', 'attr']), Punctuation, Generic.Heading),
+            ),
+            # Table data
+            (
+                r"""(?x)
+                ( ^(?:[ \t\n\r\0\x0B]*?)\| | \|\| )
+                (?: ( (?: (?! \[\[ | \{\{ )[^|\n<] )*? )(\|)(?!\|) )?
+                """,
+                bygroups(Punctuation, using(this, state=[
+                         'root', 'attr']), Punctuation),
+            ),
+            # Table headers
+            (
+                r"""(?x)
+                ( ^(?:[ \t\n\r\0\x0B]*?)!  )
+                (?: ( (?: (?! \[\[ | \{\{ )[^|\n<] )*? )(\|)(?!\|) )?
+                """,
+                bygroups(Punctuation, using(this, state=[
+                         'root', 'attr']), Punctuation),
+                'table-header',
+            ),
+            include('list'),
+            include('inline'),
+            include('text'),
+        ],
+        'table-header': [
+            # Requires another state for || handling inside headers
+            (r'\n', Text, '#pop'),
+            (
+                r"""(?x)
+                (!!|\|\|)
+                (?:
+                    ( (?: (?! \[\[ | \{\{ )[^|\n<] )*? )
+                    (\|)(?!\|)
+                )?
+                """,
+                bygroups(Punctuation, using(this, state=[
+                         'root', 'attr']), Punctuation)
+            ),
+            *text_rules(Generic.Subheading),
+        ],
+        'entity': [
+            (r'&\S*?;', Name.Entity),
+        ],
+        'dt': [
+            (r'\n', Text, '#pop'),
+            include('inline'),
+            (r':', Keyword, '#pop'),
+            include('text'),
+        ],
+        'extlink-inner': [
+            (r'\]', Punctuation, '#pop'),
+            include('inline'),
+            include('text'),
+        ],
+        'nowiki-ish': [
+            include('entity'),
+            include('text'),
+        ],
+        'attr': [
+            include('replaceable'),
+            (r'\s+', Whitespace),
+            (r'(=)(\s*)(")', bygroups(Operator, Whitespace, String.Double), 'attr-val-2'),
+            (r"(=)(\s*)(')", bygroups(Operator, Whitespace, String.Single), 'attr-val-1'),
+            (r'(=)(\s*)', bygroups(Operator, Whitespace), 'attr-val-0'),
+            (r'[\w:-]+', Name.Attribute),
+
+        ],
+        'attr-val-0': [
+            (r'\s', Whitespace, '#pop'),
+            include('replaceable'),
+            *text_rules(String),
+        ],
+        'attr-val-1': [
+            (r"'", String.Single, '#pop'),
+            include('replaceable'),
+            *text_rules(String.Single),
+        ],
+        'attr-val-2': [
+            (r'"', String.Double, '#pop'),
+            include('replaceable'),
+            *text_rules(String.Double),
+        ],
+        'tag-inner-ordinary': [
+            (r'/?\s*>', Punctuation, '#pop'),
+            include('tag-attr'),
+        ],
+        'tag-inner': [
+            # Return to root state for self-closing tags
+            (r'/\s*>', Punctuation, '#pop:2'),
+            (r'\s*>', Punctuation, '#pop'),
+            include('tag-attr'),
+        ],
+        # There states below are just like their non-tag variants, the key difference is
+        # they forcibly quit when encountering tag closing markup
+        'tag-attr': [
+            include('replaceable'),
+            (r'\s+', Whitespace),
+            (r'(=)(\s*)(")', bygroups(Operator,
+             Whitespace, String.Double), 'tag-attr-val-2'),
+            (r"(=)(\s*)(')", bygroups(Operator,
+             Whitespace, String.Single), 'tag-attr-val-1'),
+            (r'(=)(\s*)', bygroups(Operator, Whitespace), 'tag-attr-val-0'),
+            (r'[\w:-]+', Name.Attribute),
+
+        ],
+        'tag-attr-val-0': [
+            (r'\s', Whitespace, '#pop'),
+            (r'/?>', Punctuation, '#pop:2'),
+            include('replaceable'),
+            *text_rules(String),
+        ],
+        'tag-attr-val-1': [
+            (r"'", String.Single, '#pop'),
+            (r'/?>', Punctuation, '#pop:2'),
+            include('replaceable'),
+            *text_rules(String.Single),
+        ],
+        'tag-attr-val-2': [
+            (r'"', String.Double, '#pop'),
+            (r'/?>', Punctuation, '#pop:2'),
+            include('replaceable'),
+            *text_rules(String.Double),
+        ],
+        'tag-nowiki': nowiki_tag_rules('nowiki'),
+        'tag-pre': nowiki_tag_rules('pre'),
+        'tag-categorytree': plaintext_tag_rules('categorytree'),
+        'tag-dynamicpagelist': plaintext_tag_rules('dynamicpagelist'),
+        'tag-hiero': plaintext_tag_rules('hiero'),
+        'tag-inputbox': plaintext_tag_rules('inputbox'),
+        'tag-imagemap': plaintext_tag_rules('imagemap'),
+        'tag-charinsert': plaintext_tag_rules('charinsert'),
+        'tag-timeline': plaintext_tag_rules('timeline'),
+        'tag-gallery': plaintext_tag_rules('gallery'),
+        'tag-graph': plaintext_tag_rules('graph'),
+        'tag-rss': plaintext_tag_rules('rss'),
+        'tag-math': delegate_tag_rules('math', TexLexer, state='math'),
+        'tag-chem': delegate_tag_rules('chem', TexLexer, state='math'),
+        'tag-ce': delegate_tag_rules('ce', TexLexer, state='math'),
+        'tag-templatedata': delegate_tag_rules('templatedata', JsonLexer),
+        'text-italic': text_rules(Generic.Emph),
+        'text-bold': text_rules(Generic.Strong),
+        'text-bold-italic': text_rules(Generic.EmphStrong),
+        'text': text_rules(Text),
+    }
diff --git a/pygments/lexers/math.py b/pygments/lexers/math.py
new file mode 100644
index 0000000000000000000000000000000000000000..b225ffcf935a31fca3a909f84b949fd7acb2cd90
--- /dev/null
+++ b/pygments/lexers/math.py
@@ -0,0 +1,21 @@
+"""
+    pygments.lexers.math
+    ~~~~~~~~~~~~~~~~~~~~
+
+    Just export lexers that were contained in this module.
+
+    :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS.
+    :license: BSD, see LICENSE for details.
+"""
+
+# ruff: noqa: F401
+from pygments.lexers.python import NumPyLexer
+from pygments.lexers.matlab import MatlabLexer, MatlabSessionLexer, \
+    OctaveLexer, ScilabLexer
+from pygments.lexers.julia import JuliaLexer, JuliaConsoleLexer
+from pygments.lexers.r import RConsoleLexer, SLexer, RdLexer
+from pygments.lexers.modeling import BugsLexer, JagsLexer, StanLexer
+from pygments.lexers.idl import IDLLexer
+from pygments.lexers.algebra import MuPADLexer
+
+__all__ = []
diff --git a/pygments/lexers/matlab.py b/pygments/lexers/matlab.py
new file mode 100644
index 0000000000000000000000000000000000000000..8eeffc9d23a0ad2b5ec186a8e3da29a02340b95e
--- /dev/null
+++ b/pygments/lexers/matlab.py
@@ -0,0 +1,3307 @@
+"""
+    pygments.lexers.matlab
+    ~~~~~~~~~~~~~~~~~~~~~~
+
+    Lexers for Matlab and related languages.
+
+    :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS.
+    :license: BSD, see LICENSE for details.
+"""
+
+import re
+
+from pygments.lexer import Lexer, RegexLexer, bygroups, default, words, \
+    do_insertions, include
+from pygments.token import Text, Comment, Operator, Keyword, Name, String, \
+    Number, Punctuation, Generic, Whitespace
+
+from pygments.lexers import _scilab_builtins
+
+__all__ = ['MatlabLexer', 'MatlabSessionLexer', 'OctaveLexer', 'ScilabLexer']
+
+
+class MatlabLexer(RegexLexer):
+    """
+    For Matlab source code.
+    """
+    name = 'Matlab'
+    aliases = ['matlab']
+    filenames = ['*.m']
+    mimetypes = ['text/matlab']
+    url = 'https://www.mathworks.com/products/matlab.html'
+    version_added = '0.10'
+
+    _operators = r'-|==|~=|<=|>=|<|>|&&|&|~|\|\|?|\.\*|\*|\+|\.\^|\^|\.\\|\./|/|\\'
+
+    tokens = {
+        'expressions': [
+            # operators:
+            (_operators, Operator),
+
+            # numbers (must come before punctuation to handle `.5`; cannot use
+            # `\b` due to e.g. `5. + .5`).  The negative lookahead on operators
+            # avoids including the dot in `1./x` (the dot is part of `./`).
+            (rf'(? and then
+            # (equal | open-parenthesis |  | ).
+            (rf'(?:^|(?<=;))(\s*)(\w+)(\s+)(?!=|\(|{_operators}\s|\s)',
+             bygroups(Whitespace, Name, Whitespace), 'commandargs'),
+
+            include('expressions')
+        ],
+        'blockcomment': [
+            (r'^\s*%\}', Comment.Multiline, '#pop'),
+            (r'^.*\n', Comment.Multiline),
+            (r'.', Comment.Multiline),
+        ],
+        'deffunc': [
+            (r'(\s*)(?:(\S+)(\s*)(=)(\s*))?(.+)(\()(.*)(\))(\s*)',
+             bygroups(Whitespace, Text, Whitespace, Punctuation,
+                      Whitespace, Name.Function, Punctuation, Text,
+                      Punctuation, Whitespace), '#pop'),
+            # function with no args
+            (r'(\s*)([a-zA-Z_]\w*)',
+             bygroups(Whitespace, Name.Function), '#pop'),
+        ],
+        'propattrs': [
+            (r'(\w+)(\s*)(=)(\s*)(\d+)',
+             bygroups(Name.Builtin, Whitespace, Punctuation, Whitespace,
+                      Number)),
+            (r'(\w+)(\s*)(=)(\s*)([a-zA-Z]\w*)',
+             bygroups(Name.Builtin, Whitespace, Punctuation, Whitespace,
+                      Keyword)),
+            (r',', Punctuation),
+            (r'\)', Punctuation, '#pop'),
+            (r'\s+', Whitespace),
+            (r'.', Text),
+        ],
+        'defprops': [
+            (r'%\{\s*\n', Comment.Multiline, 'blockcomment'),
+            (r'%.*$', Comment),
+            (r'(?.
+    """
+    name = 'Matlab session'
+    aliases = ['matlabsession']
+    url = 'https://www.mathworks.com/products/matlab.html'
+    version_added = '0.10'
+    _example = "matlabsession/matlabsession_sample.txt"
+
+    def get_tokens_unprocessed(self, text):
+        mlexer = MatlabLexer(**self.options)
+
+        curcode = ''
+        insertions = []
+        continuation = False
+
+        for match in line_re.finditer(text):
+            line = match.group()
+
+            if line.startswith('>> '):
+                insertions.append((len(curcode),
+                                   [(0, Generic.Prompt, line[:3])]))
+                curcode += line[3:]
+
+            elif line.startswith('>>'):
+                insertions.append((len(curcode),
+                                   [(0, Generic.Prompt, line[:2])]))
+                curcode += line[2:]
+
+            elif line.startswith('???'):
+
+                idx = len(curcode)
+
+                # without is showing error on same line as before...?
+                # line = "\n" + line
+                token = (0, Generic.Traceback, line)
+                insertions.append((idx, [token]))
+            elif continuation and insertions:
+                # line_start is the length of the most recent prompt symbol
+                line_start = len(insertions[-1][-1][-1])
+                # Set leading spaces with the length of the prompt to be a generic prompt
+                # This keeps code aligned when prompts are removed, say with some Javascript
+                if line.startswith(' '*line_start):
+                    insertions.append(
+                        (len(curcode), [(0, Generic.Prompt, line[:line_start])]))
+                    curcode += line[line_start:]
+                else:
+                    curcode += line
+            else:
+                if curcode:
+                    yield from do_insertions(
+                        insertions, mlexer.get_tokens_unprocessed(curcode))
+                    curcode = ''
+                    insertions = []
+
+                yield match.start(), Generic.Output, line
+
+            # Does not allow continuation if a comment is included after the ellipses.
+            # Continues any line that ends with ..., even comments (lines that start with %)
+            if line.strip().endswith('...'):
+                continuation = True
+            else:
+                continuation = False
+
+        if curcode:  # or item:
+            yield from do_insertions(
+                insertions, mlexer.get_tokens_unprocessed(curcode))
+
+
+class OctaveLexer(RegexLexer):
+    """
+    For GNU Octave source code.
+    """
+    name = 'Octave'
+    url = 'https://www.gnu.org/software/octave/index'
+    aliases = ['octave']
+    filenames = ['*.m']
+    mimetypes = ['text/octave']
+    version_added = '1.5'
+
+    # These lists are generated automatically.
+    # Run the following in bash shell:
+    #
+    # First dump all of the Octave manual into a plain text file:
+    #
+    #   $ info octave --subnodes -o octave-manual
+    #
+    # Now grep through it:
+
+    # for i in \
+    #     "Built-in Function" "Command" "Function File" \
+    #     "Loadable Function" "Mapping Function";
+    # do
+    #     perl -e '@name = qw('"$i"');
+    #              print lc($name[0]),"_kw = [\n"';
+    #
+    #     perl -n -e 'print "\"$1\",\n" if /-- '"$i"': .* (\w*) \(/;' \
+    #         octave-manual | sort | uniq ;
+    #     echo "]" ;
+    #     echo;
+    # done
+
+    # taken from Octave Mercurial changeset 8cc154f45e37 (30-jan-2011)
+
+    builtin_kw = (
+        "addlistener", "addpath", "addproperty", "all",
+        "and", "any", "argnames", "argv", "assignin",
+        "atexit", "autoload",
+        "available_graphics_toolkits", "beep_on_error",
+        "bitand", "bitmax", "bitor", "bitshift", "bitxor",
+        "cat", "cell", "cellstr", "char", "class", "clc",
+        "columns", "command_line_path",
+        "completion_append_char", "completion_matches",
+        "complex", "confirm_recursive_rmdir", "cputime",
+        "crash_dumps_octave_core", "ctranspose", "cumprod",
+        "cumsum", "debug_on_error", "debug_on_interrupt",
+        "debug_on_warning", "default_save_options",
+        "dellistener", "diag", "diff", "disp",
+        "doc_cache_file", "do_string_escapes", "double",
+        "drawnow", "e", "echo_executing_commands", "eps",
+        "eq", "errno", "errno_list", "error", "eval",
+        "evalin", "exec", "exist", "exit", "eye", "false",
+        "fclear", "fclose", "fcntl", "fdisp", "feof",
+        "ferror", "feval", "fflush", "fgetl", "fgets",
+        "fieldnames", "file_in_loadpath", "file_in_path",
+        "filemarker", "filesep", "find_dir_in_path",
+        "fixed_point_format", "fnmatch", "fopen", "fork",
+        "formula", "fprintf", "fputs", "fread", "freport",
+        "frewind", "fscanf", "fseek", "fskipl", "ftell",
+        "functions", "fwrite", "ge", "genpath", "get",
+        "getegid", "getenv", "geteuid", "getgid",
+        "getpgrp", "getpid", "getppid", "getuid", "glob",
+        "gt", "gui_mode", "history_control",
+        "history_file", "history_size",
+        "history_timestamp_format_string", "home",
+        "horzcat", "hypot", "ifelse",
+        "ignore_function_time_stamp", "inferiorto",
+        "info_file", "info_program", "inline", "input",
+        "intmax", "intmin", "ipermute",
+        "is_absolute_filename", "isargout", "isbool",
+        "iscell", "iscellstr", "ischar", "iscomplex",
+        "isempty", "isfield", "isfloat", "isglobal",
+        "ishandle", "isieee", "isindex", "isinteger",
+        "islogical", "ismatrix", "ismethod", "isnull",
+        "isnumeric", "isobject", "isreal",
+        "is_rooted_relative_filename", "issorted",
+        "isstruct", "isvarname", "kbhit", "keyboard",
+        "kill", "lasterr", "lasterror", "lastwarn",
+        "ldivide", "le", "length", "link", "linspace",
+        "logical", "lstat", "lt", "make_absolute_filename",
+        "makeinfo_program", "max_recursion_depth", "merge",
+        "methods", "mfilename", "minus", "mislocked",
+        "mkdir", "mkfifo", "mkstemp", "mldivide", "mlock",
+        "mouse_wheel_zoom", "mpower", "mrdivide", "mtimes",
+        "munlock", "nargin", "nargout",
+        "native_float_format", "ndims", "ne", "nfields",
+        "nnz", "norm", "not", "numel", "nzmax",
+        "octave_config_info", "octave_core_file_limit",
+        "octave_core_file_name",
+        "octave_core_file_options", "ones", "or",
+        "output_max_field_width", "output_precision",
+        "page_output_immediately", "page_screen_output",
+        "path", "pathsep", "pause", "pclose", "permute",
+        "pi", "pipe", "plus", "popen", "power",
+        "print_empty_dimensions", "printf",
+        "print_struct_array_contents", "prod",
+        "program_invocation_name", "program_name",
+        "putenv", "puts", "pwd", "quit", "rats", "rdivide",
+        "readdir", "readlink", "read_readline_init_file",
+        "realmax", "realmin", "rehash", "rename",
+        "repelems", "re_read_readline_init_file", "reset",
+        "reshape", "resize", "restoredefaultpath",
+        "rethrow", "rmdir", "rmfield", "rmpath", "rows",
+        "save_header_format_string", "save_precision",
+        "saving_history", "scanf", "set", "setenv",
+        "shell_cmd", "sighup_dumps_octave_core",
+        "sigterm_dumps_octave_core", "silent_functions",
+        "single", "size", "size_equal", "sizemax",
+        "sizeof", "sleep", "source", "sparse_auto_mutate",
+        "split_long_rows", "sprintf", "squeeze", "sscanf",
+        "stat", "stderr", "stdin", "stdout", "strcmp",
+        "strcmpi", "string_fill_char", "strncmp",
+        "strncmpi", "struct", "struct_levels_to_print",
+        "strvcat", "subsasgn", "subsref", "sum", "sumsq",
+        "superiorto", "suppress_verbose_help_message",
+        "symlink", "system", "tic", "tilde_expand",
+        "times", "tmpfile", "tmpnam", "toc", "toupper",
+        "transpose", "true", "typeinfo", "umask", "uminus",
+        "uname", "undo_string_escapes", "unlink", "uplus",
+        "upper", "usage", "usleep", "vec", "vectorize",
+        "vertcat", "waitpid", "warning", "warranty",
+        "whos_line_format", "yes_or_no", "zeros",
+        "inf", "Inf", "nan", "NaN")
+
+    command_kw = ("close", "load", "who", "whos")
+
+    function_kw = (
+        "accumarray", "accumdim", "acosd", "acotd",
+        "acscd", "addtodate", "allchild", "ancestor",
+        "anova", "arch_fit", "arch_rnd", "arch_test",
+        "area", "arma_rnd", "arrayfun", "ascii", "asctime",
+        "asecd", "asind", "assert", "atand",
+        "autoreg_matrix", "autumn", "axes", "axis", "bar",
+        "barh", "bartlett", "bartlett_test", "beep",
+        "betacdf", "betainv", "betapdf", "betarnd",
+        "bicgstab", "bicubic", "binary", "binocdf",
+        "binoinv", "binopdf", "binornd", "bitcmp",
+        "bitget", "bitset", "blackman", "blanks",
+        "blkdiag", "bone", "box", "brighten", "calendar",
+        "cast", "cauchy_cdf", "cauchy_inv", "cauchy_pdf",
+        "cauchy_rnd", "caxis", "celldisp", "center", "cgs",
+        "chisquare_test_homogeneity",
+        "chisquare_test_independence", "circshift", "cla",
+        "clabel", "clf", "clock", "cloglog", "closereq",
+        "colon", "colorbar", "colormap", "colperm",
+        "comet", "common_size", "commutation_matrix",
+        "compan", "compare_versions", "compass",
+        "computer", "cond", "condest", "contour",
+        "contourc", "contourf", "contrast", "conv",
+        "convhull", "cool", "copper", "copyfile", "cor",
+        "corrcoef", "cor_test", "cosd", "cotd", "cov",
+        "cplxpair", "cross", "cscd", "cstrcat", "csvread",
+        "csvwrite", "ctime", "cumtrapz", "curl", "cut",
+        "cylinder", "date", "datenum", "datestr",
+        "datetick", "datevec", "dblquad", "deal",
+        "deblank", "deconv", "delaunay", "delaunayn",
+        "delete", "demo", "detrend", "diffpara", "diffuse",
+        "dir", "discrete_cdf", "discrete_inv",
+        "discrete_pdf", "discrete_rnd", "display",
+        "divergence", "dlmwrite", "dos", "dsearch",
+        "dsearchn", "duplication_matrix", "durbinlevinson",
+        "ellipsoid", "empirical_cdf", "empirical_inv",
+        "empirical_pdf", "empirical_rnd", "eomday",
+        "errorbar", "etime", "etreeplot", "example",
+        "expcdf", "expinv", "expm", "exppdf", "exprnd",
+        "ezcontour", "ezcontourf", "ezmesh", "ezmeshc",
+        "ezplot", "ezpolar", "ezsurf", "ezsurfc", "factor",
+        "factorial", "fail", "fcdf", "feather", "fftconv",
+        "fftfilt", "fftshift", "figure", "fileattrib",
+        "fileparts", "fill", "findall", "findobj",
+        "findstr", "finv", "flag", "flipdim", "fliplr",
+        "flipud", "fpdf", "fplot", "fractdiff", "freqz",
+        "freqz_plot", "frnd", "fsolve",
+        "f_test_regression", "ftp", "fullfile", "fzero",
+        "gamcdf", "gaminv", "gampdf", "gamrnd", "gca",
+        "gcbf", "gcbo", "gcf", "genvarname", "geocdf",
+        "geoinv", "geopdf", "geornd", "getfield", "ginput",
+        "glpk", "gls", "gplot", "gradient",
+        "graphics_toolkit", "gray", "grid", "griddata",
+        "griddatan", "gtext", "gunzip", "gzip", "hadamard",
+        "hamming", "hankel", "hanning", "hggroup",
+        "hidden", "hilb", "hist", "histc", "hold", "hot",
+        "hotelling_test", "housh", "hsv", "hurst",
+        "hygecdf", "hygeinv", "hygepdf", "hygernd",
+        "idivide", "ifftshift", "image", "imagesc",
+        "imfinfo", "imread", "imshow", "imwrite", "index",
+        "info", "inpolygon", "inputname", "interpft",
+        "interpn", "intersect", "invhilb", "iqr", "isa",
+        "isdefinite", "isdir", "is_duplicate_entry",
+        "isequal", "isequalwithequalnans", "isfigure",
+        "ishermitian", "ishghandle", "is_leap_year",
+        "isletter", "ismac", "ismember", "ispc", "isprime",
+        "isprop", "isscalar", "issquare", "isstrprop",
+        "issymmetric", "isunix", "is_valid_file_id",
+        "isvector", "jet", "kendall",
+        "kolmogorov_smirnov_cdf",
+        "kolmogorov_smirnov_test", "kruskal_wallis_test",
+        "krylov", "kurtosis", "laplace_cdf", "laplace_inv",
+        "laplace_pdf", "laplace_rnd", "legend", "legendre",
+        "license", "line", "linkprop", "list_primes",
+        "loadaudio", "loadobj", "logistic_cdf",
+        "logistic_inv", "logistic_pdf", "logistic_rnd",
+        "logit", "loglog", "loglogerr", "logm", "logncdf",
+        "logninv", "lognpdf", "lognrnd", "logspace",
+        "lookfor", "ls_command", "lsqnonneg", "magic",
+        "mahalanobis", "manova", "matlabroot",
+        "mcnemar_test", "mean", "meansq", "median", "menu",
+        "mesh", "meshc", "meshgrid", "meshz", "mexext",
+        "mget", "mkpp", "mode", "moment", "movefile",
+        "mpoles", "mput", "namelengthmax", "nargchk",
+        "nargoutchk", "nbincdf", "nbininv", "nbinpdf",
+        "nbinrnd", "nchoosek", "ndgrid", "newplot", "news",
+        "nonzeros", "normcdf", "normest", "norminv",
+        "normpdf", "normrnd", "now", "nthroot", "null",
+        "ocean", "ols", "onenormest", "optimget",
+        "optimset", "orderfields", "orient", "orth",
+        "pack", "pareto", "parseparams", "pascal", "patch",
+        "pathdef", "pcg", "pchip", "pcolor", "pcr",
+        "peaks", "periodogram", "perl", "perms", "pie",
+        "pink", "planerot", "playaudio", "plot",
+        "plotmatrix", "plotyy", "poisscdf", "poissinv",
+        "poisspdf", "poissrnd", "polar", "poly",
+        "polyaffine", "polyarea", "polyderiv", "polyfit",
+        "polygcd", "polyint", "polyout", "polyreduce",
+        "polyval", "polyvalm", "postpad", "powerset",
+        "ppder", "ppint", "ppjumps", "ppplot", "ppval",
+        "pqpnonneg", "prepad", "primes", "print",
+        "print_usage", "prism", "probit", "qp", "qqplot",
+        "quadcc", "quadgk", "quadl", "quadv", "quiver",
+        "qzhess", "rainbow", "randi", "range", "rank",
+        "ranks", "rat", "reallog", "realpow", "realsqrt",
+        "record", "rectangle_lw", "rectangle_sw",
+        "rectint", "refresh", "refreshdata",
+        "regexptranslate", "repmat", "residue", "ribbon",
+        "rindex", "roots", "rose", "rosser", "rotdim",
+        "rref", "run", "run_count", "rundemos", "run_test",
+        "runtests", "saveas", "saveaudio", "saveobj",
+        "savepath", "scatter", "secd", "semilogx",
+        "semilogxerr", "semilogy", "semilogyerr",
+        "setaudio", "setdiff", "setfield", "setxor",
+        "shading", "shift", "shiftdim", "sign_test",
+        "sinc", "sind", "sinetone", "sinewave", "skewness",
+        "slice", "sombrero", "sortrows", "spaugment",
+        "spconvert", "spdiags", "spearman", "spectral_adf",
+        "spectral_xdf", "specular", "speed", "spencer",
+        "speye", "spfun", "sphere", "spinmap", "spline",
+        "spones", "sprand", "sprandn", "sprandsym",
+        "spring", "spstats", "spy", "sqp", "stairs",
+        "statistics", "std", "stdnormal_cdf",
+        "stdnormal_inv", "stdnormal_pdf", "stdnormal_rnd",
+        "stem", "stft", "strcat", "strchr", "strjust",
+        "strmatch", "strread", "strsplit", "strtok",
+        "strtrim", "strtrunc", "structfun", "studentize",
+        "subplot", "subsindex", "subspace", "substr",
+        "substruct", "summer", "surf", "surface", "surfc",
+        "surfl", "surfnorm", "svds", "swapbytes",
+        "sylvester_matrix", "symvar", "synthesis", "table",
+        "tand", "tar", "tcdf", "tempdir", "tempname",
+        "test", "text", "textread", "textscan", "tinv",
+        "title", "toeplitz", "tpdf", "trace", "trapz",
+        "treelayout", "treeplot", "triangle_lw",
+        "triangle_sw", "tril", "trimesh", "triplequad",
+        "triplot", "trisurf", "triu", "trnd", "tsearchn",
+        "t_test", "t_test_regression", "type", "unidcdf",
+        "unidinv", "unidpdf", "unidrnd", "unifcdf",
+        "unifinv", "unifpdf", "unifrnd", "union", "unique",
+        "unix", "unmkpp", "unpack", "untabify", "untar",
+        "unwrap", "unzip", "u_test", "validatestring",
+        "vander", "var", "var_test", "vech", "ver",
+        "version", "view", "voronoi", "voronoin",
+        "waitforbuttonpress", "wavread", "wavwrite",
+        "wblcdf", "wblinv", "wblpdf", "wblrnd", "weekday",
+        "welch_test", "what", "white", "whitebg",
+        "wienrnd", "wilcoxon_test", "wilkinson", "winter",
+        "xlabel", "xlim", "ylabel", "yulewalker", "zip",
+        "zlabel", "z_test")
+
+    loadable_kw = (
+        "airy", "amd", "balance", "besselh", "besseli",
+        "besselj", "besselk", "bessely", "bitpack",
+        "bsxfun", "builtin", "ccolamd", "cellfun",
+        "cellslices", "chol", "choldelete", "cholinsert",
+        "cholinv", "cholshift", "cholupdate", "colamd",
+        "colloc", "convhulln", "convn", "csymamd",
+        "cummax", "cummin", "daspk", "daspk_options",
+        "dasrt", "dasrt_options", "dassl", "dassl_options",
+        "dbclear", "dbdown", "dbstack", "dbstatus",
+        "dbstop", "dbtype", "dbup", "dbwhere", "det",
+        "dlmread", "dmperm", "dot", "eig", "eigs",
+        "endgrent", "endpwent", "etree", "fft", "fftn",
+        "fftw", "filter", "find", "full", "gcd",
+        "getgrent", "getgrgid", "getgrnam", "getpwent",
+        "getpwnam", "getpwuid", "getrusage", "givens",
+        "gmtime", "gnuplot_binary", "hess", "ifft",
+        "ifftn", "inv", "isdebugmode", "issparse", "kron",
+        "localtime", "lookup", "lsode", "lsode_options",
+        "lu", "luinc", "luupdate", "matrix_type", "max",
+        "min", "mktime", "pinv", "qr", "qrdelete",
+        "qrinsert", "qrshift", "qrupdate", "quad",
+        "quad_options", "qz", "rand", "rande", "randg",
+        "randn", "randp", "randperm", "rcond", "regexp",
+        "regexpi", "regexprep", "schur", "setgrent",
+        "setpwent", "sort", "spalloc", "sparse", "spparms",
+        "sprank", "sqrtm", "strfind", "strftime",
+        "strptime", "strrep", "svd", "svd_driver", "syl",
+        "symamd", "symbfact", "symrcm", "time", "tsearch",
+        "typecast", "urlread", "urlwrite")
+
+    mapping_kw = (
+        "abs", "acos", "acosh", "acot", "acoth", "acsc",
+        "acsch", "angle", "arg", "asec", "asech", "asin",
+        "asinh", "atan", "atanh", "beta", "betainc",
+        "betaln", "bincoeff", "cbrt", "ceil", "conj", "cos",
+        "cosh", "cot", "coth", "csc", "csch", "erf", "erfc",
+        "erfcx", "erfinv", "exp", "finite", "fix", "floor",
+        "fmod", "gamma", "gammainc", "gammaln", "imag",
+        "isalnum", "isalpha", "isascii", "iscntrl",
+        "isdigit", "isfinite", "isgraph", "isinf",
+        "islower", "isna", "isnan", "isprint", "ispunct",
+        "isspace", "isupper", "isxdigit", "lcm", "lgamma",
+        "log", "lower", "mod", "real", "rem", "round",
+        "roundb", "sec", "sech", "sign", "sin", "sinh",
+        "sqrt", "tan", "tanh", "toascii", "tolower", "xor")
+
+    builtin_consts = (
+        "EDITOR", "EXEC_PATH", "I", "IMAGE_PATH", "NA",
+        "OCTAVE_HOME", "OCTAVE_VERSION", "PAGER",
+        "PAGER_FLAGS", "SEEK_CUR", "SEEK_END", "SEEK_SET",
+        "SIG", "S_ISBLK", "S_ISCHR", "S_ISDIR", "S_ISFIFO",
+        "S_ISLNK", "S_ISREG", "S_ISSOCK", "WCONTINUE",
+        "WCOREDUMP", "WEXITSTATUS", "WIFCONTINUED",
+        "WIFEXITED", "WIFSIGNALED", "WIFSTOPPED", "WNOHANG",
+        "WSTOPSIG", "WTERMSIG", "WUNTRACED")
+
+    tokens = {
+        'root': [
+            (r'%\{\s*\n', Comment.Multiline, 'percentblockcomment'),
+            (r'#\{\s*\n', Comment.Multiline, 'hashblockcomment'),
+            (r'[%#].*$', Comment),
+            (r'^\s*function\b', Keyword, 'deffunc'),
+
+            # from 'iskeyword' on hg changeset 8cc154f45e37
+            (words((
+                '__FILE__', '__LINE__', 'break', 'case', 'catch', 'classdef',
+                'continue', 'do', 'else', 'elseif', 'end', 'end_try_catch',
+                'end_unwind_protect', 'endclassdef', 'endevents', 'endfor',
+                'endfunction', 'endif', 'endmethods', 'endproperties', 'endswitch',
+                'endwhile', 'events', 'for', 'function', 'get', 'global', 'if',
+                'methods', 'otherwise', 'persistent', 'properties', 'return',
+                'set', 'static', 'switch', 'try', 'until', 'unwind_protect',
+                'unwind_protect_cleanup', 'while'), suffix=r'\b'),
+             Keyword),
+
+            (words(builtin_kw + command_kw + function_kw + loadable_kw + mapping_kw,
+                   suffix=r'\b'),  Name.Builtin),
+
+            (words(builtin_consts, suffix=r'\b'), Name.Constant),
+
+            # operators in Octave but not Matlab:
+            (r'-=|!=|!|/=|--', Operator),
+            # operators:
+            (r'-|==|~=|<|>|<=|>=|&&|&|~|\|\|?', Operator),
+            # operators in Octave but not Matlab requiring escape for re:
+            (r'\*=|\+=|\^=|\/=|\\=|\*\*|\+\+|\.\*\*', Operator),
+            # operators requiring escape for re:
+            (r'\.\*|\*|\+|\.\^|\^|\.\\|\.\/|\/|\\', Operator),
+
+
+            # punctuation:
+            (r'[\[\](){}:@.,]', Punctuation),
+            (r'=|:|;', Punctuation),
+
+            (r'"[^"]*"', String),
+
+            (r'(\d+\.\d*|\d*\.\d+)([eEf][+-]?[0-9]+)?', Number.Float),
+            (r'\d+[eEf][+-]?[0-9]+', Number.Float),
+            (r'\d+', Number.Integer),
+
+            # quote can be transpose, instead of string:
+            # (not great, but handles common cases...)
+            (r'(?<=[\w)\].])\'+', Operator),
+            (r'(?|<=|>=|&&|&|~|\|\|?', Operator),
+            # operators requiring escape for re:
+            (r'\.\*|\*|\+|\.\^|\^|\.\\|\.\/|\/|\\', Operator),
+
+            # punctuation:
+            (r'[\[\](){}@.,=:;]+', Punctuation),
+
+            (r'"[^"]*"', String),
+
+            # quote can be transpose, instead of string:
+            # (not great, but handles common cases...)
+            (r'(?<=[\w)\].])\'+', Operator),
+            (r'(?', r'<', r'|', r'!', r"'")
+
+    operator_words = ('and', 'or', 'not')
+
+    tokens = {
+        'root': [
+            (r'/\*', Comment.Multiline, 'comment'),
+            (r'"(?:[^"\\]|\\.)*"', String),
+            (r'\(|\)|\[|\]|\{|\}', Punctuation),
+            (r'[,;$]', Punctuation),
+            (words (constants), Name.Constant),
+            (words (keywords), Keyword),
+            (words (operators), Operator),
+            (words (operator_words), Operator.Word),
+            (r'''(?x)
+              ((?:[a-zA-Z_#][\w#]*|`[^`]*`)
+              (?:::[a-zA-Z_#][\w#]*|`[^`]*`)*)(\s*)([(])''',
+             bygroups(Name.Function, Text.Whitespace, Punctuation)),
+            (r'''(?x)
+              (?:[a-zA-Z_#%][\w#%]*|`[^`]*`)
+              (?:::[a-zA-Z_#%][\w#%]*|`[^`]*`)*''', Name.Variable),
+            (r'[-+]?(\d*\.\d+([bdefls][-+]?\d+)?|\d+(\.\d*)?[bdefls][-+]?\d+)', Number.Float),
+            (r'[-+]?\d+', Number.Integer),
+            (r'\s+', Text.Whitespace),
+            (r'.', Text)
+        ],
+        'comment': [
+            (r'[^*/]+', Comment.Multiline),
+            (r'/\*', Comment.Multiline, '#push'),
+            (r'\*/', Comment.Multiline, '#pop'),
+            (r'[*/]', Comment.Multiline)
+        ]
+    }
+
+    def analyse_text (text):
+        strength = 0.0
+        # Input expression terminator.
+        if re.search (r'\$\s*$', text, re.MULTILINE):
+            strength += 0.05
+        # Function definition operator.
+        if ':=' in text:
+            strength += 0.02
+        return strength
diff --git a/pygments/lexers/meson.py b/pygments/lexers/meson.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f2c6da374b0e83fc3adc2c577c77884b7d09923
--- /dev/null
+++ b/pygments/lexers/meson.py
@@ -0,0 +1,139 @@
+"""
+    pygments.lexers.meson
+    ~~~~~~~~~~~~~~~~~~~~~
+
+    Pygments lexer for the Meson build system
+
+    :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS.
+    :license: BSD, see LICENSE for details.
+"""
+
+from pygments.lexer import RegexLexer, words, include
+from pygments.token import Comment, Name, Number, Punctuation, Operator, \
+    Keyword, String, Whitespace
+
+__all__ = ['MesonLexer']
+
+
+class MesonLexer(RegexLexer):
+    """Meson language lexer.
+
+    The grammar definition use to transcribe the syntax was retrieved from
+    https://mesonbuild.com/Syntax.html#grammar for version 0.58.
+    Some of those definitions are improperly transcribed, so the Meson++
+    implementation was also checked: https://github.com/dcbaker/meson-plus-plus.
+    """
+
+    # TODO String interpolation @VARNAME@ inner matches
+    # TODO keyword_arg: value inner matches
+
+    name = 'Meson'
+    url = 'https://mesonbuild.com/'
+    aliases = ['meson', 'meson.build']
+    filenames = ['meson.build', 'meson_options.txt']
+    mimetypes = ['text/x-meson']
+    version_added = '2.10'
+
+    tokens = {
+        'root': [
+            (r'#.*?$', Comment),
+            (r"'''.*'''", String.Single),
+            (r'[1-9][0-9]*', Number.Integer),
+            (r'0o[0-7]+', Number.Oct),
+            (r'0x[a-fA-F0-9]+', Number.Hex),
+            include('string'),
+            include('keywords'),
+            include('expr'),
+            (r'[a-zA-Z_][a-zA-Z_0-9]*', Name),
+            (r'\s+', Whitespace),
+        ],
+        'string': [
+            (r"[']{3}([']{0,2}([^\\']|\\(.|\n)))*[']{3}", String),
+            (r"'.*?(?`_.
+    """
+
+    name = "MCFunction"
+    url = "https://minecraft.wiki/w/Commands"
+    aliases = ["mcfunction", "mcf"]
+    filenames = ["*.mcfunction"]
+    mimetypes = ["text/mcfunction"]
+    version_added = '2.12'
+
+    # Used to denotate the start of a block comment, borrowed from Github's mcfunction
+    _block_comment_prefix = "[>!]"
+
+    tokens = {
+        "root": [
+            include("names"),
+            include("comments"),
+            include("literals"),
+            include("whitespace"),
+            include("property"),
+            include("operators"),
+            include("selectors"),
+        ],
+
+        "names": [
+            # The start of a command (either beginning of line OR after the run keyword)
+            #  We don't encode a list of keywords since mods, plugins, or even pre-processors
+            #  may add new commands, so we have a 'close-enough' regex which catches them.
+            (r"^(\s*)([a-z_]+)", bygroups(Whitespace, Name.Builtin)),
+            (r"(?<=run)\s+[a-z_]+", Name.Builtin),
+
+            # UUID
+            (r"\b[0-9a-fA-F]+(?:-[0-9a-fA-F]+){4}\b", Name.Variable),
+            include("resource-name"),
+            # normal command names and scoreboards
+            #  there's no way to know the differences unfortuntely
+            (r"[A-Za-z_][\w.#%$]+", Keyword.Constant),
+            (r"[#%$][\w.#%$]+", Name.Variable.Magic),
+        ],
+
+        "resource-name": [
+            # resource names have to be lowercase
+            (r"#?[a-z_][a-z_.-]*:[a-z0-9_./-]+", Name.Function),
+            # similar to above except optional `:``
+            #  a `/` must be present "somewhere"
+            (r"#?[a-z0-9_\.\-]+\/[a-z0-9_\.\-\/]+", Name.Function),
+        ],
+
+        "whitespace": [
+            (r"\s+", Whitespace),
+        ],
+
+        "comments": [
+            (rf"^\s*(#{_block_comment_prefix})", Comment.Multiline,
+             ("comments.block", "comments.block.emphasized")),
+            (r"#.*$", Comment.Single),
+        ],
+        "comments.block": [
+            (rf"^\s*#{_block_comment_prefix}", Comment.Multiline,
+             "comments.block.emphasized"),
+            (r"^\s*#", Comment.Multiline, "comments.block.normal"),
+            default("#pop"),
+        ],
+        "comments.block.normal": [
+            include("comments.block.special"),
+            (r"\S+", Comment.Multiline),
+            (r"\n", Text, "#pop"),
+            include("whitespace"),
+        ],
+        "comments.block.emphasized": [
+            include("comments.block.special"),
+            (r"\S+", String.Doc),
+            (r"\n", Text, "#pop"),
+            include("whitespace"),
+        ],
+        "comments.block.special": [
+            # Params
+            (r"@\S+", Name.Decorator),
+
+            include("resource-name"),
+
+            # Scoreboard player names
+            (r"[#%$][\w.#%$]+", Name.Variable.Magic),
+        ],
+
+        "operators": [
+            (r"[\-~%^?!+*<>\\/|&=.]", Operator),
+        ],
+
+        "literals": [
+            (r"\.\.", Literal),
+            (r"(true|false)", Keyword.Pseudo),
+
+            # these are like unquoted strings and appear in many places
+            (r"[A-Za-z_]+", Name.Variable.Class),
+
+            (r"[0-7]b", Number.Byte),
+            (r"[+-]?\d*\.?\d+([eE]?[+-]?\d+)?[df]?\b", Number.Float),
+            (r"[+-]?\d+\b", Number.Integer),
+            (r'"', String.Double, "literals.string-double"),
+            (r"'", String.Single, "literals.string-single"),
+        ],
+        "literals.string-double": [
+            (r"\\.", String.Escape),
+            (r'[^\\"\n]+', String.Double),
+            (r'"', String.Double, "#pop"),
+        ],
+        "literals.string-single": [
+            (r"\\.", String.Escape),
+            (r"[^\\'\n]+", String.Single),
+            (r"'", String.Single, "#pop"),
+        ],
+
+        "selectors": [
+            (r"@[a-z]", Name.Variable),
+        ],
+
+
+        ## Generic Property Container
+        # There are several, differing instances where the language accepts
+        #  specific contained keys or contained key, value pairings.
+        #
+        # Property Maps:
+        # - Starts with either `[` or `{`
+        # - Key separated by `:` or `=`
+        # - Deliminated by `,`
+        #
+        # Property Lists:
+        # - Starts with `[`
+        # - Deliminated by `,`
+        #
+        # For simplicity, these patterns match a generic, nestable structure
+        #  which follow a key, value pattern. For normal lists, there's only keys.
+        # This allow some "illegal" structures, but we'll accept those for
+        #  sake of simplicity
+        #
+        # Examples:
+        # - `[facing=up, powered=true]` (blockstate)
+        # - `[name="hello world", nbt={key: 1b}]` (selector + nbt)
+        # - `[{"text": "value"}, "literal"]` (json)
+        ##
+        "property": [
+            # This state gets included in root and also several substates
+            # We do this to shortcut the starting of new properties
+            #  within other properties. Lists can have sublists and compounds
+            #  and values can start a new property (see the `difficult_1.txt`
+            #  snippet).
+            (r"\{", Punctuation, ("property.curly", "property.key")),
+            (r"\[", Punctuation, ("property.square", "property.key")),
+        ],
+        "property.curly": [
+            include("whitespace"),
+            include("property"),
+            (r"\}", Punctuation, "#pop"),
+        ],
+        "property.square": [
+            include("whitespace"),
+            include("property"),
+            (r"\]", Punctuation, "#pop"),
+
+            # lists can have sequences of items
+            (r",", Punctuation),
+        ],
+        "property.key": [
+            include("whitespace"),
+
+            # resource names (for advancements)
+            #  can omit `:` to default `minecraft:`
+            # must check if there is a future equals sign if `:` is in the name
+            (r"#?[a-z_][a-z_\.\-]*\:[a-z0-9_\.\-/]+(?=\s*\=)", Name.Attribute, "property.delimiter"),
+            (r"#?[a-z_][a-z0-9_\.\-/]+", Name.Attribute, "property.delimiter"),
+
+            # unquoted NBT key
+            (r"[A-Za-z_\-\+]+", Name.Attribute, "property.delimiter"),
+
+            # quoted JSON or NBT key
+            (r'"', Name.Attribute, "property.delimiter", "literals.string-double"),
+            (r"'", Name.Attribute, "property.delimiter", "literals.string-single"),
+
+            # index for a list
+            (r"-?\d+", Number.Integer, "property.delimiter"),
+
+            default("#pop"),
+        ],
+        "property.key.string-double": [
+            (r"\\.", String.Escape),
+            (r'[^\\"\n]+', Name.Attribute),
+            (r'"', Name.Attribute, "#pop"),
+        ],
+        "property.key.string-single": [
+            (r"\\.", String.Escape),
+            (r"[^\\'\n]+", Name.Attribute),
+            (r"'", Name.Attribute, "#pop"),
+        ],
+        "property.delimiter": [
+            include("whitespace"),
+
+            (r"[:=]!?", Punctuation, "property.value"),
+            (r",", Punctuation),
+
+            default("#pop"),
+        ],
+        "property.value": [
+            include("whitespace"),
+
+            # unquoted resource names are valid literals here
+            (r"#?[a-z_][a-z_\.\-]*\:[a-z0-9_\.\-/]+", Name.Tag),
+            (r"#?[a-z_][a-z0-9_\.\-/]+", Name.Tag),
+
+            include("literals"),
+            include("property"),
+
+            default("#pop"),
+        ],
+    }
+
+
+class MCSchemaLexer(RegexLexer):
+    """Lexer for Minecraft Add-ons data Schemas, an interface structure standard used in Minecraft
+    """
+
+    name = 'MCSchema'
+    url = 'https://learn.microsoft.com/en-us/minecraft/creator/reference/content/schemasreference/'
+    aliases = ['mcschema']
+    filenames = ['*.mcschema']
+    mimetypes = ['text/mcschema']
+    version_added = '2.14'
+
+    tokens = {
+        'commentsandwhitespace': [
+            (r'\s+', Whitespace),
+            (r'//.*?$', Comment.Single),
+            (r'/\*.*?\*/', Comment.Multiline)
+        ],
+        'slashstartsregex': [
+            include('commentsandwhitespace'),
+            (r'/(\\.|[^[/\\\n]|\[(\\.|[^\]\\\n])*])+/'
+             r'([gimuysd]+\b|\B)', String.Regex, '#pop'),
+            (r'(?=/)', Text, ('#pop', 'badregex')),
+            default('#pop')
+        ],
+        'badregex': [
+            (r'\n', Whitespace, '#pop')
+        ],
+        'singlestring': [
+            (r'\\.', String.Escape),
+            (r"'", String.Single, '#pop'),
+            (r"[^\\']+", String.Single),
+        ],
+        'doublestring': [
+            (r'\\.', String.Escape),
+            (r'"', String.Double, '#pop'),
+            (r'[^\\"]+', String.Double),
+        ],
+        'root': [
+            (r'^(?=\s|/|', Comment, '#pop'),
+            (r'[^\-]+|-', Comment),
+        ],
+    }
+
+
+class ReasonLexer(RegexLexer):
+    """
+    For the ReasonML language.
+    """
+
+    name = 'ReasonML'
+    url = 'https://reasonml.github.io/'
+    aliases = ['reasonml', 'reason']
+    filenames = ['*.re', '*.rei']
+    mimetypes = ['text/x-reasonml']
+    version_added = '2.6'
+
+    keywords = (
+        'as', 'assert', 'begin', 'class', 'constraint', 'do', 'done', 'downto',
+        'else', 'end', 'exception', 'external', 'false', 'for', 'fun', 'esfun',
+        'function', 'functor', 'if', 'in', 'include', 'inherit', 'initializer', 'lazy',
+        'let', 'switch', 'module', 'pub', 'mutable', 'new', 'nonrec', 'object', 'of',
+        'open', 'pri', 'rec', 'sig', 'struct', 'then', 'to', 'true', 'try',
+        'type', 'val', 'virtual', 'when', 'while', 'with',
+    )
+    keyopts = (
+        '!=', '#', '&', '&&', r'\(', r'\)', r'\*', r'\+', ',', '-',
+        r'-\.', '=>', r'\.', r'\.\.', r'\.\.\.', ':', '::', ':=', ':>', ';', ';;', '<',
+        '<-', '=', '>', '>]', r'>\}', r'\?', r'\?\?', r'\[', r'\[<', r'\[>',
+        r'\[\|', ']', '_', '`', r'\{', r'\{<', r'\|', r'\|\|', r'\|]', r'\}', '~'
+    )
+
+    operators = r'[!$%&*+\./:<=>?@^|~-]'
+    word_operators = ('and', 'asr', 'land', 'lor', 'lsl', 'lsr', 'lxor', 'mod', 'or')
+    prefix_syms = r'[!?~]'
+    infix_syms = r'[=<>@^|&+\*/$%-]'
+    primitives = ('unit', 'int', 'float', 'bool', 'string', 'char', 'list', 'array')
+
+    tokens = {
+        'escape-sequence': [
+            (r'\\[\\"\'ntbr]', String.Escape),
+            (r'\\[0-9]{3}', String.Escape),
+            (r'\\x[0-9a-fA-F]{2}', String.Escape),
+        ],
+        'root': [
+            (r'\s+', Text),
+            (r'false|true|\(\)|\[\]', Name.Builtin.Pseudo),
+            (r'\b([A-Z][\w\']*)(?=\s*\.)', Name.Namespace, 'dotted'),
+            (r'\b([A-Z][\w\']*)', Name.Class),
+            (r'//.*?\n', Comment.Single),
+            (r'\/\*(?!/)', Comment.Multiline, 'comment'),
+            (r'\b({})\b'.format('|'.join(keywords)), Keyword),
+            (r'({})'.format('|'.join(keyopts[::-1])), Operator.Word),
+            (rf'({infix_syms}|{prefix_syms})?{operators}', Operator),
+            (r'\b({})\b'.format('|'.join(word_operators)), Operator.Word),
+            (r'\b({})\b'.format('|'.join(primitives)), Keyword.Type),
+
+            (r"[^\W\d][\w']*", Name),
+
+            (r'-?\d[\d_]*(.[\d_]*)?([eE][+\-]?\d[\d_]*)', Number.Float),
+            (r'0[xX][\da-fA-F][\da-fA-F_]*', Number.Hex),
+            (r'0[oO][0-7][0-7_]*', Number.Oct),
+            (r'0[bB][01][01_]*', Number.Bin),
+            (r'\d[\d_]*', Number.Integer),
+
+            (r"'(?:(\\[\\\"'ntbr ])|(\\[0-9]{3})|(\\x[0-9a-fA-F]{2}))'",
+             String.Char),
+            (r"'.'", String.Char),
+            (r"'", Keyword),
+
+            (r'"', String.Double, 'string'),
+
+            (r'[~?][a-z][\w\']*:', Name.Variable),
+        ],
+        'comment': [
+            (r'[^/*]+', Comment.Multiline),
+            (r'\/\*', Comment.Multiline, '#push'),
+            (r'\*\/', Comment.Multiline, '#pop'),
+            (r'\*', Comment.Multiline),
+        ],
+        'string': [
+            (r'[^\\"]+', String.Double),
+            include('escape-sequence'),
+            (r'\\\n', String.Double),
+            (r'"', String.Double, '#pop'),
+        ],
+        'dotted': [
+            (r'\s+', Text),
+            (r'\.', Punctuation),
+            (r'[A-Z][\w\']*(?=\s*\.)', Name.Namespace),
+            (r'[A-Z][\w\']*', Name.Class, '#pop'),
+            (r'[a-z_][\w\']*', Name, '#pop'),
+            default('#pop'),
+        ],
+    }
+
+
+class FStarLexer(RegexLexer):
+    """
+    For the F* language.
+    """
+
+    name = 'FStar'
+    url = 'https://www.fstar-lang.org/'
+    aliases = ['fstar']
+    filenames = ['*.fst', '*.fsti']
+    mimetypes = ['text/x-fstar']
+    version_added = '2.7'
+
+    keywords = (
+        'abstract', 'attributes', 'noeq', 'unopteq', 'and'
+        'begin', 'by', 'default', 'effect', 'else', 'end', 'ensures',
+        'exception', 'exists', 'false', 'forall', 'fun', 'function', 'if',
+        'in', 'include', 'inline', 'inline_for_extraction', 'irreducible',
+        'logic', 'match', 'module', 'mutable', 'new', 'new_effect', 'noextract',
+        'of', 'open', 'opaque', 'private', 'range_of', 'reifiable',
+        'reify', 'reflectable', 'requires', 'set_range_of', 'sub_effect',
+        'synth', 'then', 'total', 'true', 'try', 'type', 'unfold', 'unfoldable',
+        'val', 'when', 'with', 'not'
+    )
+    decl_keywords = ('let', 'rec')
+    assume_keywords = ('assume', 'admit', 'assert', 'calc')
+    keyopts = (
+        r'~', r'-', r'/\\', r'\\/', r'<:', r'<@', r'\(\|', r'\|\)', r'#', r'u#',
+        r'&', r'\(', r'\)', r'\(\)', r',', r'~>', r'->', r'<-', r'<--', r'<==>',
+        r'==>', r'\.', r'\?', r'\?\.', r'\.\[', r'\.\(', r'\.\(\|', r'\.\[\|',
+        r'\{:pattern', r':', r'::', r':=', r';', r';;', r'=', r'%\[', r'!\{',
+        r'\[', r'\[@', r'\[\|', r'\|>', r'\]', r'\|\]', r'\{', r'\|', r'\}', r'\$'
+    )
+
+    operators = r'[!$%&*+\./:<=>?@^|~-]'
+    prefix_syms = r'[!?~]'
+    infix_syms = r'[=<>@^|&+\*/$%-]'
+    primitives = ('unit', 'int', 'float', 'bool', 'string', 'char', 'list', 'array')
+
+    tokens = {
+        'escape-sequence': [
+            (r'\\[\\"\'ntbr]', String.Escape),
+            (r'\\[0-9]{3}', String.Escape),
+            (r'\\x[0-9a-fA-F]{2}', String.Escape),
+        ],
+        'root': [
+            (r'\s+', Text),
+            (r'false|true|False|True|\(\)|\[\]', Name.Builtin.Pseudo),
+            (r'\b([A-Z][\w\']*)(?=\s*\.)', Name.Namespace, 'dotted'),
+            (r'\b([A-Z][\w\']*)', Name.Class),
+            (r'\(\*(?![)])', Comment, 'comment'),
+            (r'\/\/.+$', Comment),
+            (r'\b({})\b'.format('|'.join(keywords)), Keyword),
+            (r'\b({})\b'.format('|'.join(assume_keywords)), Name.Exception),
+            (r'\b({})\b'.format('|'.join(decl_keywords)), Keyword.Declaration),
+            (r'({})'.format('|'.join(keyopts[::-1])), Operator),
+            (rf'({infix_syms}|{prefix_syms})?{operators}', Operator),
+            (r'\b({})\b'.format('|'.join(primitives)), Keyword.Type),
+
+            (r"[^\W\d][\w']*", Name),
+
+            (r'-?\d[\d_]*(.[\d_]*)?([eE][+\-]?\d[\d_]*)', Number.Float),
+            (r'0[xX][\da-fA-F][\da-fA-F_]*', Number.Hex),
+            (r'0[oO][0-7][0-7_]*', Number.Oct),
+            (r'0[bB][01][01_]*', Number.Bin),
+            (r'\d[\d_]*', Number.Integer),
+
+            (r"'(?:(\\[\\\"'ntbr ])|(\\[0-9]{3})|(\\x[0-9a-fA-F]{2}))'",
+             String.Char),
+            (r"'.'", String.Char),
+            (r"'", Keyword),  # a stray quote is another syntax element
+            (r"\`([\w\'.]+)\`", Operator.Word),  # for infix applications
+            (r"\`", Keyword),  # for quoting
+            (r'"', String.Double, 'string'),
+
+            (r'[~?][a-z][\w\']*:', Name.Variable),
+        ],
+        'comment': [
+            (r'[^(*)]+', Comment),
+            (r'\(\*', Comment, '#push'),
+            (r'\*\)', Comment, '#pop'),
+            (r'[(*)]', Comment),
+        ],
+        'string': [
+            (r'[^\\"]+', String.Double),
+            include('escape-sequence'),
+            (r'\\\n', String.Double),
+            (r'"', String.Double, '#pop'),
+        ],
+        'dotted': [
+            (r'\s+', Text),
+            (r'\.', Punctuation),
+            (r'[A-Z][\w\']*(?=\s*\.)', Name.Namespace),
+            (r'[A-Z][\w\']*', Name.Class, '#pop'),
+            (r'[a-z_][\w\']*', Name, '#pop'),
+            default('#pop'),
+        ],
+    }
diff --git a/pygments/lexers/modeling.py b/pygments/lexers/modeling.py
new file mode 100644
index 0000000000000000000000000000000000000000..d681e7f3bb1479dc1ada0c56204d51d35df4ec8f
--- /dev/null
+++ b/pygments/lexers/modeling.py
@@ -0,0 +1,366 @@
+"""
+    pygments.lexers.modeling
+    ~~~~~~~~~~~~~~~~~~~~~~~~
+
+    Lexers for modeling languages.
+
+    :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS.
+    :license: BSD, see LICENSE for details.
+"""
+
+import re
+
+from pygments.lexer import RegexLexer, include, bygroups, using, default
+from pygments.token import Text, Comment, Operator, Keyword, Name, String, \
+    Number, Punctuation, Whitespace
+
+from pygments.lexers.html import HtmlLexer
+from pygments.lexers import _stan_builtins
+
+__all__ = ['ModelicaLexer', 'BugsLexer', 'JagsLexer', 'StanLexer']
+
+
+class ModelicaLexer(RegexLexer):
+    """
+    For Modelica source code.
+    """
+    name = 'Modelica'
+    url = 'http://www.modelica.org/'
+    aliases = ['modelica']
+    filenames = ['*.mo']
+    mimetypes = ['text/x-modelica']
+    version_added = '1.1'
+
+    flags = re.DOTALL | re.MULTILINE
+
+    _name = r"(?:'(?:[^\\']|\\.)+'|[a-zA-Z_]\w*)"
+
+    tokens = {
+        'whitespace': [
+            (r'[\s\ufeff]+', Text),
+            (r'//[^\n]*\n?', Comment.Single),
+            (r'/\*.*?\*/', Comment.Multiline)
+        ],
+        'root': [
+            include('whitespace'),
+            (r'"', String.Double, 'string'),
+            (r'[()\[\]{},;]+', Punctuation),
+            (r'\.?[*^/+-]|\.|<>|[<>:=]=?', Operator),
+            (r'\d+(\.?\d*[eE][-+]?\d+|\.\d*)', Number.Float),
+            (r'\d+', Number.Integer),
+            (r'(abs|acos|actualStream|array|asin|assert|AssertionLevel|atan|'
+             r'atan2|backSample|Boolean|cardinality|cat|ceil|change|Clock|'
+             r'Connections|cos|cosh|cross|delay|diagonal|div|edge|exp|'
+             r'ExternalObject|fill|floor|getInstanceName|hold|homotopy|'
+             r'identity|inStream|integer|Integer|interval|inverse|isPresent|'
+             r'linspace|log|log10|matrix|max|min|mod|ndims|noClock|noEvent|'
+             r'ones|outerProduct|pre|previous|product|Real|reinit|rem|rooted|'
+             r'sample|scalar|semiLinear|shiftSample|sign|sin|sinh|size|skew|'
+             r'smooth|spatialDistribution|sqrt|StateSelect|String|subSample|'
+             r'sum|superSample|symmetric|tan|tanh|terminal|terminate|time|'
+             r'transpose|vector|zeros)\b', Name.Builtin),
+            (r'(algorithm|annotation|break|connect|constant|constrainedby|der|'
+             r'discrete|each|else|elseif|elsewhen|encapsulated|enumeration|'
+             r'equation|exit|expandable|extends|external|firstTick|final|flow|for|if|'
+             r'import|impure|in|initial|inner|input|interval|loop|nondiscrete|outer|'
+             r'output|parameter|partial|protected|public|pure|redeclare|'
+             r'replaceable|return|stream|then|when|while)\b',
+             Keyword.Reserved),
+            (r'(and|not|or)\b', Operator.Word),
+            (r'(block|class|connector|end|function|model|operator|package|'
+             r'record|type)\b', Keyword.Reserved, 'class'),
+            (r'(false|true)\b', Keyword.Constant),
+            (r'within\b', Keyword.Reserved, 'package-prefix'),
+            (_name, Name)
+        ],
+        'class': [
+            include('whitespace'),
+            (r'(function|record)\b', Keyword.Reserved),
+            (r'(if|for|when|while)\b', Keyword.Reserved, '#pop'),
+            (_name, Name.Class, '#pop'),
+            default('#pop')
+        ],
+        'package-prefix': [
+            include('whitespace'),
+            (_name, Name.Namespace, '#pop'),
+            default('#pop')
+        ],
+        'string': [
+            (r'"', String.Double, '#pop'),
+            (r'\\[\'"?\\abfnrtv]', String.Escape),
+            (r'(?i)<\s*html\s*>([^\\"]|\\.)+?(<\s*/\s*html\s*>|(?="))',
+             using(HtmlLexer)),
+            (r'<|\\?[^"\\<]+', String.Double)
+        ]
+    }
+
+
+class BugsLexer(RegexLexer):
+    """
+    Pygments Lexer for OpenBugs and WinBugs
+    models.
+    """
+
+    name = 'BUGS'
+    aliases = ['bugs', 'winbugs', 'openbugs']
+    filenames = ['*.bug']
+    url = 'https://www.mrc-bsu.cam.ac.uk/software/bugs/openbugs'
+    version_added = '1.6'
+
+    _FUNCTIONS = (
+        # Scalar functions
+        'abs', 'arccos', 'arccosh', 'arcsin', 'arcsinh', 'arctan', 'arctanh',
+        'cloglog', 'cos', 'cosh', 'cumulative', 'cut', 'density', 'deviance',
+        'equals', 'expr', 'gammap', 'ilogit', 'icloglog', 'integral', 'log',
+        'logfact', 'loggam', 'logit', 'max', 'min', 'phi', 'post.p.value',
+        'pow', 'prior.p.value', 'probit', 'replicate.post', 'replicate.prior',
+        'round', 'sin', 'sinh', 'solution', 'sqrt', 'step', 'tan', 'tanh',
+        'trunc',
+        # Vector functions
+        'inprod', 'interp.lin', 'inverse', 'logdet', 'mean', 'eigen.vals',
+        'ode', 'prod', 'p.valueM', 'rank', 'ranked', 'replicate.postM',
+        'sd', 'sort', 'sum',
+        # Special
+        'D', 'I', 'F', 'T', 'C')
+    """ OpenBUGS built-in functions
+
+    From http://www.openbugs.info/Manuals/ModelSpecification.html#ContentsAII
+
+    This also includes
+
+    - T, C, I : Truncation and censoring.
+      ``T`` and ``C`` are in OpenBUGS. ``I`` in WinBUGS.
+    - D : ODE
+    - F : Functional http://www.openbugs.info/Examples/Functionals.html
+
+    """
+
+    _DISTRIBUTIONS = ('dbern', 'dbin', 'dcat', 'dnegbin', 'dpois',
+                      'dhyper', 'dbeta', 'dchisqr', 'ddexp', 'dexp',
+                      'dflat', 'dgamma', 'dgev', 'df', 'dggamma', 'dgpar',
+                      'dloglik', 'dlnorm', 'dlogis', 'dnorm', 'dpar',
+                      'dt', 'dunif', 'dweib', 'dmulti', 'ddirch', 'dmnorm',
+                      'dmt', 'dwish')
+    """ OpenBUGS built-in distributions
+
+    Functions from
+    http://www.openbugs.info/Manuals/ModelSpecification.html#ContentsAI
+    """
+
+    tokens = {
+        'whitespace': [
+            (r"\s+", Text),
+        ],
+        'comments': [
+            # Comments
+            (r'#.*$', Comment.Single),
+        ],
+        'root': [
+            # Comments
+            include('comments'),
+            include('whitespace'),
+            # Block start
+            (r'(model)(\s+)(\{)',
+             bygroups(Keyword.Namespace, Text, Punctuation)),
+            # Reserved Words
+            (r'(for|in)(?![\w.])', Keyword.Reserved),
+            # Built-in Functions
+            (r'({})(?=\s*\()'.format(r'|'.join(_FUNCTIONS + _DISTRIBUTIONS)),
+             Name.Builtin),
+            # Regular variable names
+            (r'[A-Za-z][\w.]*', Name),
+            # Number Literals
+            (r'[-+]?[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)?', Number),
+            # Punctuation
+            (r'\[|\]|\(|\)|:|,|;', Punctuation),
+            # Assignment operators
+            # SLexer makes these tokens Operators.
+            (r'<-|~', Operator),
+            # Infix and prefix operators
+            (r'\+|-|\*|/', Operator),
+            # Block
+            (r'[{}]', Punctuation),
+        ]
+    }
+
+    def analyse_text(text):
+        if re.search(r"^\s*model\s*{", text, re.M):
+            return 0.7
+        else:
+            return 0.0
+
+
+class JagsLexer(RegexLexer):
+    """
+    Pygments Lexer for JAGS.
+    """
+
+    name = 'JAGS'
+    aliases = ['jags']
+    filenames = ['*.jag', '*.bug']
+    url = 'https://mcmc-jags.sourceforge.io'
+    version_added = '1.6'
+
+    # JAGS
+    _FUNCTIONS = (
+        'abs', 'arccos', 'arccosh', 'arcsin', 'arcsinh', 'arctan', 'arctanh',
+        'cos', 'cosh', 'cloglog',
+        'equals', 'exp', 'icloglog', 'ifelse', 'ilogit', 'log', 'logfact',
+        'loggam', 'logit', 'phi', 'pow', 'probit', 'round', 'sin', 'sinh',
+        'sqrt', 'step', 'tan', 'tanh', 'trunc', 'inprod', 'interp.lin',
+        'logdet', 'max', 'mean', 'min', 'prod', 'sum', 'sd', 'inverse',
+        'rank', 'sort', 't', 'acos', 'acosh', 'asin', 'asinh', 'atan',
+        # Truncation/Censoring (should I include)
+        'T', 'I')
+    # Distributions with density, probability and quartile functions
+    _DISTRIBUTIONS = tuple(f'[dpq]{x}' for x in
+                           ('bern', 'beta', 'dchiqsqr', 'ddexp', 'dexp',
+                            'df', 'gamma', 'gen.gamma', 'logis', 'lnorm',
+                            'negbin', 'nchisqr', 'norm', 'par', 'pois', 'weib'))
+    # Other distributions without density and probability
+    _OTHER_DISTRIBUTIONS = (
+        'dt', 'dunif', 'dbetabin', 'dbern', 'dbin', 'dcat', 'dhyper',
+        'ddirch', 'dmnorm', 'dwish', 'dmt', 'dmulti', 'dbinom', 'dchisq',
+        'dnbinom', 'dweibull', 'ddirich')
+
+    tokens = {
+        'whitespace': [
+            (r"\s+", Text),
+        ],
+        'names': [
+            # Regular variable names
+            (r'[a-zA-Z][\w.]*\b', Name),
+        ],
+        'comments': [
+            # do not use stateful comments
+            (r'(?s)/\*.*?\*/', Comment.Multiline),
+            # Comments
+            (r'#.*$', Comment.Single),
+        ],
+        'root': [
+            # Comments
+            include('comments'),
+            include('whitespace'),
+            # Block start
+            (r'(model|data)(\s+)(\{)',
+             bygroups(Keyword.Namespace, Text, Punctuation)),
+            (r'var(?![\w.])', Keyword.Declaration),
+            # Reserved Words
+            (r'(for|in)(?![\w.])', Keyword.Reserved),
+            # Builtins
+            # Need to use lookahead because . is a valid char
+            (r'({})(?=\s*\()'.format(r'|'.join(_FUNCTIONS
+                                          + _DISTRIBUTIONS
+                                          + _OTHER_DISTRIBUTIONS)),
+             Name.Builtin),
+            # Names
+            include('names'),
+            # Number Literals
+            (r'[-+]?[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)?', Number),
+            (r'\[|\]|\(|\)|:|,|;', Punctuation),
+            # Assignment operators
+            (r'<-|~', Operator),
+            # # JAGS includes many more than OpenBUGS
+            (r'\+|-|\*|\/|\|\|[&]{2}|[<>=]=?|\^|%.*?%', Operator),
+            (r'[{}]', Punctuation),
+        ]
+    }
+
+    def analyse_text(text):
+        if re.search(r'^\s*model\s*\{', text, re.M):
+            if re.search(r'^\s*data\s*\{', text, re.M):
+                return 0.9
+            elif re.search(r'^\s*var', text, re.M):
+                return 0.9
+            else:
+                return 0.3
+        else:
+            return 0
+
+
+class StanLexer(RegexLexer):
+    """Pygments Lexer for Stan models.
+
+    The Stan modeling language is specified in the *Stan Modeling Language
+    User's Guide and Reference Manual, v2.17.0*,
+    `pdf `__.
+    """
+
+    name = 'Stan'
+    aliases = ['stan']
+    filenames = ['*.stan']
+    url = 'https://mc-stan.org'
+    version_added = '1.6'
+
+    tokens = {
+        'whitespace': [
+            (r"\s+", Text),
+        ],
+        'comments': [
+            (r'(?s)/\*.*?\*/', Comment.Multiline),
+            # Comments
+            (r'(//|#).*$', Comment.Single),
+        ],
+        'root': [
+            (r'"[^"]*"', String),
+            # Comments
+            include('comments'),
+            # block start
+            include('whitespace'),
+            # Block start
+            (r'({})(\s*)(\{{)'.format(r'|'.join(('functions', 'data', r'transformed\s+?data',
+                        'parameters', r'transformed\s+parameters',
+                        'model', r'generated\s+quantities'))),
+             bygroups(Keyword.Namespace, Text, Punctuation)),
+            # target keyword
+            (r'target\s*\+=', Keyword),
+            # Reserved Words
+            (r'({})\b'.format(r'|'.join(_stan_builtins.KEYWORDS)), Keyword),
+            # Truncation
+            (r'T(?=\s*\[)', Keyword),
+            # Data types
+            (r'({})\b'.format(r'|'.join(_stan_builtins.TYPES)), Keyword.Type),
+             # < should be punctuation, but elsewhere I can't tell if it is in
+             # a range constraint
+            (r'(<)(\s*)(upper|lower|offset|multiplier)(\s*)(=)',
+             bygroups(Operator, Whitespace, Keyword, Whitespace, Punctuation)),
+            (r'(,)(\s*)(upper)(\s*)(=)',
+             bygroups(Punctuation, Whitespace, Keyword, Whitespace, Punctuation)),
+            # Punctuation
+            (r"[;,\[\]()]", Punctuation),
+            # Builtin
+            (r'({})(?=\s*\()'.format('|'.join(_stan_builtins.FUNCTIONS)), Name.Builtin),
+            (r'(~)(\s*)({})(?=\s*\()'.format('|'.join(_stan_builtins.DISTRIBUTIONS)),
+                bygroups(Operator, Whitespace, Name.Builtin)),
+            # Special names ending in __, like lp__
+            (r'[A-Za-z]\w*__\b', Name.Builtin.Pseudo),
+            (r'({})\b'.format(r'|'.join(_stan_builtins.RESERVED)), Keyword.Reserved),
+            # user-defined functions
+            (r'[A-Za-z]\w*(?=\s*\()]', Name.Function),
+            # Imaginary Literals
+            (r'[0-9]+(\.[0-9]*)?([eE][+-]?[0-9]+)?i', Number.Float),
+            (r'\.[0-9]+([eE][+-]?[0-9]+)?i', Number.Float),
+            (r'[0-9]+i', Number.Float),
+            # Real Literals
+            (r'[0-9]+(\.[0-9]*)?([eE][+-]?[0-9]+)?', Number.Float),
+            (r'\.[0-9]+([eE][+-]?[0-9]+)?', Number.Float),
+            # Integer Literals
+            (r'[0-9]+', Number.Integer),
+            # Regular variable names
+            (r'[A-Za-z]\w*\b', Name),
+            # Assignment operators
+            (r'<-|(?:\+|-|\.?/|\.?\*|=)?=|~', Operator),
+            # Infix, prefix and postfix operators (and = )
+            (r"\+|-|\.?\*|\.?/|\\|'|\.?\^|!=?|<=?|>=?|\|\||&&|%|\?|:|%/%|!", Operator),
+            # Block delimiters
+            (r'[{}]', Punctuation),
+            # Distribution |
+            (r'\|', Punctuation)
+        ]
+    }
+
+    def analyse_text(text):
+        if re.search(r'^\s*parameters\s*\{', text, re.M):
+            return 1.0
+        else:
+            return 0.0
diff --git a/pygments/lexers/modula2.py b/pygments/lexers/modula2.py
new file mode 100644
index 0000000000000000000000000000000000000000..713f4722f7d8d4837450a54bde5b1b7160af9911
--- /dev/null
+++ b/pygments/lexers/modula2.py
@@ -0,0 +1,1579 @@
+"""
+    pygments.lexers.modula2
+    ~~~~~~~~~~~~~~~~~~~~~~~
+
+    Multi-Dialect Lexer for Modula-2.
+
+    :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS.
+    :license: BSD, see LICENSE for details.
+"""
+
+import re
+
+from pygments.lexer import RegexLexer, include
+from pygments.util import get_bool_opt, get_list_opt
+from pygments.token import Text, Comment, Operator, Keyword, Name, \
+    String, Number, Punctuation, Error
+
+__all__ = ['Modula2Lexer']
+
+
+# Multi-Dialect Modula-2 Lexer
+class Modula2Lexer(RegexLexer):
+    """
+    For Modula-2 source code.
+
+    The Modula-2 lexer supports several dialects.  By default, it operates in
+    fallback mode, recognising the *combined* literals, punctuation symbols
+    and operators of all supported dialects, and the *combined* reserved words
+    and builtins of PIM Modula-2, ISO Modula-2 and Modula-2 R10, while not
+    differentiating between library defined identifiers.
+
+    To select a specific dialect, a dialect option may be passed
+    or a dialect tag may be embedded into a source file.
+
+    Dialect Options:
+
+    `m2pim`
+        Select PIM Modula-2 dialect.
+    `m2iso`
+        Select ISO Modula-2 dialect.
+    `m2r10`
+        Select Modula-2 R10 dialect.
+    `objm2`
+        Select Objective Modula-2 dialect.
+
+    The PIM and ISO dialect options may be qualified with a language extension.
+
+    Language Extensions:
+
+    `+aglet`
+        Select Aglet Modula-2 extensions, available with m2iso.
+    `+gm2`
+        Select GNU Modula-2 extensions, available with m2pim.
+    `+p1`
+        Select p1 Modula-2 extensions, available with m2iso.
+    `+xds`
+        Select XDS Modula-2 extensions, available with m2iso.
+
+
+    Passing a Dialect Option via Unix Commandline Interface
+
+    Dialect options may be passed to the lexer using the `dialect` key.
+    Only one such option should be passed. If multiple dialect options are
+    passed, the first valid option is used, any subsequent options are ignored.
+
+    Examples:
+
+    `$ pygmentize -O full,dialect=m2iso -f html -o /path/to/output /path/to/input`
+        Use ISO dialect to render input to HTML output
+    `$ pygmentize -O full,dialect=m2iso+p1 -f rtf -o /path/to/output /path/to/input`
+        Use ISO dialect with p1 extensions to render input to RTF output
+
+
+    Embedding a Dialect Option within a source file
+
+    A dialect option may be embedded in a source file in form of a dialect
+    tag, a specially formatted comment that specifies a dialect option.
+
+    Dialect Tag EBNF::
+
+       dialectTag :
+           OpeningCommentDelim Prefix dialectOption ClosingCommentDelim ;
+
+       dialectOption :
+           'm2pim' | 'm2iso' | 'm2r10' | 'objm2' |
+           'm2iso+aglet' | 'm2pim+gm2' | 'm2iso+p1' | 'm2iso+xds' ;
+
+       Prefix : '!' ;
+
+       OpeningCommentDelim : '(*' ;
+
+       ClosingCommentDelim : '*)' ;
+
+    No whitespace is permitted between the tokens of a dialect tag.
+
+    In the event that a source file contains multiple dialect tags, the first
+    tag that contains a valid dialect option will be used and any subsequent
+    dialect tags will be ignored.  Ideally, a dialect tag should be placed
+    at the beginning of a source file.
+
+    An embedded dialect tag overrides a dialect option set via command line.
+
+    Examples:
+
+    ``(*!m2r10*) DEFINITION MODULE Foobar; ...``
+        Use Modula2 R10 dialect to render this source file.
+    ``(*!m2pim+gm2*) DEFINITION MODULE Bazbam; ...``
+        Use PIM dialect with GNU extensions to render this source file.
+
+
+    Algol Publication Mode:
+
+    In Algol publication mode, source text is rendered for publication of
+    algorithms in scientific papers and academic texts, following the format
+    of the Revised Algol-60 Language Report.  It is activated by passing
+    one of two corresponding styles as an option:
+
+    `algol`
+        render reserved words lowercase underline boldface
+        and builtins lowercase boldface italic
+    `algol_nu`
+        render reserved words lowercase boldface (no underlining)
+        and builtins lowercase boldface italic
+
+    The lexer automatically performs the required lowercase conversion when
+    this mode is activated.
+
+    Example:
+
+    ``$ pygmentize -O full,style=algol -f latex -o /path/to/output /path/to/input``
+        Render input file in Algol publication mode to LaTeX output.
+
+
+    Rendering Mode of First Class ADT Identifiers:
+
+    The rendering of standard library first class ADT identifiers is controlled
+    by option flag "treat_stdlib_adts_as_builtins".
+
+    When this option is turned on, standard library ADT identifiers are rendered
+    as builtins.  When it is turned off, they are rendered as ordinary library
+    identifiers.
+
+    `treat_stdlib_adts_as_builtins` (default: On)
+
+    The option is useful for dialects that support ADTs as first class objects
+    and provide ADTs in the standard library that would otherwise be built-in.
+
+    At present, only Modula-2 R10 supports library ADTs as first class objects
+    and therefore, no ADT identifiers are defined for any other dialects.
+
+    Example:
+
+    ``$ pygmentize -O full,dialect=m2r10,treat_stdlib_adts_as_builtins=Off ...``
+        Render standard library ADTs as ordinary library types.
+
+    .. versionchanged:: 2.1
+       Added multi-dialect support.
+    """
+    name = 'Modula-2'
+    url = 'http://www.modula2.org/'
+    aliases = ['modula2', 'm2']
+    filenames = ['*.def', '*.mod']
+    mimetypes = ['text/x-modula2']
+    version_added = '1.3'
+
+    flags = re.MULTILINE | re.DOTALL
+
+    tokens = {
+        'whitespace': [
+            (r'\n+', Text),  # blank lines
+            (r'\s+', Text),  # whitespace
+        ],
+        'dialecttags': [
+            # PIM Dialect Tag
+            (r'\(\*!m2pim\*\)', Comment.Special),
+            # ISO Dialect Tag
+            (r'\(\*!m2iso\*\)', Comment.Special),
+            # M2R10 Dialect Tag
+            (r'\(\*!m2r10\*\)', Comment.Special),
+            # ObjM2 Dialect Tag
+            (r'\(\*!objm2\*\)', Comment.Special),
+            # Aglet Extensions Dialect Tag
+            (r'\(\*!m2iso\+aglet\*\)', Comment.Special),
+            # GNU Extensions Dialect Tag
+            (r'\(\*!m2pim\+gm2\*\)', Comment.Special),
+            # p1 Extensions Dialect Tag
+            (r'\(\*!m2iso\+p1\*\)', Comment.Special),
+            # XDS Extensions Dialect Tag
+            (r'\(\*!m2iso\+xds\*\)', Comment.Special),
+        ],
+        'identifiers': [
+            (r'([a-zA-Z_$][\w$]*)', Name),
+        ],
+        'prefixed_number_literals': [
+            #
+            # Base-2, whole number
+            (r'0b[01]+(\'[01]+)*', Number.Bin),
+            #
+            # Base-16, whole number
+            (r'0[ux][0-9A-F]+(\'[0-9A-F]+)*', Number.Hex),
+        ],
+        'plain_number_literals': [
+            #
+            # Base-10, real number with exponent
+            (r'[0-9]+(\'[0-9]+)*'  # integral part
+             r'\.[0-9]+(\'[0-9]+)*'  # fractional part
+             r'[eE][+-]?[0-9]+(\'[0-9]+)*',  # exponent
+             Number.Float),
+            #
+            # Base-10, real number without exponent
+            (r'[0-9]+(\'[0-9]+)*'  # integral part
+             r'\.[0-9]+(\'[0-9]+)*',  # fractional part
+             Number.Float),
+            #
+            # Base-10, whole number
+            (r'[0-9]+(\'[0-9]+)*', Number.Integer),
+        ],
+        'suffixed_number_literals': [
+            #
+            # Base-8, whole number
+            (r'[0-7]+B', Number.Oct),
+            #
+            # Base-8, character code
+            (r'[0-7]+C', Number.Oct),
+            #
+            # Base-16, number
+            (r'[0-9A-F]+H', Number.Hex),
+        ],
+        'string_literals': [
+            (r'"(\\\\|\\[^\\]|[^"\\])*"', String.Double),
+            (r"'(\\\\|\\[^\\]|[^'\\])*'", String.Single),
+        ],
+        'digraph_operators': [
+            # Dot Product Operator
+            (r'\*\.', Operator),
+            # Array Concatenation Operator
+            (r'\+>', Operator),  # M2R10 + ObjM2
+            # Inequality Operator
+            (r'<>', Operator),  # ISO + PIM
+            # Less-Or-Equal, Subset
+            (r'<=', Operator),
+            # Greater-Or-Equal, Superset
+            (r'>=', Operator),
+            # Identity Operator
+            (r'==', Operator),  # M2R10 + ObjM2
+            # Type Conversion Operator
+            (r'::', Operator),  # M2R10 + ObjM2
+            # Assignment Symbol
+            (r':=', Operator),
+            # Postfix Increment Mutator
+            (r'\+\+', Operator),  # M2R10 + ObjM2
+            # Postfix Decrement Mutator
+            (r'--', Operator),  # M2R10 + ObjM2
+        ],
+        'unigraph_operators': [
+            # Arithmetic Operators
+            (r'[+-]', Operator),
+            (r'[*/]', Operator),
+            # ISO 80000-2 compliant Set Difference Operator
+            (r'\\', Operator),  # M2R10 + ObjM2
+            # Relational Operators
+            (r'[=#<>]', Operator),
+            # Dereferencing Operator
+            (r'\^', Operator),
+            # Dereferencing Operator Synonym
+            (r'@', Operator),  # ISO
+            # Logical AND Operator Synonym
+            (r'&', Operator),  # PIM + ISO
+            # Logical NOT Operator Synonym
+            (r'~', Operator),  # PIM + ISO
+            # Smalltalk Message Prefix
+            (r'`', Operator),  # ObjM2
+        ],
+        'digraph_punctuation': [
+            # Range Constructor
+            (r'\.\.', Punctuation),
+            # Opening Chevron Bracket
+            (r'<<', Punctuation),  # M2R10 + ISO
+            # Closing Chevron Bracket
+            (r'>>', Punctuation),  # M2R10 + ISO
+            # Blueprint Punctuation
+            (r'->', Punctuation),  # M2R10 + ISO
+            # Distinguish |# and # in M2 R10
+            (r'\|#', Punctuation),
+            # Distinguish ## and # in M2 R10
+            (r'##', Punctuation),
+            # Distinguish |* and * in M2 R10
+            (r'\|\*', Punctuation),
+        ],
+        'unigraph_punctuation': [
+            # Common Punctuation
+            (r'[()\[\]{},.:;|]', Punctuation),
+            # Case Label Separator Synonym
+            (r'!', Punctuation),  # ISO
+            # Blueprint Punctuation
+            (r'\?', Punctuation),  # M2R10 + ObjM2
+        ],
+        'comments': [
+            # Single Line Comment
+            (r'^//.*?\n', Comment.Single),  # M2R10 + ObjM2
+            # Block Comment
+            (r'\(\*([^$].*?)\*\)', Comment.Multiline),
+            # Template Block Comment
+            (r'/\*(.*?)\*/', Comment.Multiline),  # M2R10 + ObjM2
+        ],
+        'pragmas': [
+            # ISO Style Pragmas
+            (r'<\*.*?\*>', Comment.Preproc),  # ISO, M2R10 + ObjM2
+            # Pascal Style Pragmas
+            (r'\(\*\$.*?\*\)', Comment.Preproc),  # PIM
+        ],
+        'root': [
+            include('whitespace'),
+            include('dialecttags'),
+            include('pragmas'),
+            include('comments'),
+            include('identifiers'),
+            include('suffixed_number_literals'),  # PIM + ISO
+            include('prefixed_number_literals'),  # M2R10 + ObjM2
+            include('plain_number_literals'),
+            include('string_literals'),
+            include('digraph_punctuation'),
+            include('digraph_operators'),
+            include('unigraph_punctuation'),
+            include('unigraph_operators'),
+        ]
+    }
+
+#  C o m m o n   D a t a s e t s
+
+    # Common Reserved Words Dataset
+    common_reserved_words = (
+        # 37 common reserved words
+        'AND', 'ARRAY', 'BEGIN', 'BY', 'CASE', 'CONST', 'DEFINITION', 'DIV',
+        'DO', 'ELSE', 'ELSIF', 'END', 'EXIT', 'FOR', 'FROM', 'IF',
+        'IMPLEMENTATION', 'IMPORT', 'IN', 'LOOP', 'MOD', 'MODULE', 'NOT',
+        'OF', 'OR', 'POINTER', 'PROCEDURE', 'RECORD', 'REPEAT', 'RETURN',
+        'SET', 'THEN', 'TO', 'TYPE', 'UNTIL', 'VAR', 'WHILE',
+    )
+
+    # Common Builtins Dataset
+    common_builtins = (
+        # 16 common builtins
+        'ABS', 'BOOLEAN', 'CARDINAL', 'CHAR', 'CHR', 'FALSE', 'INTEGER',
+        'LONGINT', 'LONGREAL', 'MAX', 'MIN', 'NIL', 'ODD', 'ORD', 'REAL',
+        'TRUE',
+    )
+
+    # Common Pseudo-Module Builtins Dataset
+    common_pseudo_builtins = (
+        # 4 common pseudo builtins
+        'ADDRESS', 'BYTE', 'WORD', 'ADR'
+    )
+
+#  P I M   M o d u l a - 2   D a t a s e t s
+
+    # Lexemes to Mark as Error Tokens for PIM Modula-2
+    pim_lexemes_to_reject = (
+        '!', '`', '@', '$', '%', '?', '\\', '==', '++', '--', '::', '*.',
+        '+>', '->', '<<', '>>', '|#', '##',
+    )
+
+    # PIM Modula-2 Additional Reserved Words Dataset
+    pim_additional_reserved_words = (
+        # 3 additional reserved words
+        'EXPORT', 'QUALIFIED', 'WITH',
+    )
+
+    # PIM Modula-2 Additional Builtins Dataset
+    pim_additional_builtins = (
+        # 16 additional builtins
+        'BITSET', 'CAP', 'DEC', 'DISPOSE', 'EXCL', 'FLOAT', 'HALT', 'HIGH',
+        'INC', 'INCL', 'NEW', 'NIL', 'PROC', 'SIZE', 'TRUNC', 'VAL',
+    )
+
+    # PIM Modula-2 Additional Pseudo-Module Builtins Dataset
+    pim_additional_pseudo_builtins = (
+        # 5 additional pseudo builtins
+        'SYSTEM', 'PROCESS', 'TSIZE', 'NEWPROCESS', 'TRANSFER',
+    )
+
+#  I S O   M o d u l a - 2   D a t a s e t s
+
+    # Lexemes to Mark as Error Tokens for ISO Modula-2
+    iso_lexemes_to_reject = (
+        '`', '$', '%', '?', '\\', '==', '++', '--', '::', '*.', '+>', '->',
+        '<<', '>>', '|#', '##',
+    )
+
+    # ISO Modula-2 Additional Reserved Words Dataset
+    iso_additional_reserved_words = (
+        # 9 additional reserved words (ISO 10514-1)
+        'EXCEPT', 'EXPORT', 'FINALLY', 'FORWARD', 'PACKEDSET', 'QUALIFIED',
+        'REM', 'RETRY', 'WITH',
+        # 10 additional reserved words (ISO 10514-2 & ISO 10514-3)
+        'ABSTRACT', 'AS', 'CLASS', 'GUARD', 'INHERIT', 'OVERRIDE', 'READONLY',
+        'REVEAL', 'TRACED', 'UNSAFEGUARDED',
+    )
+
+    # ISO Modula-2 Additional Builtins Dataset
+    iso_additional_builtins = (
+        # 26 additional builtins (ISO 10514-1)
+        'BITSET', 'CAP', 'CMPLX', 'COMPLEX', 'DEC', 'DISPOSE', 'EXCL', 'FLOAT',
+        'HALT', 'HIGH', 'IM', 'INC', 'INCL', 'INT', 'INTERRUPTIBLE',  'LENGTH',
+        'LFLOAT', 'LONGCOMPLEX', 'NEW', 'PROC', 'PROTECTION', 'RE', 'SIZE',
+        'TRUNC', 'UNINTERRUBTIBLE', 'VAL',
+        # 5 additional builtins (ISO 10514-2 & ISO 10514-3)
+        'CREATE', 'DESTROY', 'EMPTY', 'ISMEMBER', 'SELF',
+    )
+
+    # ISO Modula-2 Additional Pseudo-Module Builtins Dataset
+    iso_additional_pseudo_builtins = (
+        # 14 additional builtins (SYSTEM)
+        'SYSTEM', 'BITSPERLOC', 'LOCSPERBYTE', 'LOCSPERWORD', 'LOC',
+        'ADDADR', 'SUBADR', 'DIFADR', 'MAKEADR', 'ADR',
+        'ROTATE', 'SHIFT', 'CAST', 'TSIZE',
+        # 13 additional builtins (COROUTINES)
+        'COROUTINES', 'ATTACH', 'COROUTINE', 'CURRENT', 'DETACH', 'HANDLER',
+        'INTERRUPTSOURCE', 'IOTRANSFER', 'IsATTACHED', 'LISTEN',
+        'NEWCOROUTINE', 'PROT', 'TRANSFER',
+        # 9 additional builtins (EXCEPTIONS)
+        'EXCEPTIONS', 'AllocateSource', 'CurrentNumber', 'ExceptionNumber',
+        'ExceptionSource', 'GetMessage', 'IsCurrentSource',
+        'IsExceptionalExecution', 'RAISE',
+        # 3 additional builtins (TERMINATION)
+        'TERMINATION', 'IsTerminating', 'HasHalted',
+        # 4 additional builtins (M2EXCEPTION)
+        'M2EXCEPTION', 'M2Exceptions', 'M2Exception', 'IsM2Exception',
+        'indexException', 'rangeException', 'caseSelectException',
+        'invalidLocation', 'functionException', 'wholeValueException',
+        'wholeDivException', 'realValueException', 'realDivException',
+        'complexValueException', 'complexDivException', 'protException',
+        'sysException', 'coException', 'exException',
+    )
+
+#  M o d u l a - 2   R 1 0   D a t a s e t s
+
+    # Lexemes to Mark as Error Tokens for Modula-2 R10
+    m2r10_lexemes_to_reject = (
+        '!', '`', '@', '$', '%', '&', '<>',
+    )
+
+    # Modula-2 R10 reserved words in addition to the common set
+    m2r10_additional_reserved_words = (
+        # 12 additional reserved words
+        'ALIAS', 'ARGLIST', 'BLUEPRINT', 'COPY', 'GENLIB', 'INDETERMINATE',
+        'NEW', 'NONE', 'OPAQUE', 'REFERENTIAL', 'RELEASE', 'RETAIN',
+        # 2 additional reserved words with symbolic assembly option
+        'ASM', 'REG',
+    )
+
+    # Modula-2 R10 builtins in addition to the common set
+    m2r10_additional_builtins = (
+        # 26 additional builtins
+        'CARDINAL', 'COUNT', 'EMPTY', 'EXISTS', 'INSERT', 'LENGTH', 'LONGCARD',
+        'OCTET', 'PTR', 'PRED', 'READ', 'READNEW', 'REMOVE', 'RETRIEVE', 'SORT',
+        'STORE', 'SUBSET', 'SUCC', 'TLIMIT', 'TMAX', 'TMIN', 'TRUE', 'TSIZE',
+        'UNICHAR', 'WRITE', 'WRITEF',
+    )
+
+    # Modula-2 R10 Additional Pseudo-Module Builtins Dataset
+    m2r10_additional_pseudo_builtins = (
+        # 13 additional builtins (TPROPERTIES)
+        'TPROPERTIES', 'PROPERTY', 'LITERAL', 'TPROPERTY', 'TLITERAL',
+        'TBUILTIN', 'TDYN', 'TREFC', 'TNIL', 'TBASE', 'TPRECISION',
+        'TMAXEXP', 'TMINEXP',
+        # 4 additional builtins (CONVERSION)
+        'CONVERSION', 'TSXFSIZE', 'SXF', 'VAL',
+        # 35 additional builtins (UNSAFE)
+        'UNSAFE', 'CAST', 'INTRINSIC', 'AVAIL', 'ADD', 'SUB', 'ADDC', 'SUBC',
+        'FETCHADD', 'FETCHSUB', 'SHL', 'SHR', 'ASHR', 'ROTL', 'ROTR', 'ROTLC',
+        'ROTRC', 'BWNOT', 'BWAND', 'BWOR', 'BWXOR', 'BWNAND', 'BWNOR',
+        'SETBIT', 'TESTBIT', 'LSBIT', 'MSBIT', 'CSBITS', 'BAIL', 'HALT',
+        'TODO', 'FFI', 'ADDR', 'VARGLIST', 'VARGC',
+        # 11 additional builtins (ATOMIC)
+        'ATOMIC', 'INTRINSIC', 'AVAIL', 'SWAP', 'CAS', 'INC', 'DEC', 'BWAND',
+        'BWNAND', 'BWOR', 'BWXOR',
+        # 7 additional builtins (COMPILER)
+        'COMPILER', 'DEBUG', 'MODNAME', 'PROCNAME', 'LINENUM', 'DEFAULT',
+        'HASH',
+        # 5 additional builtins (ASSEMBLER)
+        'ASSEMBLER', 'REGISTER', 'SETREG', 'GETREG', 'CODE',
+    )
+
+#  O b j e c t i v e   M o d u l a - 2   D a t a s e t s
+
+    # Lexemes to Mark as Error Tokens for Objective Modula-2
+    objm2_lexemes_to_reject = (
+        '!', '$', '%', '&', '<>',
+    )
+
+    # Objective Modula-2 Extensions
+    # reserved words in addition to Modula-2 R10
+    objm2_additional_reserved_words = (
+        # 16 additional reserved words
+        'BYCOPY', 'BYREF', 'CLASS', 'CONTINUE', 'CRITICAL', 'INOUT', 'METHOD',
+        'ON', 'OPTIONAL', 'OUT', 'PRIVATE', 'PROTECTED', 'PROTOCOL', 'PUBLIC',
+        'SUPER', 'TRY',
+    )
+
+    # Objective Modula-2 Extensions
+    # builtins in addition to Modula-2 R10
+    objm2_additional_builtins = (
+        # 3 additional builtins
+        'OBJECT', 'NO', 'YES',
+    )
+
+    # Objective Modula-2 Extensions
+    # pseudo-module builtins in addition to Modula-2 R10
+    objm2_additional_pseudo_builtins = (
+        # None
+    )
+
+#  A g l e t   M o d u l a - 2   D a t a s e t s
+
+    # Aglet Extensions
+    # reserved words in addition to ISO Modula-2
+    aglet_additional_reserved_words = (
+        # None
+    )
+
+    # Aglet Extensions
+    # builtins in addition to ISO Modula-2
+    aglet_additional_builtins = (
+        # 9 additional builtins
+        'BITSET8', 'BITSET16', 'BITSET32', 'CARDINAL8', 'CARDINAL16',
+        'CARDINAL32', 'INTEGER8', 'INTEGER16', 'INTEGER32',
+    )
+
+    # Aglet Modula-2 Extensions
+    # pseudo-module builtins in addition to ISO Modula-2
+    aglet_additional_pseudo_builtins = (
+        # None
+    )
+
+#  G N U   M o d u l a - 2   D a t a s e t s
+
+    # GNU Extensions
+    # reserved words in addition to PIM Modula-2
+    gm2_additional_reserved_words = (
+        # 10 additional reserved words
+        'ASM', '__ATTRIBUTE__', '__BUILTIN__', '__COLUMN__', '__DATE__',
+        '__FILE__', '__FUNCTION__', '__LINE__', '__MODULE__', 'VOLATILE',
+    )
+
+    # GNU Extensions
+    # builtins in addition to PIM Modula-2
+    gm2_additional_builtins = (
+        # 21 additional builtins
+        'BITSET8', 'BITSET16', 'BITSET32', 'CARDINAL8', 'CARDINAL16',
+        'CARDINAL32', 'CARDINAL64', 'COMPLEX32', 'COMPLEX64', 'COMPLEX96',
+        'COMPLEX128', 'INTEGER8', 'INTEGER16', 'INTEGER32', 'INTEGER64',
+        'REAL8', 'REAL16', 'REAL32', 'REAL96', 'REAL128', 'THROW',
+    )
+
+    # GNU Extensions
+    # pseudo-module builtins in addition to PIM Modula-2
+    gm2_additional_pseudo_builtins = (
+        # None
+    )
+
+#  p 1   M o d u l a - 2   D a t a s e t s
+
+    # p1 Extensions
+    # reserved words in addition to ISO Modula-2
+    p1_additional_reserved_words = (
+        # None
+    )
+
+    # p1 Extensions
+    # builtins in addition to ISO Modula-2
+    p1_additional_builtins = (
+        # None
+    )
+
+    # p1 Modula-2 Extensions
+    # pseudo-module builtins in addition to ISO Modula-2
+    p1_additional_pseudo_builtins = (
+        # 1 additional builtin
+        'BCD',
+    )
+
+#  X D S   M o d u l a - 2   D a t a s e t s
+
+    # XDS Extensions
+    # reserved words in addition to ISO Modula-2
+    xds_additional_reserved_words = (
+        # 1 additional reserved word
+        'SEQ',
+    )
+
+    # XDS Extensions
+    # builtins in addition to ISO Modula-2
+    xds_additional_builtins = (
+        # 9 additional builtins
+        'ASH', 'ASSERT', 'DIFFADR_TYPE', 'ENTIER', 'INDEX', 'LEN',
+        'LONGCARD', 'SHORTCARD', 'SHORTINT',
+    )
+
+    # XDS Modula-2 Extensions
+    # pseudo-module builtins in addition to ISO Modula-2
+    xds_additional_pseudo_builtins = (
+        # 22 additional builtins (SYSTEM)
+        'PROCESS', 'NEWPROCESS', 'BOOL8', 'BOOL16', 'BOOL32', 'CARD8',
+        'CARD16', 'CARD32', 'INT8', 'INT16', 'INT32', 'REF', 'MOVE',
+        'FILL', 'GET', 'PUT', 'CC', 'int', 'unsigned', 'size_t', 'void'
+        # 3 additional builtins (COMPILER)
+        'COMPILER', 'OPTION', 'EQUATION'
+    )
+
+#  P I M   S t a n d a r d   L i b r a r y   D a t a s e t s
+
+    # PIM Modula-2 Standard Library Modules Dataset
+    pim_stdlib_module_identifiers = (
+        'Terminal', 'FileSystem', 'InOut', 'RealInOut', 'MathLib0', 'Storage',
+    )
+
+    # PIM Modula-2 Standard Library Types Dataset
+    pim_stdlib_type_identifiers = (
+        'Flag', 'FlagSet', 'Response', 'Command', 'Lock', 'Permission',
+        'MediumType', 'File', 'FileProc', 'DirectoryProc', 'FileCommand',
+        'DirectoryCommand',
+    )
+
+    # PIM Modula-2 Standard Library Procedures Dataset
+    pim_stdlib_proc_identifiers = (
+        'Read', 'BusyRead', 'ReadAgain', 'Write', 'WriteString', 'WriteLn',
+        'Create', 'Lookup', 'Close', 'Delete', 'Rename', 'SetRead', 'SetWrite',
+        'SetModify', 'SetOpen', 'Doio', 'SetPos', 'GetPos', 'Length', 'Reset',
+        'Again', 'ReadWord', 'WriteWord', 'ReadChar', 'WriteChar',
+        'CreateMedium', 'DeleteMedium', 'AssignName', 'DeassignName',
+        'ReadMedium', 'LookupMedium', 'OpenInput', 'OpenOutput', 'CloseInput',
+        'CloseOutput', 'ReadString', 'ReadInt', 'ReadCard', 'ReadWrd',
+        'WriteInt', 'WriteCard', 'WriteOct', 'WriteHex', 'WriteWrd',
+        'ReadReal', 'WriteReal', 'WriteFixPt', 'WriteRealOct', 'sqrt', 'exp',
+        'ln', 'sin', 'cos', 'arctan', 'entier', 'ALLOCATE', 'DEALLOCATE',
+    )
+
+    # PIM Modula-2 Standard Library Variables Dataset
+    pim_stdlib_var_identifiers = (
+        'Done', 'termCH', 'in', 'out'
+    )
+
+    # PIM Modula-2 Standard Library Constants Dataset
+    pim_stdlib_const_identifiers = (
+        'EOL',
+    )
+
+#  I S O   S t a n d a r d   L i b r a r y   D a t a s e t s
+
+    # ISO Modula-2 Standard Library Modules Dataset
+    iso_stdlib_module_identifiers = (
+        # TO DO
+    )
+
+    # ISO Modula-2 Standard Library Types Dataset
+    iso_stdlib_type_identifiers = (
+        # TO DO
+    )
+
+    # ISO Modula-2 Standard Library Procedures Dataset
+    iso_stdlib_proc_identifiers = (
+        # TO DO
+    )
+
+    # ISO Modula-2 Standard Library Variables Dataset
+    iso_stdlib_var_identifiers = (
+        # TO DO
+    )
+
+    # ISO Modula-2 Standard Library Constants Dataset
+    iso_stdlib_const_identifiers = (
+        # TO DO
+    )
+
+#  M 2   R 1 0   S t a n d a r d   L i b r a r y   D a t a s e t s
+
+    # Modula-2 R10 Standard Library ADTs Dataset
+    m2r10_stdlib_adt_identifiers = (
+        'BCD', 'LONGBCD', 'BITSET', 'SHORTBITSET', 'LONGBITSET',
+        'LONGLONGBITSET', 'COMPLEX', 'LONGCOMPLEX', 'SHORTCARD', 'LONGLONGCARD',
+        'SHORTINT', 'LONGLONGINT', 'POSINT', 'SHORTPOSINT', 'LONGPOSINT',
+        'LONGLONGPOSINT', 'BITSET8', 'BITSET16', 'BITSET32', 'BITSET64',
+        'BITSET128', 'BS8', 'BS16', 'BS32', 'BS64', 'BS128', 'CARDINAL8',
+        'CARDINAL16', 'CARDINAL32', 'CARDINAL64', 'CARDINAL128', 'CARD8',
+        'CARD16', 'CARD32', 'CARD64', 'CARD128', 'INTEGER8', 'INTEGER16',
+        'INTEGER32', 'INTEGER64', 'INTEGER128', 'INT8', 'INT16', 'INT32',
+        'INT64', 'INT128', 'STRING', 'UNISTRING',
+    )
+
+    # Modula-2 R10 Standard Library Blueprints Dataset
+    m2r10_stdlib_blueprint_identifiers = (
+        'ProtoRoot', 'ProtoComputational', 'ProtoNumeric', 'ProtoScalar',
+        'ProtoNonScalar', 'ProtoCardinal', 'ProtoInteger', 'ProtoReal',
+        'ProtoComplex', 'ProtoVector', 'ProtoTuple', 'ProtoCompArray',
+        'ProtoCollection', 'ProtoStaticArray', 'ProtoStaticSet',
+        'ProtoStaticString', 'ProtoArray', 'ProtoString', 'ProtoSet',
+        'ProtoMultiSet', 'ProtoDictionary', 'ProtoMultiDict', 'ProtoExtension',
+        'ProtoIO', 'ProtoCardMath', 'ProtoIntMath', 'ProtoRealMath',
+    )
+
+    # Modula-2 R10 Standard Library Modules Dataset
+    m2r10_stdlib_module_identifiers = (
+        'ASCII', 'BooleanIO', 'CharIO', 'UnicharIO', 'OctetIO',
+        'CardinalIO', 'LongCardIO', 'IntegerIO', 'LongIntIO', 'RealIO',
+        'LongRealIO', 'BCDIO', 'LongBCDIO', 'CardMath', 'LongCardMath',
+        'IntMath', 'LongIntMath', 'RealMath', 'LongRealMath', 'BCDMath',
+        'LongBCDMath', 'FileIO', 'FileSystem', 'Storage', 'IOSupport',
+    )
+
+    # Modula-2 R10 Standard Library Types Dataset
+    m2r10_stdlib_type_identifiers = (
+        'File', 'Status',
+        # TO BE COMPLETED
+    )
+
+    # Modula-2 R10 Standard Library Procedures Dataset
+    m2r10_stdlib_proc_identifiers = (
+        'ALLOCATE', 'DEALLOCATE', 'SIZE',
+        # TO BE COMPLETED
+    )
+
+    # Modula-2 R10 Standard Library Variables Dataset
+    m2r10_stdlib_var_identifiers = (
+        'stdIn', 'stdOut', 'stdErr',
+    )
+
+    # Modula-2 R10 Standard Library Constants Dataset
+    m2r10_stdlib_const_identifiers = (
+        'pi', 'tau',
+    )
+
+#  D i a l e c t s
+
+    # Dialect modes
+    dialects = (
+        'unknown',
+        'm2pim', 'm2iso', 'm2r10', 'objm2',
+        'm2iso+aglet', 'm2pim+gm2', 'm2iso+p1', 'm2iso+xds',
+    )
+
+#   D a t a b a s e s
+
+    # Lexemes to Mark as Errors Database
+    lexemes_to_reject_db = {
+        # Lexemes to reject for unknown dialect
+        'unknown': (
+            # LEAVE THIS EMPTY
+        ),
+        # Lexemes to reject for PIM Modula-2
+        'm2pim': (
+            pim_lexemes_to_reject,
+        ),
+        # Lexemes to reject for ISO Modula-2
+        'm2iso': (
+            iso_lexemes_to_reject,
+        ),
+        # Lexemes to reject for Modula-2 R10
+        'm2r10': (
+            m2r10_lexemes_to_reject,
+        ),
+        # Lexemes to reject for Objective Modula-2
+        'objm2': (
+            objm2_lexemes_to_reject,
+        ),
+        # Lexemes to reject for Aglet Modula-2
+        'm2iso+aglet': (
+            iso_lexemes_to_reject,
+        ),
+        # Lexemes to reject for GNU Modula-2
+        'm2pim+gm2': (
+            pim_lexemes_to_reject,
+        ),
+        # Lexemes to reject for p1 Modula-2
+        'm2iso+p1': (
+            iso_lexemes_to_reject,
+        ),
+        # Lexemes to reject for XDS Modula-2
+        'm2iso+xds': (
+            iso_lexemes_to_reject,
+        ),
+    }
+
+    # Reserved Words Database
+    reserved_words_db = {
+        # Reserved words for unknown dialect
+        'unknown': (
+            common_reserved_words,
+            pim_additional_reserved_words,
+            iso_additional_reserved_words,
+            m2r10_additional_reserved_words,
+        ),
+
+        # Reserved words for PIM Modula-2
+        'm2pim': (
+            common_reserved_words,
+            pim_additional_reserved_words,
+        ),
+
+        # Reserved words for Modula-2 R10
+        'm2iso': (
+            common_reserved_words,
+            iso_additional_reserved_words,
+        ),
+
+        # Reserved words for ISO Modula-2
+        'm2r10': (
+            common_reserved_words,
+            m2r10_additional_reserved_words,
+        ),
+
+        # Reserved words for Objective Modula-2
+        'objm2': (
+            common_reserved_words,
+            m2r10_additional_reserved_words,
+            objm2_additional_reserved_words,
+        ),
+
+        # Reserved words for Aglet Modula-2 Extensions
+        'm2iso+aglet': (
+            common_reserved_words,
+            iso_additional_reserved_words,
+            aglet_additional_reserved_words,
+        ),
+
+        # Reserved words for GNU Modula-2 Extensions
+        'm2pim+gm2': (
+            common_reserved_words,
+            pim_additional_reserved_words,
+            gm2_additional_reserved_words,
+        ),
+
+        # Reserved words for p1 Modula-2 Extensions
+        'm2iso+p1': (
+            common_reserved_words,
+            iso_additional_reserved_words,
+            p1_additional_reserved_words,
+        ),
+
+        # Reserved words for XDS Modula-2 Extensions
+        'm2iso+xds': (
+            common_reserved_words,
+            iso_additional_reserved_words,
+            xds_additional_reserved_words,
+        ),
+    }
+
+    # Builtins Database
+    builtins_db = {
+        # Builtins for unknown dialect
+        'unknown': (
+            common_builtins,
+            pim_additional_builtins,
+            iso_additional_builtins,
+            m2r10_additional_builtins,
+        ),
+
+        # Builtins for PIM Modula-2
+        'm2pim': (
+            common_builtins,
+            pim_additional_builtins,
+        ),
+
+        # Builtins for ISO Modula-2
+        'm2iso': (
+            common_builtins,
+            iso_additional_builtins,
+        ),
+
+        # Builtins for ISO Modula-2
+        'm2r10': (
+            common_builtins,
+            m2r10_additional_builtins,
+        ),
+
+        # Builtins for Objective Modula-2
+        'objm2': (
+            common_builtins,
+            m2r10_additional_builtins,
+            objm2_additional_builtins,
+        ),
+
+        # Builtins for Aglet Modula-2 Extensions
+        'm2iso+aglet': (
+            common_builtins,
+            iso_additional_builtins,
+            aglet_additional_builtins,
+        ),
+
+        # Builtins for GNU Modula-2 Extensions
+        'm2pim+gm2': (
+            common_builtins,
+            pim_additional_builtins,
+            gm2_additional_builtins,
+        ),
+
+        # Builtins for p1 Modula-2 Extensions
+        'm2iso+p1': (
+            common_builtins,
+            iso_additional_builtins,
+            p1_additional_builtins,
+        ),
+
+        # Builtins for XDS Modula-2 Extensions
+        'm2iso+xds': (
+            common_builtins,
+            iso_additional_builtins,
+            xds_additional_builtins,
+        ),
+    }
+
+    # Pseudo-Module Builtins Database
+    pseudo_builtins_db = {
+        # Builtins for unknown dialect
+        'unknown': (
+            common_pseudo_builtins,
+            pim_additional_pseudo_builtins,
+            iso_additional_pseudo_builtins,
+            m2r10_additional_pseudo_builtins,
+        ),
+
+        # Builtins for PIM Modula-2
+        'm2pim': (
+            common_pseudo_builtins,
+            pim_additional_pseudo_builtins,
+        ),
+
+        # Builtins for ISO Modula-2
+        'm2iso': (
+            common_pseudo_builtins,
+            iso_additional_pseudo_builtins,
+        ),
+
+        # Builtins for ISO Modula-2
+        'm2r10': (
+            common_pseudo_builtins,
+            m2r10_additional_pseudo_builtins,
+        ),
+
+        # Builtins for Objective Modula-2
+        'objm2': (
+            common_pseudo_builtins,
+            m2r10_additional_pseudo_builtins,
+            objm2_additional_pseudo_builtins,
+        ),
+
+        # Builtins for Aglet Modula-2 Extensions
+        'm2iso+aglet': (
+            common_pseudo_builtins,
+            iso_additional_pseudo_builtins,
+            aglet_additional_pseudo_builtins,
+        ),
+
+        # Builtins for GNU Modula-2 Extensions
+        'm2pim+gm2': (
+            common_pseudo_builtins,
+            pim_additional_pseudo_builtins,
+            gm2_additional_pseudo_builtins,
+        ),
+
+        # Builtins for p1 Modula-2 Extensions
+        'm2iso+p1': (
+            common_pseudo_builtins,
+            iso_additional_pseudo_builtins,
+            p1_additional_pseudo_builtins,
+        ),
+
+        # Builtins for XDS Modula-2 Extensions
+        'm2iso+xds': (
+            common_pseudo_builtins,
+            iso_additional_pseudo_builtins,
+            xds_additional_pseudo_builtins,
+        ),
+    }
+
+    # Standard Library ADTs Database
+    stdlib_adts_db = {
+        # Empty entry for unknown dialect
+        'unknown': (
+            # LEAVE THIS EMPTY
+        ),
+        # Standard Library ADTs for PIM Modula-2
+        'm2pim': (
+            # No first class library types
+        ),
+
+        # Standard Library ADTs for ISO Modula-2
+        'm2iso': (
+            # No first class library types
+        ),
+
+        # Standard Library ADTs for Modula-2 R10
+        'm2r10': (
+            m2r10_stdlib_adt_identifiers,
+        ),
+
+        # Standard Library ADTs for Objective Modula-2
+        'objm2': (
+            m2r10_stdlib_adt_identifiers,
+        ),
+
+        # Standard Library ADTs for Aglet Modula-2
+        'm2iso+aglet': (
+            # No first class library types
+        ),
+
+        # Standard Library ADTs for GNU Modula-2
+        'm2pim+gm2': (
+            # No first class library types
+        ),
+
+        # Standard Library ADTs for p1 Modula-2
+        'm2iso+p1': (
+            # No first class library types
+        ),
+
+        # Standard Library ADTs for XDS Modula-2
+        'm2iso+xds': (
+            # No first class library types
+        ),
+    }
+
+    # Standard Library Modules Database
+    stdlib_modules_db = {
+        # Empty entry for unknown dialect
+        'unknown': (
+            # LEAVE THIS EMPTY
+        ),
+        # Standard Library Modules for PIM Modula-2
+        'm2pim': (
+            pim_stdlib_module_identifiers,
+        ),
+
+        # Standard Library Modules for ISO Modula-2
+        'm2iso': (
+            iso_stdlib_module_identifiers,
+        ),
+
+        # Standard Library Modules for Modula-2 R10
+        'm2r10': (
+            m2r10_stdlib_blueprint_identifiers,
+            m2r10_stdlib_module_identifiers,
+            m2r10_stdlib_adt_identifiers,
+        ),
+
+        # Standard Library Modules for Objective Modula-2
+        'objm2': (
+            m2r10_stdlib_blueprint_identifiers,
+            m2r10_stdlib_module_identifiers,
+        ),
+
+        # Standard Library Modules for Aglet Modula-2
+        'm2iso+aglet': (
+            iso_stdlib_module_identifiers,
+        ),
+
+        # Standard Library Modules for GNU Modula-2
+        'm2pim+gm2': (
+            pim_stdlib_module_identifiers,
+        ),
+
+        # Standard Library Modules for p1 Modula-2
+        'm2iso+p1': (
+            iso_stdlib_module_identifiers,
+        ),
+
+        # Standard Library Modules for XDS Modula-2
+        'm2iso+xds': (
+            iso_stdlib_module_identifiers,
+        ),
+    }
+
+    # Standard Library Types Database
+    stdlib_types_db = {
+        # Empty entry for unknown dialect
+        'unknown': (
+            # LEAVE THIS EMPTY
+        ),
+        # Standard Library Types for PIM Modula-2
+        'm2pim': (
+            pim_stdlib_type_identifiers,
+        ),
+
+        # Standard Library Types for ISO Modula-2
+        'm2iso': (
+            iso_stdlib_type_identifiers,
+        ),
+
+        # Standard Library Types for Modula-2 R10
+        'm2r10': (
+            m2r10_stdlib_type_identifiers,
+        ),
+
+        # Standard Library Types for Objective Modula-2
+        'objm2': (
+            m2r10_stdlib_type_identifiers,
+        ),
+
+        # Standard Library Types for Aglet Modula-2
+        'm2iso+aglet': (
+            iso_stdlib_type_identifiers,
+        ),
+
+        # Standard Library Types for GNU Modula-2
+        'm2pim+gm2': (
+            pim_stdlib_type_identifiers,
+        ),
+
+        # Standard Library Types for p1 Modula-2
+        'm2iso+p1': (
+            iso_stdlib_type_identifiers,
+        ),
+
+        # Standard Library Types for XDS Modula-2
+        'm2iso+xds': (
+            iso_stdlib_type_identifiers,
+        ),
+    }
+
+    # Standard Library Procedures Database
+    stdlib_procedures_db = {
+        # Empty entry for unknown dialect
+        'unknown': (
+            # LEAVE THIS EMPTY
+        ),
+        # Standard Library Procedures for PIM Modula-2
+        'm2pim': (
+            pim_stdlib_proc_identifiers,
+        ),
+
+        # Standard Library Procedures for ISO Modula-2
+        'm2iso': (
+            iso_stdlib_proc_identifiers,
+        ),
+
+        # Standard Library Procedures for Modula-2 R10
+        'm2r10': (
+            m2r10_stdlib_proc_identifiers,
+        ),
+
+        # Standard Library Procedures for Objective Modula-2
+        'objm2': (
+            m2r10_stdlib_proc_identifiers,
+        ),
+
+        # Standard Library Procedures for Aglet Modula-2
+        'm2iso+aglet': (
+            iso_stdlib_proc_identifiers,
+        ),
+
+        # Standard Library Procedures for GNU Modula-2
+        'm2pim+gm2': (
+            pim_stdlib_proc_identifiers,
+        ),
+
+        # Standard Library Procedures for p1 Modula-2
+        'm2iso+p1': (
+            iso_stdlib_proc_identifiers,
+        ),
+
+        # Standard Library Procedures for XDS Modula-2
+        'm2iso+xds': (
+            iso_stdlib_proc_identifiers,
+        ),
+    }
+
+    # Standard Library Variables Database
+    stdlib_variables_db = {
+        # Empty entry for unknown dialect
+        'unknown': (
+            # LEAVE THIS EMPTY
+        ),
+        # Standard Library Variables for PIM Modula-2
+        'm2pim': (
+            pim_stdlib_var_identifiers,
+        ),
+
+        # Standard Library Variables for ISO Modula-2
+        'm2iso': (
+            iso_stdlib_var_identifiers,
+        ),
+
+        # Standard Library Variables for Modula-2 R10
+        'm2r10': (
+            m2r10_stdlib_var_identifiers,
+        ),
+
+        # Standard Library Variables for Objective Modula-2
+        'objm2': (
+            m2r10_stdlib_var_identifiers,
+        ),
+
+        # Standard Library Variables for Aglet Modula-2
+        'm2iso+aglet': (
+            iso_stdlib_var_identifiers,
+        ),
+
+        # Standard Library Variables for GNU Modula-2
+        'm2pim+gm2': (
+            pim_stdlib_var_identifiers,
+        ),
+
+        # Standard Library Variables for p1 Modula-2
+        'm2iso+p1': (
+            iso_stdlib_var_identifiers,
+        ),
+
+        # Standard Library Variables for XDS Modula-2
+        'm2iso+xds': (
+            iso_stdlib_var_identifiers,
+        ),
+    }
+
+    # Standard Library Constants Database
+    stdlib_constants_db = {
+        # Empty entry for unknown dialect
+        'unknown': (
+            # LEAVE THIS EMPTY
+        ),
+        # Standard Library Constants for PIM Modula-2
+        'm2pim': (
+            pim_stdlib_const_identifiers,
+        ),
+
+        # Standard Library Constants for ISO Modula-2
+        'm2iso': (
+            iso_stdlib_const_identifiers,
+        ),
+
+        # Standard Library Constants for Modula-2 R10
+        'm2r10': (
+            m2r10_stdlib_const_identifiers,
+        ),
+
+        # Standard Library Constants for Objective Modula-2
+        'objm2': (
+            m2r10_stdlib_const_identifiers,
+        ),
+
+        # Standard Library Constants for Aglet Modula-2
+        'm2iso+aglet': (
+            iso_stdlib_const_identifiers,
+        ),
+
+        # Standard Library Constants for GNU Modula-2
+        'm2pim+gm2': (
+            pim_stdlib_const_identifiers,
+        ),
+
+        # Standard Library Constants for p1 Modula-2
+        'm2iso+p1': (
+            iso_stdlib_const_identifiers,
+        ),
+
+        # Standard Library Constants for XDS Modula-2
+        'm2iso+xds': (
+            iso_stdlib_const_identifiers,
+        ),
+    }
+
+#   M e t h o d s
+
+    # initialise a lexer instance
+    def __init__(self, **options):
+        #
+        # check dialect options
+        #
+        dialects = get_list_opt(options, 'dialect', [])
+        #
+        for dialect_option in dialects:
+            if dialect_option in self.dialects[1:-1]:
+                # valid dialect option found
+                self.set_dialect(dialect_option)
+                break
+        #
+        # Fallback Mode (DEFAULT)
+        else:
+            # no valid dialect option
+            self.set_dialect('unknown')
+        #
+        self.dialect_set_by_tag = False
+        #
+        # check style options
+        #
+        styles = get_list_opt(options, 'style', [])
+        #
+        # use lowercase mode for Algol style
+        if 'algol' in styles or 'algol_nu' in styles:
+            self.algol_publication_mode = True
+        else:
+            self.algol_publication_mode = False
+        #
+        # Check option flags
+        #
+        self.treat_stdlib_adts_as_builtins = get_bool_opt(
+            options, 'treat_stdlib_adts_as_builtins', True)
+        #
+        # call superclass initialiser
+        RegexLexer.__init__(self, **options)
+
+    # Set lexer to a specified dialect
+    def set_dialect(self, dialect_id):
+        #
+        # if __debug__:
+        #    print 'entered set_dialect with arg: ', dialect_id
+        #
+        # check dialect name against known dialects
+        if dialect_id not in self.dialects:
+            dialect = 'unknown'  # default
+        else:
+            dialect = dialect_id
+        #
+        # compose lexemes to reject set
+        lexemes_to_reject_set = set()
+        # add each list of reject lexemes for this dialect
+        for list in self.lexemes_to_reject_db[dialect]:
+            lexemes_to_reject_set.update(set(list))
+        #
+        # compose reserved words set
+        reswords_set = set()
+        # add each list of reserved words for this dialect
+        for list in self.reserved_words_db[dialect]:
+            reswords_set.update(set(list))
+        #
+        # compose builtins set
+        builtins_set = set()
+        # add each list of builtins for this dialect excluding reserved words
+        for list in self.builtins_db[dialect]:
+            builtins_set.update(set(list).difference(reswords_set))
+        #
+        # compose pseudo-builtins set
+        pseudo_builtins_set = set()
+        # add each list of builtins for this dialect excluding reserved words
+        for list in self.pseudo_builtins_db[dialect]:
+            pseudo_builtins_set.update(set(list).difference(reswords_set))
+        #
+        # compose ADTs set
+        adts_set = set()
+        # add each list of ADTs for this dialect excluding reserved words
+        for list in self.stdlib_adts_db[dialect]:
+            adts_set.update(set(list).difference(reswords_set))
+        #
+        # compose modules set
+        modules_set = set()
+        # add each list of builtins for this dialect excluding builtins
+        for list in self.stdlib_modules_db[dialect]:
+            modules_set.update(set(list).difference(builtins_set))
+        #
+        # compose types set
+        types_set = set()
+        # add each list of types for this dialect excluding builtins
+        for list in self.stdlib_types_db[dialect]:
+            types_set.update(set(list).difference(builtins_set))
+        #
+        # compose procedures set
+        procedures_set = set()
+        # add each list of procedures for this dialect excluding builtins
+        for list in self.stdlib_procedures_db[dialect]:
+            procedures_set.update(set(list).difference(builtins_set))
+        #
+        # compose variables set
+        variables_set = set()
+        # add each list of variables for this dialect excluding builtins
+        for list in self.stdlib_variables_db[dialect]:
+            variables_set.update(set(list).difference(builtins_set))
+        #
+        # compose constants set
+        constants_set = set()
+        # add each list of constants for this dialect excluding builtins
+        for list in self.stdlib_constants_db[dialect]:
+            constants_set.update(set(list).difference(builtins_set))
+        #
+        # update lexer state
+        self.dialect = dialect
+        self.lexemes_to_reject = lexemes_to_reject_set
+        self.reserved_words = reswords_set
+        self.builtins = builtins_set
+        self.pseudo_builtins = pseudo_builtins_set
+        self.adts = adts_set
+        self.modules = modules_set
+        self.types = types_set
+        self.procedures = procedures_set
+        self.variables = variables_set
+        self.constants = constants_set
+        #
+        # if __debug__:
+        #    print 'exiting set_dialect'
+        #    print ' self.dialect: ', self.dialect
+        #    print ' self.lexemes_to_reject: ', self.lexemes_to_reject
+        #    print ' self.reserved_words: ', self.reserved_words
+        #    print ' self.builtins: ', self.builtins
+        #    print ' self.pseudo_builtins: ', self.pseudo_builtins
+        #    print ' self.adts: ', self.adts
+        #    print ' self.modules: ', self.modules
+        #    print ' self.types: ', self.types
+        #    print ' self.procedures: ', self.procedures
+        #    print ' self.variables: ', self.variables
+        #    print ' self.types: ', self.types
+        #    print ' self.constants: ', self.constants
+
+    # Extracts a dialect name from a dialect tag comment string  and checks
+    # the extracted name against known dialects.  If a match is found,  the
+    # matching name is returned, otherwise dialect id 'unknown' is returned
+    def get_dialect_from_dialect_tag(self, dialect_tag):
+        #
+        # if __debug__:
+        #    print 'entered get_dialect_from_dialect_tag with arg: ', dialect_tag
+        #
+        # constants
+        left_tag_delim = '(*!'
+        right_tag_delim = '*)'
+        left_tag_delim_len = len(left_tag_delim)
+        right_tag_delim_len = len(right_tag_delim)
+        indicator_start = left_tag_delim_len
+        indicator_end = -(right_tag_delim_len)
+        #
+        # check comment string for dialect indicator
+        if len(dialect_tag) > (left_tag_delim_len + right_tag_delim_len) \
+           and dialect_tag.startswith(left_tag_delim) \
+           and dialect_tag.endswith(right_tag_delim):
+            #
+            # if __debug__:
+            #    print 'dialect tag found'
+            #
+            # extract dialect indicator
+            indicator = dialect_tag[indicator_start:indicator_end]
+            #
+            # if __debug__:
+            #    print 'extracted: ', indicator
+            #
+            # check against known dialects
+            for index in range(1, len(self.dialects)):
+                #
+                # if __debug__:
+                #    print 'dialects[', index, ']: ', self.dialects[index]
+                #
+                if indicator == self.dialects[index]:
+                    #
+                    # if __debug__:
+                    #    print 'matching dialect found'
+                    #
+                    # indicator matches known dialect
+                    return indicator
+            else:
+                # indicator does not match any dialect
+                return 'unknown'  # default
+        else:
+            # invalid indicator string
+            return 'unknown'  # default
+
+    # intercept the token stream, modify token attributes and return them
+    def get_tokens_unprocessed(self, text):
+        for index, token, value in RegexLexer.get_tokens_unprocessed(self, text):
+            #
+            # check for dialect tag if dialect has not been set by tag
+            if not self.dialect_set_by_tag and token == Comment.Special:
+                indicated_dialect = self.get_dialect_from_dialect_tag(value)
+                if indicated_dialect != 'unknown':
+                    # token is a dialect indicator
+                    # reset reserved words and builtins
+                    self.set_dialect(indicated_dialect)
+                    self.dialect_set_by_tag = True
+            #
+            # check for reserved words, predefined and stdlib identifiers
+            if token is Name:
+                if value in self.reserved_words:
+                    token = Keyword.Reserved
+                    if self.algol_publication_mode:
+                        value = value.lower()
+                #
+                elif value in self.builtins:
+                    token = Name.Builtin
+                    if self.algol_publication_mode:
+                        value = value.lower()
+                #
+                elif value in self.pseudo_builtins:
+                    token = Name.Builtin.Pseudo
+                    if self.algol_publication_mode:
+                        value = value.lower()
+                #
+                elif value in self.adts:
+                    if not self.treat_stdlib_adts_as_builtins:
+                        token = Name.Namespace
+                    else:
+                        token = Name.Builtin.Pseudo
+                        if self.algol_publication_mode:
+                            value = value.lower()
+                #
+                elif value in self.modules:
+                    token = Name.Namespace
+                #
+                elif value in self.types:
+                    token = Name.Class
+                #
+                elif value in self.procedures:
+                    token = Name.Function
+                #
+                elif value in self.variables:
+                    token = Name.Variable
+                #
+                elif value in self.constants:
+                    token = Name.Constant
+            #
+            elif token in Number:
+                #
+                # mark prefix number literals as error for PIM and ISO dialects
+                if self.dialect not in ('unknown', 'm2r10', 'objm2'):
+                    if "'" in value or value[0:2] in ('0b', '0x', '0u'):
+                        token = Error
+                #
+                elif self.dialect in ('m2r10', 'objm2'):
+                    # mark base-8 number literals as errors for M2 R10 and ObjM2
+                    if token is Number.Oct:
+                        token = Error
+                    # mark suffix base-16 literals as errors for M2 R10 and ObjM2
+                    elif token is Number.Hex and 'H' in value:
+                        token = Error
+                    # mark real numbers with E as errors for M2 R10 and ObjM2
+                    elif token is Number.Float and 'E' in value:
+                        token = Error
+            #
+            elif token in Comment:
+                #
+                # mark single line comment as error for PIM and ISO dialects
+                if token is Comment.Single:
+                    if self.dialect not in ('unknown', 'm2r10', 'objm2'):
+                        token = Error
+                #
+                if token is Comment.Preproc:
+                    # mark ISO pragma as error for PIM dialects
+                    if value.startswith('<*') and \
+                       self.dialect.startswith('m2pim'):
+                        token = Error
+                    # mark PIM pragma as comment for other dialects
+                    elif value.startswith('(*$') and \
+                            self.dialect != 'unknown' and \
+                            not self.dialect.startswith('m2pim'):
+                        token = Comment.Multiline
+            #
+            else:  # token is neither Name nor Comment
+                #
+                # mark lexemes matching the dialect's error token set as errors
+                if value in self.lexemes_to_reject:
+                    token = Error
+                #
+                # substitute lexemes when in Algol mode
+                if self.algol_publication_mode:
+                    if value == '#':
+                        value = '≠'
+                    elif value == '<=':
+                        value = '≤'
+                    elif value == '>=':
+                        value = '≥'
+                    elif value == '==':
+                        value = '≡'
+                    elif value == '*.':
+                        value = '•'
+
+            # return result
+            yield index, token, value
+
+    def analyse_text(text):
+        """It's Pascal-like, but does not use FUNCTION -- uses PROCEDURE
+        instead."""
+
+        # Check if this looks like Pascal, if not, bail out early
+        if not ('(*' in text and '*)' in text and ':=' in text):
+            return
+
+        result = 0
+        # Procedure is in Modula2
+        if re.search(r'\bPROCEDURE\b', text):
+            result += 0.6
+
+        # FUNCTION is only valid in Pascal, but not in Modula2
+        if re.search(r'\bFUNCTION\b', text):
+            result = 0.0
+
+        return result
diff --git a/pygments/lexers/mojo.py b/pygments/lexers/mojo.py
new file mode 100644
index 0000000000000000000000000000000000000000..4df18c4f9c94520ab852ba2f66bb50518eb64be7
--- /dev/null
+++ b/pygments/lexers/mojo.py
@@ -0,0 +1,707 @@
+"""
+    pygments.lexers.mojo
+    ~~~~~~~~~~~~~~~~~~~~
+
+    Lexers for Mojo and related languages.
+
+    :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS.
+    :license: BSD, see LICENSE for details.
+"""
+
+import keyword
+
+from pygments import unistring as uni
+from pygments.lexer import (
+    RegexLexer,
+    bygroups,
+    combined,
+    default,
+    include,
+    this,
+    using,
+    words,
+)
+from pygments.token import (
+    Comment,
+    # Error,
+    Keyword,
+    Name,
+    Number,
+    Operator,
+    Punctuation,
+    String,
+    Text,
+    Whitespace,
+)
+from pygments.util import shebang_matches
+
+__all__ = ["MojoLexer"]
+
+
+class MojoLexer(RegexLexer):
+    """
+    For Mojo source code (version 24.2.1).
+    """
+
+    name = "Mojo"
+    url = "https://docs.modular.com/mojo/"
+    aliases = ["mojo", "🔥"]
+    filenames = [
+        "*.mojo",
+        "*.🔥",
+    ]
+    mimetypes = [
+        "text/x-mojo",
+        "application/x-mojo",
+    ]
+    version_added = "2.18"
+
+    uni_name = f"[{uni.xid_start}][{uni.xid_continue}]*"
+
+    def innerstring_rules(ttype):
+        return [
+            # the old style '%s' % (...) string formatting (still valid in Py3)
+            (
+                r"%(\(\w+\))?[-#0 +]*([0-9]+|[*])?(\.([0-9]+|[*]))?"
+                "[hlL]?[E-GXc-giorsaux%]",
+                String.Interpol,
+            ),
+            # the new style '{}'.format(...) string formatting
+            (
+                r"\{"
+                r"((\w+)((\.\w+)|(\[[^\]]+\]))*)?"  # field name
+                r"(\![sra])?"  # conversion
+                r"(\:(.?[<>=\^])?[-+ ]?#?0?(\d+)?,?(\.\d+)?[E-GXb-gnosx%]?)?"
+                r"\}",
+                String.Interpol,
+            ),
+            # backslashes, quotes and formatting signs must be parsed one at a time
+            (r'[^\\\'"%{\n]+', ttype),
+            (r'[\'"\\]', ttype),
+            # unhandled string formatting sign
+            (r"%|(\{{1,2})", ttype),
+            # newlines are an error (use "nl" state)
+        ]
+
+    def fstring_rules(ttype):
+        return [
+            # Assuming that a '}' is the closing brace after format specifier.
+            # Sadly, this means that we won't detect syntax error. But it's
+            # more important to parse correct syntax correctly, than to
+            # highlight invalid syntax.
+            (r"\}", String.Interpol),
+            (r"\{", String.Interpol, "expr-inside-fstring"),
+            # backslashes, quotes and formatting signs must be parsed one at a time
+            (r'[^\\\'"{}\n]+', ttype),
+            (r'[\'"\\]', ttype),
+            # newlines are an error (use "nl" state)
+        ]
+
+    tokens = {
+        "root": [
+            (r"\s+", Whitespace),
+            (
+                r'^(\s*)([rRuUbB]{,2})("""(?:.|\n)*?""")',
+                bygroups(Whitespace, String.Affix, String.Doc),
+            ),
+            (
+                r"^(\s*)([rRuUbB]{,2})('''(?:.|\n)*?''')",
+                bygroups(Whitespace, String.Affix, String.Doc),
+            ),
+            (r"\A#!.+$", Comment.Hashbang),
+            (r"#.*$", Comment.Single),
+            (r"\\\n", Whitespace),
+            (r"\\", Whitespace),
+            include("keywords"),
+            include("soft-keywords"),
+            # In the original PR, all the below here used ((?:\s|\\\s)+) to
+            # designate whitespace, but I can't find any example of this being
+            # needed in the example file, so we're replacing it with `\s+`.
+            (
+                r"(alias)(\s+)",
+                bygroups(Keyword, Whitespace),
+                "varname",  # TODO varname the right fit?
+            ),
+            (r"(var)(\s+)", bygroups(Keyword, Whitespace), "varname"),
+            (r"(def)(\s+)", bygroups(Keyword, Whitespace), "funcname"),
+            (r"(fn)(\s+)", bygroups(Keyword, Whitespace), "funcname"),
+            (
+                r"(class)(\s+)",
+                bygroups(Keyword, Whitespace),
+                "classname",
+            ),  # not implemented yet
+            (r"(struct)(\s+)", bygroups(Keyword, Whitespace), "structname"),
+            (r"(trait)(\s+)", bygroups(Keyword, Whitespace), "structname"),
+            (r"(from)(\s+)", bygroups(Keyword.Namespace, Whitespace), "fromimport"),
+            (r"(import)(\s+)", bygroups(Keyword.Namespace, Whitespace), "import"),
+            include("expr"),
+        ],
+        "expr": [
+            # raw f-strings
+            (
+                '(?i)(rf|fr)(""")',
+                bygroups(String.Affix, String.Double),
+                combined("rfstringescape", "tdqf"),
+            ),
+            (
+                "(?i)(rf|fr)(''')",
+                bygroups(String.Affix, String.Single),
+                combined("rfstringescape", "tsqf"),
+            ),
+            (
+                '(?i)(rf|fr)(")',
+                bygroups(String.Affix, String.Double),
+                combined("rfstringescape", "dqf"),
+            ),
+            (
+                "(?i)(rf|fr)(')",
+                bygroups(String.Affix, String.Single),
+                combined("rfstringescape", "sqf"),
+            ),
+            # non-raw f-strings
+            (
+                '([fF])(""")',
+                bygroups(String.Affix, String.Double),
+                combined("fstringescape", "tdqf"),
+            ),
+            (
+                "([fF])(''')",
+                bygroups(String.Affix, String.Single),
+                combined("fstringescape", "tsqf"),
+            ),
+            (
+                '([fF])(")',
+                bygroups(String.Affix, String.Double),
+                combined("fstringescape", "dqf"),
+            ),
+            (
+                "([fF])(')",
+                bygroups(String.Affix, String.Single),
+                combined("fstringescape", "sqf"),
+            ),
+            # raw bytes and strings
+            ('(?i)(rb|br|r)(""")', bygroups(String.Affix, String.Double), "tdqs"),
+            ("(?i)(rb|br|r)(''')", bygroups(String.Affix, String.Single), "tsqs"),
+            ('(?i)(rb|br|r)(")', bygroups(String.Affix, String.Double), "dqs"),
+            ("(?i)(rb|br|r)(')", bygroups(String.Affix, String.Single), "sqs"),
+            # non-raw strings
+            (
+                '([uU]?)(""")',
+                bygroups(String.Affix, String.Double),
+                combined("stringescape", "tdqs"),
+            ),
+            (
+                "([uU]?)(''')",
+                bygroups(String.Affix, String.Single),
+                combined("stringescape", "tsqs"),
+            ),
+            (
+                '([uU]?)(")',
+                bygroups(String.Affix, String.Double),
+                combined("stringescape", "dqs"),
+            ),
+            (
+                "([uU]?)(')",
+                bygroups(String.Affix, String.Single),
+                combined("stringescape", "sqs"),
+            ),
+            # non-raw bytes
+            (
+                '([bB])(""")',
+                bygroups(String.Affix, String.Double),
+                combined("bytesescape", "tdqs"),
+            ),
+            (
+                "([bB])(''')",
+                bygroups(String.Affix, String.Single),
+                combined("bytesescape", "tsqs"),
+            ),
+            (
+                '([bB])(")',
+                bygroups(String.Affix, String.Double),
+                combined("bytesescape", "dqs"),
+            ),
+            (
+                "([bB])(')",
+                bygroups(String.Affix, String.Single),
+                combined("bytesescape", "sqs"),
+            ),
+            (r"[^\S\n]+", Text),
+            include("numbers"),
+            (r"!=|==|<<|>>|:=|[-~+/*%=<>&^|.]", Operator),
+            (r"([]{}:\(\),;[])+", Punctuation),
+            (r"(in|is|and|or|not)\b", Operator.Word),
+            include("expr-keywords"),
+            include("builtins"),
+            include("magicfuncs"),
+            include("magicvars"),
+            include("name"),
+        ],
+        "expr-inside-fstring": [
+            (r"[{([]", Punctuation, "expr-inside-fstring-inner"),
+            # without format specifier
+            (
+                r"(=\s*)?"  # debug (https://bugs.python.org/issue36817)
+                r"(\![sraf])?"  # conversion
+                r"\}",
+                String.Interpol,
+                "#pop",
+            ),
+            # with format specifier
+            # we'll catch the remaining '}' in the outer scope
+            (
+                r"(=\s*)?"  # debug (https://bugs.python.org/issue36817)
+                r"(\![sraf])?"  # conversion
+                r":",
+                String.Interpol,
+                "#pop",
+            ),
+            (r"\s+", Whitespace),  # allow new lines
+            include("expr"),
+        ],
+        "expr-inside-fstring-inner": [
+            (r"[{([]", Punctuation, "expr-inside-fstring-inner"),
+            (r"[])}]", Punctuation, "#pop"),
+            (r"\s+", Whitespace),  # allow new lines
+            include("expr"),
+        ],
+        "expr-keywords": [
+            # Based on https://docs.python.org/3/reference/expressions.html
+            (
+                words(
+                    (
+                        "async for",  # TODO https://docs.modular.com/mojo/roadmap#no-async-for-or-async-with
+                        "async with",  # TODO https://docs.modular.com/mojo/roadmap#no-async-for-or-async-with
+                        "await",
+                        "else",
+                        "for",
+                        "if",
+                        "lambda",
+                        "yield",
+                        "yield from",
+                    ),
+                    suffix=r"\b",
+                ),
+                Keyword,
+            ),
+            (words(("True", "False", "None"), suffix=r"\b"), Keyword.Constant),
+        ],
+        "keywords": [
+            (
+                words(
+                    (
+                        "assert",
+                        "async",
+                        "await",
+                        "borrowed",
+                        "break",
+                        "continue",
+                        "del",
+                        "elif",
+                        "else",
+                        "except",
+                        "finally",
+                        "for",
+                        "global",
+                        "if",
+                        "lambda",
+                        "pass",
+                        "raise",
+                        "nonlocal",
+                        "return",
+                        "try",
+                        "while",
+                        "yield",
+                        "yield from",
+                        "as",
+                        "with",
+                    ),
+                    suffix=r"\b",
+                ),
+                Keyword,
+            ),
+            (words(("True", "False", "None"), suffix=r"\b"), Keyword.Constant),
+        ],
+        "soft-keywords": [
+            # `match`, `case` and `_` soft keywords
+            (
+                r"(^[ \t]*)"  # at beginning of line + possible indentation
+                r"(match|case)\b"  # a possible keyword
+                r"(?![ \t]*(?:"  # not followed by...
+                r"[:,;=^&|@~)\]}]|(?:" +  # characters and keywords that mean this isn't
+                # pattern matching (but None/True/False is ok)
+                r"|".join(k for k in keyword.kwlist if k[0].islower())
+                + r")\b))",
+                bygroups(Whitespace, Keyword),
+                "soft-keywords-inner",
+            ),
+        ],
+        "soft-keywords-inner": [
+            # optional `_` keyword
+            (r"(\s+)([^\n_]*)(_\b)", bygroups(Whitespace, using(this), Keyword)),
+            default("#pop"),
+        ],
+        "builtins": [
+            (
+                words(
+                    (
+                        "__import__",
+                        "abs",
+                        "aiter",
+                        "all",
+                        "any",
+                        "bin",
+                        "bool",
+                        "bytearray",
+                        "breakpoint",
+                        "bytes",
+                        "callable",
+                        "chr",
+                        "classmethod",
+                        "compile",
+                        "complex",
+                        "delattr",
+                        "dict",
+                        "dir",
+                        "divmod",
+                        "enumerate",
+                        "eval",
+                        "filter",
+                        "float",
+                        "format",
+                        "frozenset",
+                        "getattr",
+                        "globals",
+                        "hasattr",
+                        "hash",
+                        "hex",
+                        "id",
+                        "input",
+                        "int",
+                        "isinstance",
+                        "issubclass",
+                        "iter",
+                        "len",
+                        "list",
+                        "locals",
+                        "map",
+                        "max",
+                        "memoryview",
+                        "min",
+                        "next",
+                        "object",
+                        "oct",
+                        "open",
+                        "ord",
+                        "pow",
+                        "print",
+                        "property",
+                        "range",
+                        "repr",
+                        "reversed",
+                        "round",
+                        "set",
+                        "setattr",
+                        "slice",
+                        "sorted",
+                        "staticmethod",
+                        "str",
+                        "sum",
+                        "super",
+                        "tuple",
+                        "type",
+                        "vars",
+                        "zip",
+                        # Mojo builtin types: https://docs.modular.com/mojo/stdlib/builtin/
+                        "AnyType",
+                        "Coroutine",
+                        "DType",
+                        "Error",
+                        "Int",
+                        "List",
+                        "ListLiteral",
+                        "Scalar",
+                        "Int8",
+                        "UInt8",
+                        "Int16",
+                        "UInt16",
+                        "Int32",
+                        "UInt32",
+                        "Int64",
+                        "UInt64",
+                        "BFloat16",
+                        "Float16",
+                        "Float32",
+                        "Float64",
+                        "SIMD",
+                        "String",
+                        "Tensor",
+                        "Tuple",
+                        "Movable",
+                        "Copyable",
+                        "CollectionElement",
+                    ),
+                    prefix=r"(?>',
+    # Binary augmented
+    '+=', '-=', '*=', '/=', '%=', '**=', '&=', '|=', '^=', '<<=', '>>=',
+    # Comparison
+    '==', '!=', '<', '<=', '>', '>=', '<=>',
+    # Patterns and assignment
+    ':=', '?', '=~', '!~', '=>',
+    # Calls and sends
+    '.', '<-', '->',
+]
+_escape_pattern = (
+    r'(?:\\x[0-9a-fA-F]{2}|\\u[0-9a-fA-F]{4}|\\U[0-9a-fA-F]{8}|'
+    r'\\["\'\\bftnr])')
+# _char = _escape_chars + [('.', String.Char)]
+_identifier = r'[_a-zA-Z]\w*'
+
+_constants = [
+    # Void constants
+    'null',
+    # Bool constants
+    'false', 'true',
+    # Double constants
+    'Infinity', 'NaN',
+    # Special objects
+    'M', 'Ref', 'throw', 'traceln',
+]
+
+_guards = [
+    'Any', 'Binding', 'Bool', 'Bytes', 'Char', 'DeepFrozen', 'Double',
+    'Empty', 'Int', 'List', 'Map', 'Near', 'NullOk', 'Same', 'Selfless',
+    'Set', 'Str', 'SubrangeGuard', 'Transparent', 'Void',
+]
+
+_safeScope = [
+    '_accumulateList', '_accumulateMap', '_auditedBy', '_bind',
+    '_booleanFlow', '_comparer', '_equalizer', '_iterForever', '_loop',
+    '_makeBytes', '_makeDouble', '_makeFinalSlot', '_makeInt', '_makeList',
+    '_makeMap', '_makeMessageDesc', '_makeOrderedSpace', '_makeParamDesc',
+    '_makeProtocolDesc', '_makeSourceSpan', '_makeString', '_makeVarSlot',
+    '_makeVerbFacet', '_mapExtract', '_matchSame', '_quasiMatcher',
+    '_slotToBinding', '_splitList', '_suchThat', '_switchFailed',
+    '_validateFor', 'b__quasiParser', 'eval', 'import', 'm__quasiParser',
+    'makeBrandPair', 'makeLazySlot', 'safeScope', 'simple__quasiParser',
+]
+
+
+class MonteLexer(RegexLexer):
+    """
+    Lexer for the Monte programming language.
+    """
+    name = 'Monte'
+    url = 'https://monte.readthedocs.io/'
+    aliases = ['monte']
+    filenames = ['*.mt']
+    version_added = '2.2'
+
+    tokens = {
+        'root': [
+            # Comments
+            (r'#[^\n]*\n', Comment),
+
+            # Docstrings
+            # Apologies for the non-greedy matcher here.
+            (r'/\*\*.*?\*/', String.Doc),
+
+            # `var` declarations
+            (r'\bvar\b', Keyword.Declaration, 'var'),
+
+            # `interface` declarations
+            (r'\binterface\b', Keyword.Declaration, 'interface'),
+
+            # method declarations
+            (words(_methods, prefix='\\b', suffix='\\b'),
+             Keyword, 'method'),
+
+            # All other declarations
+            (words(_declarations, prefix='\\b', suffix='\\b'),
+             Keyword.Declaration),
+
+            # Keywords
+            (words(_keywords, prefix='\\b', suffix='\\b'), Keyword),
+
+            # Literals
+            ('[+-]?0x[_0-9a-fA-F]+', Number.Hex),
+            (r'[+-]?[_0-9]+\.[_0-9]*([eE][+-]?[_0-9]+)?', Number.Float),
+            ('[+-]?[_0-9]+', Number.Integer),
+            ("'", String.Double, 'char'),
+            ('"', String.Double, 'string'),
+
+            # Quasiliterals
+            ('`', String.Backtick, 'ql'),
+
+            # Operators
+            (words(_operators), Operator),
+
+            # Verb operators
+            (_identifier + '=', Operator.Word),
+
+            # Safe scope constants
+            (words(_constants, prefix='\\b', suffix='\\b'),
+             Keyword.Pseudo),
+
+            # Safe scope guards
+            (words(_guards, prefix='\\b', suffix='\\b'), Keyword.Type),
+
+            # All other safe scope names
+            (words(_safeScope, prefix='\\b', suffix='\\b'),
+             Name.Builtin),
+
+            # Identifiers
+            (_identifier, Name),
+
+            # Punctuation
+            (r'\(|\)|\{|\}|\[|\]|:|,', Punctuation),
+
+            # Whitespace
+            (' +', Whitespace),
+
+            # Definite lexer errors
+            ('=', Error),
+        ],
+        'char': [
+            # It is definitely an error to have a char of width == 0.
+            ("'", Error, 'root'),
+            (_escape_pattern, String.Escape, 'charEnd'),
+            ('.', String.Char, 'charEnd'),
+        ],
+        'charEnd': [
+            ("'", String.Char, '#pop:2'),
+            # It is definitely an error to have a char of width > 1.
+            ('.', Error),
+        ],
+        # The state of things coming into an interface.
+        'interface': [
+            (' +', Whitespace),
+            (_identifier, Name.Class, '#pop'),
+            include('root'),
+        ],
+        # The state of things coming into a method.
+        'method': [
+            (' +', Whitespace),
+            (_identifier, Name.Function, '#pop'),
+            include('root'),
+        ],
+        'string': [
+            ('"', String.Double, 'root'),
+            (_escape_pattern, String.Escape),
+            (r'\n', String.Double),
+            ('.', String.Double),
+        ],
+        'ql': [
+            ('`', String.Backtick, 'root'),
+            (r'\$' + _escape_pattern, String.Escape),
+            (r'\$\$', String.Escape),
+            (r'@@', String.Escape),
+            (r'\$\{', String.Interpol, 'qlNest'),
+            (r'@\{', String.Interpol, 'qlNest'),
+            (r'\$' + _identifier, Name),
+            ('@' + _identifier, Name),
+            ('.', String.Backtick),
+        ],
+        'qlNest': [
+            (r'\}', String.Interpol, '#pop'),
+            include('root'),
+        ],
+        # The state of things immediately following `var`.
+        'var': [
+            (' +', Whitespace),
+            (_identifier, Name.Variable, '#pop'),
+            include('root'),
+        ],
+    }
diff --git a/pygments/lexers/mosel.py b/pygments/lexers/mosel.py
new file mode 100644
index 0000000000000000000000000000000000000000..426c9a145261af10f25c2f6e5f16fdccf665b065
--- /dev/null
+++ b/pygments/lexers/mosel.py
@@ -0,0 +1,447 @@
+"""
+    pygments.lexers.mosel
+    ~~~~~~~~~~~~~~~~~~~~~
+
+    Lexers for the mosel language.
+    http://www.fico.com/en/products/fico-xpress-optimization
+
+    :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS.
+    :license: BSD, see LICENSE for details.
+"""
+
+from pygments.lexer import RegexLexer, words
+from pygments.token import Text, Comment, Operator, Keyword, Name, String, \
+    Number, Punctuation
+
+__all__ = ['MoselLexer']
+
+FUNCTIONS = (
+    # core functions
+    '_',
+    'abs',
+    'arctan',
+    'asproc',
+    'assert',
+    'bitflip',
+    'bitneg',
+    'bitset',
+    'bitshift',
+    'bittest',
+    'bitval',
+    'ceil',
+    'cos',
+    'create',
+    'currentdate',
+    'currenttime',
+    'cutelt',
+    'cutfirst',
+    'cuthead',
+    'cutlast',
+    'cuttail',
+    'datablock',
+    'delcell',
+    'exists',
+    'exit',
+    'exp',
+    'exportprob',
+    'fclose',
+    'fflush',
+    'finalize',
+    'findfirst',
+    'findlast',
+    'floor',
+    'fopen',
+    'fselect',
+    'fskipline',
+    'fwrite',
+    'fwrite_',
+    'fwriteln',
+    'fwriteln_',
+    'getact',
+    'getcoeff',
+    'getcoeffs',
+    'getdual',
+    'getelt',
+    'getfid',
+    'getfirst',
+    'getfname',
+    'gethead',
+    'getlast',
+    'getobjval',
+    'getparam',
+    'getrcost',
+    'getreadcnt',
+    'getreverse',
+    'getsize',
+    'getslack',
+    'getsol',
+    'gettail',
+    'gettype',
+    'getvars',
+    'isdynamic',
+    'iseof',
+    'isfinite',
+    'ishidden',
+    'isinf',
+    'isnan',
+    'isodd',
+    'ln',
+    'localsetparam',
+    'log',
+    'makesos1',
+    'makesos2',
+    'maxlist',
+    'memoryuse',
+    'minlist',
+    'newmuid',
+    'publish',
+    'random',
+    'read',
+    'readln',
+    'reset',
+    'restoreparam',
+    'reverse',
+    'round',
+    'setcoeff',
+    'sethidden',
+    'setioerr',
+    'setmatherr',
+    'setname',
+    'setparam',
+    'setrandseed',
+    'setrange',
+    'settype',
+    'sin',
+    'splithead',
+    'splittail',
+    'sqrt',
+    'strfmt',
+    'substr',
+    'timestamp',
+    'unpublish',
+    'versionnum',
+    'versionstr',
+    'write',
+    'write_',
+    'writeln',
+    'writeln_',
+
+    # mosel exam mmxprs | sed -n -e "s/ [pf][a-z]* \([a-zA-Z0-9_]*\).*/'\1',/p" | sort -u
+    'addcut',
+    'addcuts',
+    'addmipsol',
+    'basisstability',
+    'calcsolinfo',
+    'clearmipdir',
+    'clearmodcut',
+    'command',
+    'copysoltoinit',
+    'crossoverlpsol',
+    'defdelayedrows',
+    'defsecurevecs',
+    'delcuts',
+    'dropcuts',
+    'estimatemarginals',
+    'fixglobal',
+    'flushmsgq',
+    'getbstat',
+    'getcnlist',
+    'getcplist',
+    'getdualray',
+    'getiis',
+    'getiissense',
+    'getiistype',
+    'getinfcause',
+    'getinfeas',
+    'getlb',
+    'getlct',
+    'getleft',
+    'getloadedlinctrs',
+    'getloadedmpvars',
+    'getname',
+    'getprimalray',
+    'getprobstat',
+    'getrange',
+    'getright',
+    'getsensrng',
+    'getsize',
+    'getsol',
+    'gettype',
+    'getub',
+    'getvars',
+    'gety',
+    'hasfeature',
+    'implies',
+    'indicator',
+    'initglobal',
+    'ishidden',
+    'isiisvalid',
+    'isintegral',
+    'loadbasis',
+    'loadcuts',
+    'loadlpsol',
+    'loadmipsol',
+    'loadprob',
+    'maximise',
+    'maximize',
+    'minimise',
+    'minimize',
+    'postsolve',
+    'readbasis',
+    'readdirs',
+    'readsol',
+    'refinemipsol',
+    'rejectintsol',
+    'repairinfeas',
+    'repairinfeas_deprec',
+    'resetbasis',
+    'resetiis',
+    'resetsol',
+    'savebasis',
+    'savemipsol',
+    'savesol',
+    'savestate',
+    'selectsol',
+    'setarchconsistency',
+    'setbstat',
+    'setcallback',
+    'setcbcutoff',
+    'setgndata',
+    'sethidden',
+    'setlb',
+    'setmipdir',
+    'setmodcut',
+    'setsol',
+    'setub',
+    'setucbdata',
+    'stopoptimise',
+    'stopoptimize',
+    'storecut',
+    'storecuts',
+    'unloadprob',
+    'uselastbarsol',
+    'writebasis',
+    'writedirs',
+    'writeprob',
+    'writesol',
+    'xor',
+    'xprs_addctr',
+    'xprs_addindic',
+
+    # mosel exam mmsystem | sed -n -e "s/ [pf][a-z]* \([a-zA-Z0-9_]*\).*/'\1',/p" | sort -u
+    'addmonths',
+    'copytext',
+    'cuttext',
+    'deltext',
+    'endswith',
+    'erase',
+    'expandpath',
+    'fcopy',
+    'fdelete',
+    'findfiles',
+    'findtext',
+    'fmove',
+    'formattext',
+    'getasnumber',
+    'getchar',
+    'getcwd',
+    'getdate',
+    'getday',
+    'getdaynum',
+    'getdays',
+    'getdirsep',
+    'getdsoparam',
+    'getendparse',
+    'getenv',
+    'getfsize',
+    'getfstat',
+    'getftime',
+    'gethour',
+    'getminute',
+    'getmonth',
+    'getmsec',
+    'getoserrmsg',
+    'getoserror',
+    'getpathsep',
+    'getqtype',
+    'getsecond',
+    'getsepchar',
+    'getsize',
+    'getstart',
+    'getsucc',
+    'getsysinfo',
+    'getsysstat',
+    'gettime',
+    'gettmpdir',
+    'gettrim',
+    'getweekday',
+    'getyear',
+    'inserttext',
+    'isvalid',
+    'jointext',
+    'makedir',
+    'makepath',
+    'newtar',
+    'newzip',
+    'nextfield',
+    'openpipe',
+    'parseextn',
+    'parseint',
+    'parsereal',
+    'parsetext',
+    'pastetext',
+    'pathmatch',
+    'pathsplit',
+    'qsort',
+    'quote',
+    'readtextline',
+    'regmatch',
+    'regreplace',
+    'removedir',
+    'removefiles',
+    'setchar',
+    'setdate',
+    'setday',
+    'setdsoparam',
+    'setendparse',
+    'setenv',
+    'sethour',
+    'setminute',
+    'setmonth',
+    'setmsec',
+    'setoserror',
+    'setqtype',
+    'setsecond',
+    'setsepchar',
+    'setstart',
+    'setsucc',
+    'settime',
+    'settrim',
+    'setyear',
+    'sleep',
+    'splittext',
+    'startswith',
+    'system',
+    'tarlist',
+    'textfmt',
+    'tolower',
+    'toupper',
+    'trim',
+    'untar',
+    'unzip',
+    'ziplist',
+
+    # mosel exam mmjobs | sed -n -e "s/ [pf][a-z]* \([a-zA-Z0-9_]*\).*/'\1',/p" | sort -u
+    'canceltimer',
+    'clearaliases',
+    'compile',
+    'connect',
+    'detach',
+    'disconnect',
+    'dropnextevent',
+    'findxsrvs',
+    'getaliases',
+    'getannidents',
+    'getannotations',
+    'getbanner',
+    'getclass',
+    'getdsoprop',
+    'getdsopropnum',
+    'getexitcode',
+    'getfromgid',
+    'getfromid',
+    'getfromuid',
+    'getgid',
+    'gethostalias',
+    'getid',
+    'getmodprop',
+    'getmodpropnum',
+    'getnextevent',
+    'getnode',
+    'getrmtid',
+    'getstatus',
+    'getsysinfo',
+    'gettimer',
+    'getuid',
+    'getvalue',
+    'isqueueempty',
+    'load',
+    'nullevent',
+    'peeknextevent',
+    'resetmodpar',
+    'run',
+    'send',
+    'setcontrol',
+    'setdefstream',
+    'setgid',
+    'sethostalias',
+    'setmodpar',
+    'settimer',
+    'setuid',
+    'setworkdir',
+    'stop',
+    'unload',
+    'wait',
+    'waitexpired',
+    'waitfor',
+    'waitforend',
+)
+
+
+class MoselLexer(RegexLexer):
+    """
+    For the Mosel optimization language.
+    """
+    name = 'Mosel'
+    aliases = ['mosel']
+    filenames = ['*.mos']
+    url = 'https://www.fico.com/fico-xpress-optimization/docs/latest/mosel/mosel_lang/dhtml/moselreflang.html'
+    version_added = '2.6'
+
+    tokens = {
+        'root': [
+            (r'\n', Text),
+            (r'\s+', Text.Whitespace),
+            (r'!.*?\n', Comment.Single),
+            (r'\(!(.|\n)*?!\)', Comment.Multiline),
+            (words((
+                'and', 'as', 'break', 'case', 'count', 'declarations', 'do',
+                'dynamic', 'elif', 'else', 'end-', 'end', 'evaluation', 'false',
+                'forall', 'forward', 'from', 'function', 'hashmap', 'if',
+                'imports', 'include', 'initialisations', 'initializations', 'inter',
+                'max', 'min', 'model', 'namespace', 'next', 'not', 'nsgroup',
+                'nssearch', 'of', 'options', 'or', 'package', 'parameters',
+                'procedure', 'public', 'prod', 'record', 'repeat', 'requirements',
+                'return', 'sum', 'then', 'to', 'true', 'union', 'until', 'uses',
+                'version', 'while', 'with'), prefix=r'\b', suffix=r'\b'),
+             Keyword.Builtin),
+            (words((
+                'range', 'array', 'set', 'list', 'mpvar', 'mpproblem', 'linctr',
+                'nlctr', 'integer', 'string', 'real', 'boolean', 'text', 'time',
+                'date', 'datetime', 'returned', 'Model', 'Mosel', 'counter',
+                'xmldoc', 'is_sos1', 'is_sos2', 'is_integer', 'is_binary',
+                'is_continuous', 'is_free', 'is_semcont', 'is_semint',
+                'is_partint'), prefix=r'\b', suffix=r'\b'),
+             Keyword.Type),
+            (r'(\+|\-|\*|/|=|<=|>=|\||\^|<|>|<>|\.\.|\.|:=|::|:|in|mod|div)',
+             Operator),
+            (r'[()\[\]{},;]+', Punctuation),
+            (words(FUNCTIONS,  prefix=r'\b', suffix=r'\b'), Name.Function),
+            (r'(\d+\.(?!\.)\d*|\.(?!.)\d+)([eE][+-]?\d+)?', Number.Float),
+            (r'\d+([eE][+-]?\d+)?', Number.Integer),
+            (r'[+-]?Infinity', Number.Integer),
+            (r'0[xX][0-9a-fA-F]+', Number),
+            (r'"', String.Double, 'double_quote'),
+            (r'\'', String.Single, 'single_quote'),
+            (r'(\w+|(\.(?!\.)))', Text),
+        ],
+        'single_quote': [
+            (r'\'', String.Single, '#pop'),
+            (r'[^\']+', String.Single),
+        ],
+        'double_quote': [
+            (r'(\\"|\\[0-7]{1,3}\D|\\[abfnrtv]|\\\\)', String.Escape),
+            (r'\"', String.Double, '#pop'),
+            (r'[^"\\]+', String.Double),
+        ],
+    }
diff --git a/pygments/lexers/ncl.py b/pygments/lexers/ncl.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2f476087bf6e6e92328542c52d0c609d1972845
--- /dev/null
+++ b/pygments/lexers/ncl.py
@@ -0,0 +1,894 @@
+"""
+    pygments.lexers.ncl
+    ~~~~~~~~~~~~~~~~~~~
+
+    Lexers for NCAR Command Language.
+
+    :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS.
+    :license: BSD, see LICENSE for details.
+"""
+
+import re
+
+from pygments.lexer import RegexLexer, include, words
+from pygments.token import Text, Comment, Operator, Keyword, Name, String, \
+    Number, Punctuation
+
+__all__ = ['NCLLexer']
+
+
+class NCLLexer(RegexLexer):
+    """
+    Lexer for NCL code.
+    """
+    name = 'NCL'
+    aliases = ['ncl']
+    filenames = ['*.ncl']
+    mimetypes = ['text/ncl']
+    url = 'https://www.ncl.ucar.edu'
+    version_added = '2.2'
+
+    flags = re.MULTILINE
+
+    tokens = {
+        'root': [
+            (r';.*\n', Comment),
+            include('strings'),
+            include('core'),
+            (r'[a-zA-Z_]\w*', Name),
+            include('nums'),
+            (r'[\s]+', Text),
+        ],
+        'core': [
+            # Statements
+            (words((
+                'begin', 'break', 'continue', 'create', 'defaultapp', 'do',
+                'else', 'end', 'external', 'exit', 'True', 'False', 'file', 'function',
+                'getvalues', 'graphic', 'group', 'if', 'list', 'load', 'local',
+                'new', '_Missing', 'Missing', 'noparent', 'procedure',
+                'quit', 'QUIT', 'Quit', 'record', 'return', 'setvalues', 'stop',
+                'then', 'while'), prefix=r'\b', suffix=r'\s*\b'),
+             Keyword),
+
+            # Data Types
+            (words((
+                'ubyte', 'uint', 'uint64', 'ulong', 'string', 'byte',
+                'character', 'double', 'float', 'integer', 'int64', 'logical',
+                'long', 'short', 'ushort', 'enumeric', 'numeric', 'snumeric'),
+                prefix=r'\b', suffix=r'\s*\b'),
+             Keyword.Type),
+
+            # Operators
+            (r'[\%^*+\-/<>]', Operator),
+
+            # punctuation:
+            (r'[\[\]():@$!&|.,\\{}]', Punctuation),
+            (r'[=:]', Punctuation),
+
+            # Intrinsics
+            (words((
+                'abs', 'acos', 'addfile', 'addfiles', 'all', 'angmom_atm', 'any',
+                'area_conserve_remap', 'area_hi2lores', 'area_poly_sphere',
+                'asciiread', 'asciiwrite', 'asin', 'atan', 'atan2', 'attsetvalues',
+                'avg', 'betainc', 'bin_avg', 'bin_sum', 'bw_bandpass_filter',
+                'cancor', 'cbinread', 'cbinwrite', 'cd_calendar', 'cd_inv_calendar',
+                'cdfbin_p', 'cdfbin_pr', 'cdfbin_s', 'cdfbin_xn', 'cdfchi_p',
+                'cdfchi_x', 'cdfgam_p', 'cdfgam_x', 'cdfnor_p', 'cdfnor_x',
+                'cdft_p', 'cdft_t', 'ceil', 'center_finite_diff',
+                'center_finite_diff_n', 'cfftb', 'cfftf', 'cfftf_frq_reorder',
+                'charactertodouble', 'charactertofloat', 'charactertointeger',
+                'charactertolong', 'charactertoshort', 'charactertostring',
+                'chartodouble', 'chartofloat', 'chartoint', 'chartointeger',
+                'chartolong', 'chartoshort', 'chartostring', 'chiinv', 'clear',
+                'color_index_to_rgba', 'conform', 'conform_dims', 'cos', 'cosh',
+                'count_unique_values', 'covcorm', 'covcorm_xy', 'craybinnumrec',
+                'craybinrecread', 'create_graphic', 'csa1', 'csa1d', 'csa1s',
+                'csa1x', 'csa1xd', 'csa1xs', 'csa2', 'csa2d', 'csa2l', 'csa2ld',
+                'csa2ls', 'csa2lx', 'csa2lxd', 'csa2lxs', 'csa2s', 'csa2x',
+                'csa2xd', 'csa2xs', 'csa3', 'csa3d', 'csa3l', 'csa3ld', 'csa3ls',
+                'csa3lx', 'csa3lxd', 'csa3lxs', 'csa3s', 'csa3x', 'csa3xd',
+                'csa3xs', 'csc2s', 'csgetp', 'css2c', 'cssetp', 'cssgrid', 'csstri',
+                'csvoro', 'cumsum', 'cz2ccm', 'datatondc', 'day_of_week',
+                'day_of_year', 'days_in_month', 'default_fillvalue', 'delete',
+                'depth_to_pres', 'destroy', 'determinant', 'dewtemp_trh',
+                'dgeevx_lapack', 'dim_acumrun_n', 'dim_avg', 'dim_avg_n',
+                'dim_avg_wgt', 'dim_avg_wgt_n', 'dim_cumsum', 'dim_cumsum_n',
+                'dim_gamfit_n', 'dim_gbits', 'dim_max', 'dim_max_n', 'dim_median',
+                'dim_median_n', 'dim_min', 'dim_min_n', 'dim_num', 'dim_num_n',
+                'dim_numrun_n', 'dim_pqsort', 'dim_pqsort_n', 'dim_product',
+                'dim_product_n', 'dim_rmsd', 'dim_rmsd_n', 'dim_rmvmean',
+                'dim_rmvmean_n', 'dim_rmvmed', 'dim_rmvmed_n', 'dim_spi_n',
+                'dim_standardize', 'dim_standardize_n', 'dim_stat4', 'dim_stat4_n',
+                'dim_stddev', 'dim_stddev_n', 'dim_sum', 'dim_sum_n', 'dim_sum_wgt',
+                'dim_sum_wgt_n', 'dim_variance', 'dim_variance_n', 'dimsizes',
+                'doubletobyte', 'doubletochar', 'doubletocharacter',
+                'doubletofloat', 'doubletoint', 'doubletointeger', 'doubletolong',
+                'doubletoshort', 'dpres_hybrid_ccm', 'dpres_plevel', 'draw',
+                'draw_color_palette', 'dsgetp', 'dsgrid2', 'dsgrid2d', 'dsgrid2s',
+                'dsgrid3', 'dsgrid3d', 'dsgrid3s', 'dspnt2', 'dspnt2d', 'dspnt2s',
+                'dspnt3', 'dspnt3d', 'dspnt3s', 'dssetp', 'dtrend', 'dtrend_msg',
+                'dtrend_msg_n', 'dtrend_n', 'dtrend_quadratic',
+                'dtrend_quadratic_msg_n', 'dv2uvf', 'dv2uvg', 'dz_height',
+                'echo_off', 'echo_on', 'eof2data', 'eof_varimax', 'eofcor',
+                'eofcor_pcmsg', 'eofcor_ts', 'eofcov', 'eofcov_pcmsg', 'eofcov_ts',
+                'eofunc', 'eofunc_ts', 'eofunc_varimax', 'equiv_sample_size', 'erf',
+                'erfc', 'esacr', 'esacv', 'esccr', 'esccv', 'escorc', 'escorc_n',
+                'escovc', 'exit', 'exp', 'exp_tapersh', 'exp_tapersh_wgts',
+                'exp_tapershC', 'ezfftb', 'ezfftb_n', 'ezfftf', 'ezfftf_n',
+                'f2fosh', 'f2foshv', 'f2fsh', 'f2fshv', 'f2gsh', 'f2gshv', 'fabs',
+                'fbindirread', 'fbindirwrite', 'fbinnumrec', 'fbinread',
+                'fbinrecread', 'fbinrecwrite', 'fbinwrite', 'fft2db', 'fft2df',
+                'fftshift', 'fileattdef', 'filechunkdimdef', 'filedimdef',
+                'fileexists', 'filegrpdef', 'filevarattdef', 'filevarchunkdef',
+                'filevarcompressleveldef', 'filevardef', 'filevardimsizes',
+                'filwgts_lancos', 'filwgts_lanczos', 'filwgts_normal',
+                'floattobyte', 'floattochar', 'floattocharacter', 'floattoint',
+                'floattointeger', 'floattolong', 'floattoshort', 'floor',
+                'fluxEddy', 'fo2fsh', 'fo2fshv', 'fourier_info', 'frame', 'fspan',
+                'ftcurv', 'ftcurvd', 'ftcurvi', 'ftcurvp', 'ftcurvpi', 'ftcurvps',
+                'ftcurvs', 'ftest', 'ftgetp', 'ftkurv', 'ftkurvd', 'ftkurvp',
+                'ftkurvpd', 'ftsetp', 'ftsurf', 'g2fsh', 'g2fshv', 'g2gsh',
+                'g2gshv', 'gamma', 'gammainc', 'gaus', 'gaus_lobat',
+                'gaus_lobat_wgt', 'gc_aangle', 'gc_clkwise', 'gc_dangle',
+                'gc_inout', 'gc_latlon', 'gc_onarc', 'gc_pnt2gc', 'gc_qarea',
+                'gc_tarea', 'generate_2d_array', 'get_color_index',
+                'get_color_rgba', 'get_cpu_time', 'get_isolines', 'get_ncl_version',
+                'get_script_name', 'get_script_prefix_name', 'get_sphere_radius',
+                'get_unique_values', 'getbitsone', 'getenv', 'getfiledimsizes',
+                'getfilegrpnames', 'getfilepath', 'getfilevaratts',
+                'getfilevarchunkdimsizes', 'getfilevardims', 'getfilevardimsizes',
+                'getfilevarnames', 'getfilevartypes', 'getvaratts', 'getvardims',
+                'gradsf', 'gradsg', 'greg2jul', 'grid2triple', 'hlsrgb', 'hsvrgb',
+                'hydro', 'hyi2hyo', 'idsfft', 'igradsf', 'igradsg', 'ilapsf',
+                'ilapsg', 'ilapvf', 'ilapvg', 'ind', 'ind_resolve', 'int2p',
+                'int2p_n', 'integertobyte', 'integertochar', 'integertocharacter',
+                'integertoshort', 'inttobyte', 'inttochar', 'inttoshort',
+                'inverse_matrix', 'isatt', 'isbigendian', 'isbyte', 'ischar',
+                'iscoord', 'isdefined', 'isdim', 'isdimnamed', 'isdouble',
+                'isenumeric', 'isfile', 'isfilepresent', 'isfilevar',
+                'isfilevaratt', 'isfilevarcoord', 'isfilevardim', 'isfloat',
+                'isfunc', 'isgraphic', 'isint', 'isint64', 'isinteger',
+                'isleapyear', 'islogical', 'islong', 'ismissing', 'isnan_ieee',
+                'isnumeric', 'ispan', 'isproc', 'isshort', 'issnumeric', 'isstring',
+                'isubyte', 'isuint', 'isuint64', 'isulong', 'isunlimited',
+                'isunsigned', 'isushort', 'isvar', 'jul2greg', 'kmeans_as136',
+                'kolsm2_n', 'kron_product', 'lapsf', 'lapsg', 'lapvf', 'lapvg',
+                'latlon2utm', 'lclvl', 'lderuvf', 'lderuvg', 'linint1', 'linint1_n',
+                'linint2', 'linint2_points', 'linmsg', 'linmsg_n', 'linrood_latwgt',
+                'linrood_wgt', 'list_files', 'list_filevars', 'list_hlus',
+                'list_procfuncs', 'list_vars', 'ListAppend', 'ListCount',
+                'ListGetType', 'ListIndex', 'ListIndexFromName', 'ListPop',
+                'ListPush', 'ListSetType', 'loadscript', 'local_max', 'local_min',
+                'log', 'log10', 'longtobyte', 'longtochar', 'longtocharacter',
+                'longtoint', 'longtointeger', 'longtoshort', 'lspoly', 'lspoly_n',
+                'mask', 'max', 'maxind', 'min', 'minind', 'mixed_layer_depth',
+                'mixhum_ptd', 'mixhum_ptrh', 'mjo_cross_coh2pha',
+                'mjo_cross_segment', 'moc_globe_atl', 'monthday', 'natgrid',
+                'natgridd', 'natgrids', 'ncargpath', 'ncargversion', 'ndctodata',
+                'ndtooned', 'new', 'NewList', 'ngezlogo', 'nggcog', 'nggetp',
+                'nglogo', 'ngsetp', 'NhlAddAnnotation', 'NhlAddData',
+                'NhlAddOverlay', 'NhlAddPrimitive', 'NhlAppGetDefaultParentId',
+                'NhlChangeWorkstation', 'NhlClassName', 'NhlClearWorkstation',
+                'NhlDataPolygon', 'NhlDataPolyline', 'NhlDataPolymarker',
+                'NhlDataToNDC', 'NhlDestroy', 'NhlDraw', 'NhlFrame', 'NhlFreeColor',
+                'NhlGetBB', 'NhlGetClassResources', 'NhlGetErrorObjectId',
+                'NhlGetNamedColorIndex', 'NhlGetParentId',
+                'NhlGetParentWorkstation', 'NhlGetWorkspaceObjectId',
+                'NhlIsAllocatedColor', 'NhlIsApp', 'NhlIsDataComm', 'NhlIsDataItem',
+                'NhlIsDataSpec', 'NhlIsTransform', 'NhlIsView', 'NhlIsWorkstation',
+                'NhlName', 'NhlNDCPolygon', 'NhlNDCPolyline', 'NhlNDCPolymarker',
+                'NhlNDCToData', 'NhlNewColor', 'NhlNewDashPattern', 'NhlNewMarker',
+                'NhlPalGetDefined', 'NhlRemoveAnnotation', 'NhlRemoveData',
+                'NhlRemoveOverlay', 'NhlRemovePrimitive', 'NhlSetColor',
+                'NhlSetDashPattern', 'NhlSetMarker', 'NhlUpdateData',
+                'NhlUpdateWorkstation', 'nice_mnmxintvl', 'nngetaspectd',
+                'nngetaspects', 'nngetp', 'nngetsloped', 'nngetslopes', 'nngetwts',
+                'nngetwtsd', 'nnpnt', 'nnpntd', 'nnpntend', 'nnpntendd',
+                'nnpntinit', 'nnpntinitd', 'nnpntinits', 'nnpnts', 'nnsetp', 'num',
+                'obj_anal_ic', 'omega_ccm', 'onedtond', 'overlay', 'paleo_outline',
+                'pdfxy_bin', 'poisson_grid_fill', 'pop_remap', 'potmp_insitu_ocn',
+                'prcwater_dp', 'pres2hybrid', 'pres_hybrid_ccm', 'pres_sigma',
+                'print', 'print_table', 'printFileVarSummary', 'printVarSummary',
+                'product', 'pslec', 'pslhor', 'pslhyp', 'qsort', 'rand',
+                'random_chi', 'random_gamma', 'random_normal', 'random_setallseed',
+                'random_uniform', 'rcm2points', 'rcm2rgrid', 'rdsstoi',
+                'read_colormap_file', 'reg_multlin', 'regcoef', 'regCoef_n',
+                'regline', 'relhum', 'replace_ieeenan', 'reshape', 'reshape_ind',
+                'rgba_to_color_index', 'rgbhls', 'rgbhsv', 'rgbyiq', 'rgrid2rcm',
+                'rhomb_trunc', 'rip_cape_2d', 'rip_cape_3d', 'round', 'rtest',
+                'runave', 'runave_n', 'set_default_fillvalue', 'set_sphere_radius',
+                'setfileoption', 'sfvp2uvf', 'sfvp2uvg', 'shaec', 'shagc',
+                'shgetnp', 'shgetp', 'shgrid', 'shorttobyte', 'shorttochar',
+                'shorttocharacter', 'show_ascii', 'shsec', 'shsetp', 'shsgc',
+                'shsgc_R42', 'sigma2hybrid', 'simpeq', 'simpne', 'sin',
+                'sindex_yrmo', 'sinh', 'sizeof', 'sleep', 'smth9', 'snindex_yrmo',
+                'solve_linsys', 'span_color_indexes', 'span_color_rgba',
+                'sparse_matrix_mult', 'spcorr', 'spcorr_n', 'specx_anal',
+                'specxy_anal', 'spei', 'sprintf', 'sprinti', 'sqrt', 'sqsort',
+                'srand', 'stat2', 'stat4', 'stat_medrng', 'stat_trim',
+                'status_exit', 'stdatmus_p2tdz', 'stdatmus_z2tdp', 'stddev',
+                'str_capital', 'str_concat', 'str_fields_count', 'str_get_cols',
+                'str_get_dq', 'str_get_field', 'str_get_nl', 'str_get_sq',
+                'str_get_tab', 'str_index_of_substr', 'str_insert', 'str_is_blank',
+                'str_join', 'str_left_strip', 'str_lower', 'str_match',
+                'str_match_ic', 'str_match_ic_regex', 'str_match_ind',
+                'str_match_ind_ic', 'str_match_ind_ic_regex', 'str_match_ind_regex',
+                'str_match_regex', 'str_right_strip', 'str_split',
+                'str_split_by_length', 'str_split_csv', 'str_squeeze', 'str_strip',
+                'str_sub_str', 'str_switch', 'str_upper', 'stringtochar',
+                'stringtocharacter', 'stringtodouble', 'stringtofloat',
+                'stringtoint', 'stringtointeger', 'stringtolong', 'stringtoshort',
+                'strlen', 'student_t', 'sum', 'svd_lapack', 'svdcov', 'svdcov_sv',
+                'svdstd', 'svdstd_sv', 'system', 'systemfunc', 'tan', 'tanh',
+                'taper', 'taper_n', 'tdclrs', 'tdctri', 'tdcudp', 'tdcurv',
+                'tddtri', 'tdez2d', 'tdez3d', 'tdgetp', 'tdgrds', 'tdgrid',
+                'tdgtrs', 'tdinit', 'tditri', 'tdlbla', 'tdlblp', 'tdlbls',
+                'tdline', 'tdlndp', 'tdlnpa', 'tdlpdp', 'tdmtri', 'tdotri',
+                'tdpara', 'tdplch', 'tdprpa', 'tdprpi', 'tdprpt', 'tdsetp',
+                'tdsort', 'tdstri', 'tdstrs', 'tdttri', 'thornthwaite', 'tobyte',
+                'tochar', 'todouble', 'tofloat', 'toint', 'toint64', 'tointeger',
+                'tolong', 'toshort', 'tosigned', 'tostring', 'tostring_with_format',
+                'totype', 'toubyte', 'touint', 'touint64', 'toulong', 'tounsigned',
+                'toushort', 'trend_manken', 'tri_trunc', 'triple2grid',
+                'triple2grid2d', 'trop_wmo', 'ttest', 'typeof', 'undef',
+                'unique_string', 'update', 'ushorttoint', 'ut_calendar',
+                'ut_inv_calendar', 'utm2latlon', 'uv2dv_cfd', 'uv2dvf', 'uv2dvg',
+                'uv2sfvpf', 'uv2sfvpg', 'uv2vr_cfd', 'uv2vrdvf', 'uv2vrdvg',
+                'uv2vrf', 'uv2vrg', 'v5d_close', 'v5d_create', 'v5d_setLowLev',
+                'v5d_setUnits', 'v5d_write', 'v5d_write_var', 'variance', 'vhaec',
+                'vhagc', 'vhsec', 'vhsgc', 'vibeta', 'vinth2p', 'vinth2p_ecmwf',
+                'vinth2p_ecmwf_nodes', 'vinth2p_nodes', 'vintp2p_ecmwf', 'vr2uvf',
+                'vr2uvg', 'vrdv2uvf', 'vrdv2uvg', 'wavelet', 'wavelet_default',
+                'weibull', 'wgt_area_smooth', 'wgt_areaave', 'wgt_areaave2',
+                'wgt_arearmse', 'wgt_arearmse2', 'wgt_areasum2', 'wgt_runave',
+                'wgt_runave_n', 'wgt_vert_avg_beta', 'wgt_volave', 'wgt_volave_ccm',
+                'wgt_volrmse', 'wgt_volrmse_ccm', 'where', 'wk_smooth121', 'wmbarb',
+                'wmbarbmap', 'wmdrft', 'wmgetp', 'wmlabs', 'wmsetp', 'wmstnm',
+                'wmvect', 'wmvectmap', 'wmvlbl', 'wrf_avo', 'wrf_cape_2d',
+                'wrf_cape_3d', 'wrf_dbz', 'wrf_eth', 'wrf_helicity', 'wrf_ij_to_ll',
+                'wrf_interp_1d', 'wrf_interp_2d_xy', 'wrf_interp_3d_z',
+                'wrf_latlon_to_ij', 'wrf_ll_to_ij', 'wrf_omega', 'wrf_pvo',
+                'wrf_rh', 'wrf_slp', 'wrf_smooth_2d', 'wrf_td', 'wrf_tk',
+                'wrf_updraft_helicity', 'wrf_uvmet', 'wrf_virtual_temp',
+                'wrf_wetbulb', 'wrf_wps_close_int', 'wrf_wps_open_int',
+                'wrf_wps_rddata_int', 'wrf_wps_rdhead_int', 'wrf_wps_read_int',
+                'wrf_wps_write_int', 'write_matrix', 'write_table', 'yiqrgb',
+                'z2geouv', 'zonal_mpsi', 'addfiles_GetVar', 'advect_variable',
+                'area_conserve_remap_Wrap', 'area_hi2lores_Wrap',
+                'array_append_record', 'assignFillValue', 'byte2flt',
+                'byte2flt_hdf', 'calcDayAnomTLL', 'calcMonAnomLLLT',
+                'calcMonAnomLLT', 'calcMonAnomTLL', 'calcMonAnomTLLL',
+                'calculate_monthly_values', 'cd_convert', 'changeCase',
+                'changeCaseChar', 'clmDayTLL', 'clmDayTLLL', 'clmMon2clmDay',
+                'clmMonLLLT', 'clmMonLLT', 'clmMonTLL', 'clmMonTLLL', 'closest_val',
+                'copy_VarAtts', 'copy_VarCoords', 'copy_VarCoords_1',
+                'copy_VarCoords_2', 'copy_VarMeta', 'copyatt', 'crossp3',
+                'cshstringtolist', 'cssgrid_Wrap', 'dble2flt', 'decimalPlaces',
+                'delete_VarAtts', 'dim_avg_n_Wrap', 'dim_avg_wgt_n_Wrap',
+                'dim_avg_wgt_Wrap', 'dim_avg_Wrap', 'dim_cumsum_n_Wrap',
+                'dim_cumsum_Wrap', 'dim_max_n_Wrap', 'dim_min_n_Wrap',
+                'dim_rmsd_n_Wrap', 'dim_rmsd_Wrap', 'dim_rmvmean_n_Wrap',
+                'dim_rmvmean_Wrap', 'dim_rmvmed_n_Wrap', 'dim_rmvmed_Wrap',
+                'dim_standardize_n_Wrap', 'dim_standardize_Wrap',
+                'dim_stddev_n_Wrap', 'dim_stddev_Wrap', 'dim_sum_n_Wrap',
+                'dim_sum_wgt_n_Wrap', 'dim_sum_wgt_Wrap', 'dim_sum_Wrap',
+                'dim_variance_n_Wrap', 'dim_variance_Wrap', 'dpres_plevel_Wrap',
+                'dtrend_leftdim', 'dv2uvF_Wrap', 'dv2uvG_Wrap', 'eof_north',
+                'eofcor_Wrap', 'eofcov_Wrap', 'eofunc_north', 'eofunc_ts_Wrap',
+                'eofunc_varimax_reorder', 'eofunc_varimax_Wrap', 'eofunc_Wrap',
+                'epsZero', 'f2fosh_Wrap', 'f2foshv_Wrap', 'f2fsh_Wrap',
+                'f2fshv_Wrap', 'f2gsh_Wrap', 'f2gshv_Wrap', 'fbindirSwap',
+                'fbinseqSwap1', 'fbinseqSwap2', 'flt2dble', 'flt2string',
+                'fo2fsh_Wrap', 'fo2fshv_Wrap', 'g2fsh_Wrap', 'g2fshv_Wrap',
+                'g2gsh_Wrap', 'g2gshv_Wrap', 'generate_resample_indices',
+                'generate_sample_indices', 'generate_unique_indices',
+                'genNormalDist', 'get1Dindex', 'get1Dindex_Collapse',
+                'get1Dindex_Exclude', 'get_file_suffix', 'GetFillColor',
+                'GetFillColorIndex', 'getFillValue', 'getind_latlon2d',
+                'getVarDimNames', 'getVarFillValue', 'grib_stime2itime',
+                'hyi2hyo_Wrap', 'ilapsF_Wrap', 'ilapsG_Wrap', 'ind_nearest_coord',
+                'indStrSubset', 'int2dble', 'int2flt', 'int2p_n_Wrap', 'int2p_Wrap',
+                'isMonotonic', 'isStrSubset', 'latGau', 'latGauWgt', 'latGlobeF',
+                'latGlobeFo', 'latRegWgt', 'linint1_n_Wrap', 'linint1_Wrap',
+                'linint2_points_Wrap', 'linint2_Wrap', 'local_max_1d',
+                'local_min_1d', 'lonFlip', 'lonGlobeF', 'lonGlobeFo', 'lonPivot',
+                'merge_levels_sfc', 'mod', 'month_to_annual',
+                'month_to_annual_weighted', 'month_to_season', 'month_to_season12',
+                'month_to_seasonN', 'monthly_total_to_daily_mean', 'nameDim',
+                'natgrid_Wrap', 'NewCosWeight', 'niceLatLon2D', 'NormCosWgtGlobe',
+                'numAsciiCol', 'numAsciiRow', 'numeric2int',
+                'obj_anal_ic_deprecated', 'obj_anal_ic_Wrap', 'omega_ccm_driver',
+                'omega_to_w', 'oneDtostring', 'pack_values', 'pattern_cor', 'pdfx',
+                'pdfxy', 'pdfxy_conform', 'pot_temp', 'pot_vort_hybrid',
+                'pot_vort_isobaric', 'pres2hybrid_Wrap', 'print_clock',
+                'printMinMax', 'quadroots', 'rcm2points_Wrap', 'rcm2rgrid_Wrap',
+                'readAsciiHead', 'readAsciiTable', 'reg_multlin_stats',
+                'region_ind', 'regline_stats', 'relhum_ttd', 'replaceSingleChar',
+                'RGBtoCmap', 'rgrid2rcm_Wrap', 'rho_mwjf', 'rm_single_dims',
+                'rmAnnCycle1D', 'rmInsufData', 'rmMonAnnCycLLLT', 'rmMonAnnCycLLT',
+                'rmMonAnnCycTLL', 'runave_n_Wrap', 'runave_Wrap', 'short2flt',
+                'short2flt_hdf', 'shsgc_R42_Wrap', 'sign_f90', 'sign_matlab',
+                'smth9_Wrap', 'smthClmDayTLL', 'smthClmDayTLLL', 'SqrtCosWeight',
+                'stat_dispersion', 'static_stability', 'stdMonLLLT', 'stdMonLLT',
+                'stdMonTLL', 'stdMonTLLL', 'symMinMaxPlt', 'table_attach_columns',
+                'table_attach_rows', 'time_to_newtime', 'transpose',
+                'triple2grid_Wrap', 'ut_convert', 'uv2dvF_Wrap', 'uv2dvG_Wrap',
+                'uv2vrF_Wrap', 'uv2vrG_Wrap', 'vr2uvF_Wrap', 'vr2uvG_Wrap',
+                'w_to_omega', 'wallClockElapseTime', 'wave_number_spc',
+                'wgt_areaave_Wrap', 'wgt_runave_leftdim', 'wgt_runave_n_Wrap',
+                'wgt_runave_Wrap', 'wgt_vertical_n', 'wind_component',
+                'wind_direction', 'yyyyddd_to_yyyymmdd', 'yyyymm_time',
+                'yyyymm_to_yyyyfrac', 'yyyymmdd_time', 'yyyymmdd_to_yyyyddd',
+                'yyyymmdd_to_yyyyfrac', 'yyyymmddhh_time', 'yyyymmddhh_to_yyyyfrac',
+                'zonal_mpsi_Wrap', 'zonalAve', 'calendar_decode2', 'cd_string',
+                'kf_filter', 'run_cor', 'time_axis_labels', 'ut_string',
+                'wrf_contour', 'wrf_map', 'wrf_map_overlay', 'wrf_map_overlays',
+                'wrf_map_resources', 'wrf_map_zoom', 'wrf_overlay', 'wrf_overlays',
+                'wrf_user_getvar', 'wrf_user_ij_to_ll', 'wrf_user_intrp2d',
+                'wrf_user_intrp3d', 'wrf_user_latlon_to_ij', 'wrf_user_list_times',
+                'wrf_user_ll_to_ij', 'wrf_user_unstagger', 'wrf_user_vert_interp',
+                'wrf_vector', 'gsn_add_annotation', 'gsn_add_polygon',
+                'gsn_add_polyline', 'gsn_add_polymarker',
+                'gsn_add_shapefile_polygons', 'gsn_add_shapefile_polylines',
+                'gsn_add_shapefile_polymarkers', 'gsn_add_text', 'gsn_attach_plots',
+                'gsn_blank_plot', 'gsn_contour', 'gsn_contour_map',
+                'gsn_contour_shade', 'gsn_coordinates', 'gsn_create_labelbar',
+                'gsn_create_legend', 'gsn_create_text',
+                'gsn_csm_attach_zonal_means', 'gsn_csm_blank_plot',
+                'gsn_csm_contour', 'gsn_csm_contour_map', 'gsn_csm_contour_map_ce',
+                'gsn_csm_contour_map_overlay', 'gsn_csm_contour_map_polar',
+                'gsn_csm_hov', 'gsn_csm_lat_time', 'gsn_csm_map', 'gsn_csm_map_ce',
+                'gsn_csm_map_polar', 'gsn_csm_pres_hgt',
+                'gsn_csm_pres_hgt_streamline', 'gsn_csm_pres_hgt_vector',
+                'gsn_csm_streamline', 'gsn_csm_streamline_contour_map',
+                'gsn_csm_streamline_contour_map_ce',
+                'gsn_csm_streamline_contour_map_polar', 'gsn_csm_streamline_map',
+                'gsn_csm_streamline_map_ce', 'gsn_csm_streamline_map_polar',
+                'gsn_csm_streamline_scalar', 'gsn_csm_streamline_scalar_map',
+                'gsn_csm_streamline_scalar_map_ce',
+                'gsn_csm_streamline_scalar_map_polar', 'gsn_csm_time_lat',
+                'gsn_csm_vector', 'gsn_csm_vector_map', 'gsn_csm_vector_map_ce',
+                'gsn_csm_vector_map_polar', 'gsn_csm_vector_scalar',
+                'gsn_csm_vector_scalar_map', 'gsn_csm_vector_scalar_map_ce',
+                'gsn_csm_vector_scalar_map_polar', 'gsn_csm_x2y', 'gsn_csm_x2y2',
+                'gsn_csm_xy', 'gsn_csm_xy2', 'gsn_csm_xy3', 'gsn_csm_y',
+                'gsn_define_colormap', 'gsn_draw_colormap', 'gsn_draw_named_colors',
+                'gsn_histogram', 'gsn_labelbar_ndc', 'gsn_legend_ndc', 'gsn_map',
+                'gsn_merge_colormaps', 'gsn_open_wks', 'gsn_panel', 'gsn_polygon',
+                'gsn_polygon_ndc', 'gsn_polyline', 'gsn_polyline_ndc',
+                'gsn_polymarker', 'gsn_polymarker_ndc', 'gsn_retrieve_colormap',
+                'gsn_reverse_colormap', 'gsn_streamline', 'gsn_streamline_map',
+                'gsn_streamline_scalar', 'gsn_streamline_scalar_map', 'gsn_table',
+                'gsn_text', 'gsn_text_ndc', 'gsn_vector', 'gsn_vector_map',
+                'gsn_vector_scalar', 'gsn_vector_scalar_map', 'gsn_xy', 'gsn_y',
+                'hsv2rgb', 'maximize_output', 'namedcolor2rgb', 'namedcolor2rgba',
+                'reset_device_coordinates', 'span_named_colors'), prefix=r'\b'),
+             Name.Builtin),
+
+            # Resources
+            (words((
+                'amDataXF', 'amDataYF', 'amJust', 'amOn', 'amOrthogonalPosF',
+                'amParallelPosF', 'amResizeNotify', 'amSide', 'amTrackData',
+                'amViewId', 'amZone', 'appDefaultParent', 'appFileSuffix',
+                'appResources', 'appSysDir', 'appUsrDir', 'caCopyArrays',
+                'caXArray', 'caXCast', 'caXMaxV', 'caXMinV', 'caXMissingV',
+                'caYArray', 'caYCast', 'caYMaxV', 'caYMinV', 'caYMissingV',
+                'cnCellFillEdgeColor', 'cnCellFillMissingValEdgeColor',
+                'cnConpackParams', 'cnConstFEnableFill', 'cnConstFLabelAngleF',
+                'cnConstFLabelBackgroundColor', 'cnConstFLabelConstantSpacingF',
+                'cnConstFLabelFont', 'cnConstFLabelFontAspectF',
+                'cnConstFLabelFontColor', 'cnConstFLabelFontHeightF',
+                'cnConstFLabelFontQuality', 'cnConstFLabelFontThicknessF',
+                'cnConstFLabelFormat', 'cnConstFLabelFuncCode', 'cnConstFLabelJust',
+                'cnConstFLabelOn', 'cnConstFLabelOrthogonalPosF',
+                'cnConstFLabelParallelPosF', 'cnConstFLabelPerimColor',
+                'cnConstFLabelPerimOn', 'cnConstFLabelPerimSpaceF',
+                'cnConstFLabelPerimThicknessF', 'cnConstFLabelSide',
+                'cnConstFLabelString', 'cnConstFLabelTextDirection',
+                'cnConstFLabelZone', 'cnConstFUseInfoLabelRes',
+                'cnExplicitLabelBarLabelsOn', 'cnExplicitLegendLabelsOn',
+                'cnExplicitLineLabelsOn', 'cnFillBackgroundColor', 'cnFillColor',
+                'cnFillColors', 'cnFillDotSizeF', 'cnFillDrawOrder', 'cnFillMode',
+                'cnFillOn', 'cnFillOpacityF', 'cnFillPalette', 'cnFillPattern',
+                'cnFillPatterns', 'cnFillScaleF', 'cnFillScales', 'cnFixFillBleed',
+                'cnGridBoundFillColor', 'cnGridBoundFillPattern',
+                'cnGridBoundFillScaleF', 'cnGridBoundPerimColor',
+                'cnGridBoundPerimDashPattern', 'cnGridBoundPerimOn',
+                'cnGridBoundPerimThicknessF', 'cnHighLabelAngleF',
+                'cnHighLabelBackgroundColor', 'cnHighLabelConstantSpacingF',
+                'cnHighLabelCount', 'cnHighLabelFont', 'cnHighLabelFontAspectF',
+                'cnHighLabelFontColor', 'cnHighLabelFontHeightF',
+                'cnHighLabelFontQuality', 'cnHighLabelFontThicknessF',
+                'cnHighLabelFormat', 'cnHighLabelFuncCode', 'cnHighLabelPerimColor',
+                'cnHighLabelPerimOn', 'cnHighLabelPerimSpaceF',
+                'cnHighLabelPerimThicknessF', 'cnHighLabelString', 'cnHighLabelsOn',
+                'cnHighLowLabelOverlapMode', 'cnHighUseLineLabelRes',
+                'cnInfoLabelAngleF', 'cnInfoLabelBackgroundColor',
+                'cnInfoLabelConstantSpacingF', 'cnInfoLabelFont',
+                'cnInfoLabelFontAspectF', 'cnInfoLabelFontColor',
+                'cnInfoLabelFontHeightF', 'cnInfoLabelFontQuality',
+                'cnInfoLabelFontThicknessF', 'cnInfoLabelFormat',
+                'cnInfoLabelFuncCode', 'cnInfoLabelJust', 'cnInfoLabelOn',
+                'cnInfoLabelOrthogonalPosF', 'cnInfoLabelParallelPosF',
+                'cnInfoLabelPerimColor', 'cnInfoLabelPerimOn',
+                'cnInfoLabelPerimSpaceF', 'cnInfoLabelPerimThicknessF',
+                'cnInfoLabelSide', 'cnInfoLabelString', 'cnInfoLabelTextDirection',
+                'cnInfoLabelZone', 'cnLabelBarEndLabelsOn', 'cnLabelBarEndStyle',
+                'cnLabelDrawOrder', 'cnLabelMasking', 'cnLabelScaleFactorF',
+                'cnLabelScaleValueF', 'cnLabelScalingMode', 'cnLegendLevelFlags',
+                'cnLevelCount', 'cnLevelFlag', 'cnLevelFlags', 'cnLevelSelectionMode',
+                'cnLevelSpacingF', 'cnLevels', 'cnLineColor', 'cnLineColors',
+                'cnLineDashPattern', 'cnLineDashPatterns', 'cnLineDashSegLenF',
+                'cnLineDrawOrder', 'cnLineLabelAngleF', 'cnLineLabelBackgroundColor',
+                'cnLineLabelConstantSpacingF', 'cnLineLabelCount',
+                'cnLineLabelDensityF', 'cnLineLabelFont', 'cnLineLabelFontAspectF',
+                'cnLineLabelFontColor', 'cnLineLabelFontColors',
+                'cnLineLabelFontHeightF', 'cnLineLabelFontQuality',
+                'cnLineLabelFontThicknessF', 'cnLineLabelFormat',
+                'cnLineLabelFuncCode', 'cnLineLabelInterval', 'cnLineLabelPerimColor',
+                'cnLineLabelPerimOn', 'cnLineLabelPerimSpaceF',
+                'cnLineLabelPerimThicknessF', 'cnLineLabelPlacementMode',
+                'cnLineLabelStrings', 'cnLineLabelsOn', 'cnLinePalette',
+                'cnLineThicknessF', 'cnLineThicknesses', 'cnLinesOn',
+                'cnLowLabelAngleF', 'cnLowLabelBackgroundColor',
+                'cnLowLabelConstantSpacingF', 'cnLowLabelCount', 'cnLowLabelFont',
+                'cnLowLabelFontAspectF', 'cnLowLabelFontColor',
+                'cnLowLabelFontHeightF', 'cnLowLabelFontQuality',
+                'cnLowLabelFontThicknessF', 'cnLowLabelFormat', 'cnLowLabelFuncCode',
+                'cnLowLabelPerimColor', 'cnLowLabelPerimOn', 'cnLowLabelPerimSpaceF',
+                'cnLowLabelPerimThicknessF', 'cnLowLabelString', 'cnLowLabelsOn',
+                'cnLowUseHighLabelRes', 'cnMaxDataValueFormat', 'cnMaxLevelCount',
+                'cnMaxLevelValF', 'cnMaxPointDistanceF', 'cnMinLevelValF',
+                'cnMissingValFillColor', 'cnMissingValFillPattern',
+                'cnMissingValFillScaleF', 'cnMissingValPerimColor',
+                'cnMissingValPerimDashPattern', 'cnMissingValPerimGridBoundOn',
+                'cnMissingValPerimOn', 'cnMissingValPerimThicknessF',
+                'cnMonoFillColor', 'cnMonoFillPattern', 'cnMonoFillScale',
+                'cnMonoLevelFlag', 'cnMonoLineColor', 'cnMonoLineDashPattern',
+                'cnMonoLineLabelFontColor', 'cnMonoLineThickness', 'cnNoDataLabelOn',
+                'cnNoDataLabelString', 'cnOutOfRangeFillColor',
+                'cnOutOfRangeFillPattern', 'cnOutOfRangeFillScaleF',
+                'cnOutOfRangePerimColor', 'cnOutOfRangePerimDashPattern',
+                'cnOutOfRangePerimOn', 'cnOutOfRangePerimThicknessF',
+                'cnRasterCellSizeF', 'cnRasterMinCellSizeF', 'cnRasterModeOn',
+                'cnRasterSampleFactorF', 'cnRasterSmoothingOn', 'cnScalarFieldData',
+                'cnSmoothingDistanceF', 'cnSmoothingOn', 'cnSmoothingTensionF',
+                'cnSpanFillPalette', 'cnSpanLinePalette', 'ctCopyTables',
+                'ctXElementSize', 'ctXMaxV', 'ctXMinV', 'ctXMissingV', 'ctXTable',
+                'ctXTableLengths', 'ctXTableType', 'ctYElementSize', 'ctYMaxV',
+                'ctYMinV', 'ctYMissingV', 'ctYTable', 'ctYTableLengths',
+                'ctYTableType', 'dcDelayCompute', 'errBuffer',
+                'errFileName', 'errFilePtr', 'errLevel', 'errPrint', 'errUnitNumber',
+                'gsClipOn', 'gsColors', 'gsEdgeColor', 'gsEdgeDashPattern',
+                'gsEdgeDashSegLenF', 'gsEdgeThicknessF', 'gsEdgesOn',
+                'gsFillBackgroundColor', 'gsFillColor', 'gsFillDotSizeF',
+                'gsFillIndex', 'gsFillLineThicknessF', 'gsFillOpacityF',
+                'gsFillScaleF', 'gsFont', 'gsFontAspectF', 'gsFontColor',
+                'gsFontHeightF', 'gsFontOpacityF', 'gsFontQuality',
+                'gsFontThicknessF', 'gsLineColor', 'gsLineDashPattern',
+                'gsLineDashSegLenF', 'gsLineLabelConstantSpacingF', 'gsLineLabelFont',
+                'gsLineLabelFontAspectF', 'gsLineLabelFontColor',
+                'gsLineLabelFontHeightF', 'gsLineLabelFontQuality',
+                'gsLineLabelFontThicknessF', 'gsLineLabelFuncCode',
+                'gsLineLabelString', 'gsLineOpacityF', 'gsLineThicknessF',
+                'gsMarkerColor', 'gsMarkerIndex', 'gsMarkerOpacityF', 'gsMarkerSizeF',
+                'gsMarkerThicknessF', 'gsSegments', 'gsTextAngleF',
+                'gsTextConstantSpacingF', 'gsTextDirection', 'gsTextFuncCode',
+                'gsTextJustification', 'gsnAboveYRefLineBarColors',
+                'gsnAboveYRefLineBarFillScales', 'gsnAboveYRefLineBarPatterns',
+                'gsnAboveYRefLineColor', 'gsnAddCyclic', 'gsnAttachBorderOn',
+                'gsnAttachPlotsXAxis', 'gsnBelowYRefLineBarColors',
+                'gsnBelowYRefLineBarFillScales', 'gsnBelowYRefLineBarPatterns',
+                'gsnBelowYRefLineColor', 'gsnBoxMargin', 'gsnCenterString',
+                'gsnCenterStringFontColor', 'gsnCenterStringFontHeightF',
+                'gsnCenterStringFuncCode', 'gsnCenterStringOrthogonalPosF',
+                'gsnCenterStringParallelPosF', 'gsnContourLineThicknessesScale',
+                'gsnContourNegLineDashPattern', 'gsnContourPosLineDashPattern',
+                'gsnContourZeroLineThicknessF', 'gsnDebugWriteFileName', 'gsnDraw',
+                'gsnFrame', 'gsnHistogramBarWidthPercent', 'gsnHistogramBinIntervals',
+                'gsnHistogramBinMissing', 'gsnHistogramBinWidth',
+                'gsnHistogramClassIntervals', 'gsnHistogramCompare',
+                'gsnHistogramComputePercentages',
+                'gsnHistogramComputePercentagesNoMissing',
+                'gsnHistogramDiscreteBinValues', 'gsnHistogramDiscreteClassValues',
+                'gsnHistogramHorizontal', 'gsnHistogramMinMaxBinsOn',
+                'gsnHistogramNumberOfBins', 'gsnHistogramPercentSign',
+                'gsnHistogramSelectNiceIntervals', 'gsnLeftString',
+                'gsnLeftStringFontColor', 'gsnLeftStringFontHeightF',
+                'gsnLeftStringFuncCode', 'gsnLeftStringOrthogonalPosF',
+                'gsnLeftStringParallelPosF', 'gsnMajorLatSpacing',
+                'gsnMajorLonSpacing', 'gsnMaskLambertConformal',
+                'gsnMaskLambertConformalOutlineOn', 'gsnMaximize',
+                'gsnMinorLatSpacing', 'gsnMinorLonSpacing', 'gsnPanelBottom',
+                'gsnPanelCenter', 'gsnPanelDebug', 'gsnPanelFigureStrings',
+                'gsnPanelFigureStringsBackgroundFillColor',
+                'gsnPanelFigureStringsFontHeightF', 'gsnPanelFigureStringsJust',
+                'gsnPanelFigureStringsPerimOn', 'gsnPanelLabelBar', 'gsnPanelLeft',
+                'gsnPanelMainFont', 'gsnPanelMainFontColor',
+                'gsnPanelMainFontHeightF', 'gsnPanelMainString', 'gsnPanelRight',
+                'gsnPanelRowSpec', 'gsnPanelScalePlotIndex', 'gsnPanelTop',
+                'gsnPanelXF', 'gsnPanelXWhiteSpacePercent', 'gsnPanelYF',
+                'gsnPanelYWhiteSpacePercent', 'gsnPaperHeight', 'gsnPaperMargin',
+                'gsnPaperOrientation', 'gsnPaperWidth', 'gsnPolar',
+                'gsnPolarLabelDistance', 'gsnPolarLabelFont',
+                'gsnPolarLabelFontHeightF', 'gsnPolarLabelSpacing', 'gsnPolarTime',
+                'gsnPolarUT', 'gsnRightString', 'gsnRightStringFontColor',
+                'gsnRightStringFontHeightF', 'gsnRightStringFuncCode',
+                'gsnRightStringOrthogonalPosF', 'gsnRightStringParallelPosF',
+                'gsnScalarContour', 'gsnScale', 'gsnShape', 'gsnSpreadColorEnd',
+                'gsnSpreadColorStart', 'gsnSpreadColors', 'gsnStringFont',
+                'gsnStringFontColor', 'gsnStringFontHeightF', 'gsnStringFuncCode',
+                'gsnTickMarksOn', 'gsnXAxisIrregular2Linear', 'gsnXAxisIrregular2Log',
+                'gsnXRefLine', 'gsnXRefLineColor', 'gsnXRefLineDashPattern',
+                'gsnXRefLineThicknessF', 'gsnXYAboveFillColors', 'gsnXYBarChart',
+                'gsnXYBarChartBarWidth', 'gsnXYBarChartColors',
+                'gsnXYBarChartColors2', 'gsnXYBarChartFillDotSizeF',
+                'gsnXYBarChartFillLineThicknessF', 'gsnXYBarChartFillOpacityF',
+                'gsnXYBarChartFillScaleF', 'gsnXYBarChartOutlineOnly',
+                'gsnXYBarChartOutlineThicknessF', 'gsnXYBarChartPatterns',
+                'gsnXYBarChartPatterns2', 'gsnXYBelowFillColors', 'gsnXYFillColors',
+                'gsnXYFillOpacities', 'gsnXYLeftFillColors', 'gsnXYRightFillColors',
+                'gsnYAxisIrregular2Linear', 'gsnYAxisIrregular2Log', 'gsnYRefLine',
+                'gsnYRefLineColor', 'gsnYRefLineColors', 'gsnYRefLineDashPattern',
+                'gsnYRefLineDashPatterns', 'gsnYRefLineThicknessF',
+                'gsnYRefLineThicknesses', 'gsnZonalMean', 'gsnZonalMeanXMaxF',
+                'gsnZonalMeanXMinF', 'gsnZonalMeanYRefLine', 'lbAutoManage',
+                'lbBottomMarginF', 'lbBoxCount', 'lbBoxEndCapStyle', 'lbBoxFractions',
+                'lbBoxLineColor', 'lbBoxLineDashPattern', 'lbBoxLineDashSegLenF',
+                'lbBoxLineThicknessF', 'lbBoxLinesOn', 'lbBoxMajorExtentF',
+                'lbBoxMinorExtentF', 'lbBoxSeparatorLinesOn', 'lbBoxSizing',
+                'lbFillBackground', 'lbFillColor', 'lbFillColors', 'lbFillDotSizeF',
+                'lbFillLineThicknessF', 'lbFillPattern', 'lbFillPatterns',
+                'lbFillScaleF', 'lbFillScales', 'lbJustification', 'lbLabelAlignment',
+                'lbLabelAngleF', 'lbLabelAutoStride', 'lbLabelBarOn',
+                'lbLabelConstantSpacingF', 'lbLabelDirection', 'lbLabelFont',
+                'lbLabelFontAspectF', 'lbLabelFontColor', 'lbLabelFontHeightF',
+                'lbLabelFontQuality', 'lbLabelFontThicknessF', 'lbLabelFuncCode',
+                'lbLabelJust', 'lbLabelOffsetF', 'lbLabelPosition', 'lbLabelStride',
+                'lbLabelStrings', 'lbLabelsOn', 'lbLeftMarginF', 'lbMaxLabelLenF',
+                'lbMinLabelSpacingF', 'lbMonoFillColor', 'lbMonoFillPattern',
+                'lbMonoFillScale', 'lbOrientation', 'lbPerimColor',
+                'lbPerimDashPattern', 'lbPerimDashSegLenF', 'lbPerimFill',
+                'lbPerimFillColor', 'lbPerimOn', 'lbPerimThicknessF',
+                'lbRasterFillOn', 'lbRightMarginF', 'lbTitleAngleF',
+                'lbTitleConstantSpacingF', 'lbTitleDirection', 'lbTitleExtentF',
+                'lbTitleFont', 'lbTitleFontAspectF', 'lbTitleFontColor',
+                'lbTitleFontHeightF', 'lbTitleFontQuality', 'lbTitleFontThicknessF',
+                'lbTitleFuncCode', 'lbTitleJust', 'lbTitleOffsetF', 'lbTitleOn',
+                'lbTitlePosition', 'lbTitleString', 'lbTopMarginF', 'lgAutoManage',
+                'lgBottomMarginF', 'lgBoxBackground', 'lgBoxLineColor',
+                'lgBoxLineDashPattern', 'lgBoxLineDashSegLenF', 'lgBoxLineThicknessF',
+                'lgBoxLinesOn', 'lgBoxMajorExtentF', 'lgBoxMinorExtentF',
+                'lgDashIndex', 'lgDashIndexes', 'lgItemCount', 'lgItemOrder',
+                'lgItemPlacement', 'lgItemPositions', 'lgItemType', 'lgItemTypes',
+                'lgJustification', 'lgLabelAlignment', 'lgLabelAngleF',
+                'lgLabelAutoStride', 'lgLabelConstantSpacingF', 'lgLabelDirection',
+                'lgLabelFont', 'lgLabelFontAspectF', 'lgLabelFontColor',
+                'lgLabelFontHeightF', 'lgLabelFontQuality', 'lgLabelFontThicknessF',
+                'lgLabelFuncCode', 'lgLabelJust', 'lgLabelOffsetF', 'lgLabelPosition',
+                'lgLabelStride', 'lgLabelStrings', 'lgLabelsOn', 'lgLeftMarginF',
+                'lgLegendOn', 'lgLineColor', 'lgLineColors', 'lgLineDashSegLenF',
+                'lgLineDashSegLens', 'lgLineLabelConstantSpacingF', 'lgLineLabelFont',
+                'lgLineLabelFontAspectF', 'lgLineLabelFontColor',
+                'lgLineLabelFontColors', 'lgLineLabelFontHeightF',
+                'lgLineLabelFontHeights', 'lgLineLabelFontQuality',
+                'lgLineLabelFontThicknessF', 'lgLineLabelFuncCode',
+                'lgLineLabelStrings', 'lgLineLabelsOn', 'lgLineThicknessF',
+                'lgLineThicknesses', 'lgMarkerColor', 'lgMarkerColors',
+                'lgMarkerIndex', 'lgMarkerIndexes', 'lgMarkerSizeF', 'lgMarkerSizes',
+                'lgMarkerThicknessF', 'lgMarkerThicknesses', 'lgMonoDashIndex',
+                'lgMonoItemType', 'lgMonoLineColor', 'lgMonoLineDashSegLen',
+                'lgMonoLineLabelFontColor', 'lgMonoLineLabelFontHeight',
+                'lgMonoLineThickness', 'lgMonoMarkerColor', 'lgMonoMarkerIndex',
+                'lgMonoMarkerSize', 'lgMonoMarkerThickness', 'lgOrientation',
+                'lgPerimColor', 'lgPerimDashPattern', 'lgPerimDashSegLenF',
+                'lgPerimFill', 'lgPerimFillColor', 'lgPerimOn', 'lgPerimThicknessF',
+                'lgRightMarginF', 'lgTitleAngleF', 'lgTitleConstantSpacingF',
+                'lgTitleDirection', 'lgTitleExtentF', 'lgTitleFont',
+                'lgTitleFontAspectF', 'lgTitleFontColor', 'lgTitleFontHeightF',
+                'lgTitleFontQuality', 'lgTitleFontThicknessF', 'lgTitleFuncCode',
+                'lgTitleJust', 'lgTitleOffsetF', 'lgTitleOn', 'lgTitlePosition',
+                'lgTitleString', 'lgTopMarginF', 'mpAreaGroupCount',
+                'mpAreaMaskingOn', 'mpAreaNames', 'mpAreaTypes', 'mpBottomAngleF',
+                'mpBottomMapPosF', 'mpBottomNDCF', 'mpBottomNPCF',
+                'mpBottomPointLatF', 'mpBottomPointLonF', 'mpBottomWindowF',
+                'mpCenterLatF', 'mpCenterLonF', 'mpCenterRotF', 'mpCountyLineColor',
+                'mpCountyLineDashPattern', 'mpCountyLineDashSegLenF',
+                'mpCountyLineThicknessF', 'mpDataBaseVersion', 'mpDataResolution',
+                'mpDataSetName', 'mpDefaultFillColor', 'mpDefaultFillPattern',
+                'mpDefaultFillScaleF', 'mpDynamicAreaGroups', 'mpEllipticalBoundary',
+                'mpFillAreaSpecifiers', 'mpFillBoundarySets', 'mpFillColor',
+                'mpFillColors', 'mpFillColors-default', 'mpFillDotSizeF',
+                'mpFillDrawOrder', 'mpFillOn', 'mpFillPatternBackground',
+                'mpFillPattern', 'mpFillPatterns', 'mpFillPatterns-default',
+                'mpFillScaleF', 'mpFillScales', 'mpFillScales-default',
+                'mpFixedAreaGroups', 'mpGeophysicalLineColor',
+                'mpGeophysicalLineDashPattern', 'mpGeophysicalLineDashSegLenF',
+                'mpGeophysicalLineThicknessF', 'mpGreatCircleLinesOn',
+                'mpGridAndLimbDrawOrder', 'mpGridAndLimbOn', 'mpGridLatSpacingF',
+                'mpGridLineColor', 'mpGridLineDashPattern', 'mpGridLineDashSegLenF',
+                'mpGridLineThicknessF', 'mpGridLonSpacingF', 'mpGridMaskMode',
+                'mpGridMaxLatF', 'mpGridPolarLonSpacingF', 'mpGridSpacingF',
+                'mpInlandWaterFillColor', 'mpInlandWaterFillPattern',
+                'mpInlandWaterFillScaleF', 'mpLabelDrawOrder', 'mpLabelFontColor',
+                'mpLabelFontHeightF', 'mpLabelsOn', 'mpLambertMeridianF',
+                'mpLambertParallel1F', 'mpLambertParallel2F', 'mpLandFillColor',
+                'mpLandFillPattern', 'mpLandFillScaleF', 'mpLeftAngleF',
+                'mpLeftCornerLatF', 'mpLeftCornerLonF', 'mpLeftMapPosF',
+                'mpLeftNDCF', 'mpLeftNPCF', 'mpLeftPointLatF',
+                'mpLeftPointLonF', 'mpLeftWindowF', 'mpLimbLineColor',
+                'mpLimbLineDashPattern', 'mpLimbLineDashSegLenF',
+                'mpLimbLineThicknessF', 'mpLimitMode', 'mpMaskAreaSpecifiers',
+                'mpMaskOutlineSpecifiers', 'mpMaxLatF', 'mpMaxLonF',
+                'mpMinLatF', 'mpMinLonF', 'mpMonoFillColor', 'mpMonoFillPattern',
+                'mpMonoFillScale', 'mpNationalLineColor', 'mpNationalLineDashPattern',
+                'mpNationalLineThicknessF', 'mpOceanFillColor', 'mpOceanFillPattern',
+                'mpOceanFillScaleF', 'mpOutlineBoundarySets', 'mpOutlineDrawOrder',
+                'mpOutlineMaskingOn', 'mpOutlineOn', 'mpOutlineSpecifiers',
+                'mpPerimDrawOrder', 'mpPerimLineColor', 'mpPerimLineDashPattern',
+                'mpPerimLineDashSegLenF', 'mpPerimLineThicknessF', 'mpPerimOn',
+                'mpPolyMode', 'mpProjection', 'mpProvincialLineColor',
+                'mpProvincialLineDashPattern', 'mpProvincialLineDashSegLenF',
+                'mpProvincialLineThicknessF', 'mpRelativeCenterLat',
+                'mpRelativeCenterLon', 'mpRightAngleF', 'mpRightCornerLatF',
+                'mpRightCornerLonF', 'mpRightMapPosF', 'mpRightNDCF',
+                'mpRightNPCF', 'mpRightPointLatF', 'mpRightPointLonF',
+                'mpRightWindowF', 'mpSatelliteAngle1F', 'mpSatelliteAngle2F',
+                'mpSatelliteDistF', 'mpShapeMode', 'mpSpecifiedFillColors',
+                'mpSpecifiedFillDirectIndexing', 'mpSpecifiedFillPatterns',
+                'mpSpecifiedFillPriority', 'mpSpecifiedFillScales',
+                'mpTopAngleF', 'mpTopMapPosF', 'mpTopNDCF', 'mpTopNPCF',
+                'mpTopPointLatF', 'mpTopPointLonF', 'mpTopWindowF',
+                'mpUSStateLineColor', 'mpUSStateLineDashPattern',
+                'mpUSStateLineDashSegLenF', 'mpUSStateLineThicknessF',
+                'pmAnnoManagers', 'pmAnnoViews', 'pmLabelBarDisplayMode',
+                'pmLabelBarHeightF', 'pmLabelBarKeepAspect', 'pmLabelBarOrthogonalPosF',
+                'pmLabelBarParallelPosF', 'pmLabelBarSide', 'pmLabelBarWidthF',
+                'pmLabelBarZone', 'pmLegendDisplayMode', 'pmLegendHeightF',
+                'pmLegendKeepAspect', 'pmLegendOrthogonalPosF',
+                'pmLegendParallelPosF', 'pmLegendSide', 'pmLegendWidthF',
+                'pmLegendZone', 'pmOverlaySequenceIds', 'pmTickMarkDisplayMode',
+                'pmTickMarkZone', 'pmTitleDisplayMode', 'pmTitleZone',
+                'prGraphicStyle', 'prPolyType', 'prXArray', 'prYArray',
+                'sfCopyData', 'sfDataArray', 'sfDataMaxV', 'sfDataMinV',
+                'sfElementNodes', 'sfExchangeDimensions', 'sfFirstNodeIndex',
+                'sfMissingValueV', 'sfXArray', 'sfXCActualEndF', 'sfXCActualStartF',
+                'sfXCEndIndex', 'sfXCEndSubsetV', 'sfXCEndV', 'sfXCStartIndex',
+                'sfXCStartSubsetV', 'sfXCStartV', 'sfXCStride', 'sfXCellBounds',
+                'sfYArray', 'sfYCActualEndF', 'sfYCActualStartF', 'sfYCEndIndex',
+                'sfYCEndSubsetV', 'sfYCEndV', 'sfYCStartIndex', 'sfYCStartSubsetV',
+                'sfYCStartV', 'sfYCStride', 'sfYCellBounds', 'stArrowLengthF',
+                'stArrowStride', 'stCrossoverCheckCount',
+                'stExplicitLabelBarLabelsOn', 'stLabelBarEndLabelsOn',
+                'stLabelFormat', 'stLengthCheckCount', 'stLevelColors',
+                'stLevelCount', 'stLevelPalette', 'stLevelSelectionMode',
+                'stLevelSpacingF', 'stLevels', 'stLineColor', 'stLineOpacityF',
+                'stLineStartStride', 'stLineThicknessF', 'stMapDirection',
+                'stMaxLevelCount', 'stMaxLevelValF', 'stMinArrowSpacingF',
+                'stMinDistanceF', 'stMinLevelValF', 'stMinLineSpacingF',
+                'stMinStepFactorF', 'stMonoLineColor', 'stNoDataLabelOn',
+                'stNoDataLabelString', 'stScalarFieldData', 'stScalarMissingValColor',
+                'stSpanLevelPalette', 'stStepSizeF', 'stStreamlineDrawOrder',
+                'stUseScalarArray', 'stVectorFieldData', 'stZeroFLabelAngleF',
+                'stZeroFLabelBackgroundColor', 'stZeroFLabelConstantSpacingF',
+                'stZeroFLabelFont', 'stZeroFLabelFontAspectF',
+                'stZeroFLabelFontColor', 'stZeroFLabelFontHeightF',
+                'stZeroFLabelFontQuality', 'stZeroFLabelFontThicknessF',
+                'stZeroFLabelFuncCode', 'stZeroFLabelJust', 'stZeroFLabelOn',
+                'stZeroFLabelOrthogonalPosF', 'stZeroFLabelParallelPosF',
+                'stZeroFLabelPerimColor', 'stZeroFLabelPerimOn',
+                'stZeroFLabelPerimSpaceF', 'stZeroFLabelPerimThicknessF',
+                'stZeroFLabelSide', 'stZeroFLabelString', 'stZeroFLabelTextDirection',
+                'stZeroFLabelZone', 'tfDoNDCOverlay', 'tfPlotManagerOn',
+                'tfPolyDrawList', 'tfPolyDrawOrder', 'tiDeltaF', 'tiMainAngleF',
+                'tiMainConstantSpacingF', 'tiMainDirection', 'tiMainFont',
+                'tiMainFontAspectF', 'tiMainFontColor', 'tiMainFontHeightF',
+                'tiMainFontQuality', 'tiMainFontThicknessF', 'tiMainFuncCode',
+                'tiMainJust', 'tiMainOffsetXF', 'tiMainOffsetYF', 'tiMainOn',
+                'tiMainPosition', 'tiMainSide', 'tiMainString', 'tiUseMainAttributes',
+                'tiXAxisAngleF', 'tiXAxisConstantSpacingF', 'tiXAxisDirection',
+                'tiXAxisFont', 'tiXAxisFontAspectF', 'tiXAxisFontColor',
+                'tiXAxisFontHeightF', 'tiXAxisFontQuality', 'tiXAxisFontThicknessF',
+                'tiXAxisFuncCode', 'tiXAxisJust', 'tiXAxisOffsetXF',
+                'tiXAxisOffsetYF', 'tiXAxisOn', 'tiXAxisPosition', 'tiXAxisSide',
+                'tiXAxisString', 'tiYAxisAngleF', 'tiYAxisConstantSpacingF',
+                'tiYAxisDirection', 'tiYAxisFont', 'tiYAxisFontAspectF',
+                'tiYAxisFontColor', 'tiYAxisFontHeightF', 'tiYAxisFontQuality',
+                'tiYAxisFontThicknessF', 'tiYAxisFuncCode', 'tiYAxisJust',
+                'tiYAxisOffsetXF', 'tiYAxisOffsetYF', 'tiYAxisOn', 'tiYAxisPosition',
+                'tiYAxisSide', 'tiYAxisString', 'tmBorderLineColor',
+                'tmBorderThicknessF', 'tmEqualizeXYSizes', 'tmLabelAutoStride',
+                'tmSciNoteCutoff', 'tmXBAutoPrecision', 'tmXBBorderOn',
+                'tmXBDataLeftF', 'tmXBDataRightF', 'tmXBFormat', 'tmXBIrrTensionF',
+                'tmXBIrregularPoints', 'tmXBLabelAngleF', 'tmXBLabelConstantSpacingF',
+                'tmXBLabelDeltaF', 'tmXBLabelDirection', 'tmXBLabelFont',
+                'tmXBLabelFontAspectF', 'tmXBLabelFontColor', 'tmXBLabelFontHeightF',
+                'tmXBLabelFontQuality', 'tmXBLabelFontThicknessF',
+                'tmXBLabelFuncCode', 'tmXBLabelJust', 'tmXBLabelStride', 'tmXBLabels',
+                'tmXBLabelsOn', 'tmXBMajorLengthF', 'tmXBMajorLineColor',
+                'tmXBMajorOutwardLengthF', 'tmXBMajorThicknessF', 'tmXBMaxLabelLenF',
+                'tmXBMaxTicks', 'tmXBMinLabelSpacingF', 'tmXBMinorLengthF',
+                'tmXBMinorLineColor', 'tmXBMinorOn', 'tmXBMinorOutwardLengthF',
+                'tmXBMinorPerMajor', 'tmXBMinorThicknessF', 'tmXBMinorValues',
+                'tmXBMode', 'tmXBOn', 'tmXBPrecision', 'tmXBStyle', 'tmXBTickEndF',
+                'tmXBTickSpacingF', 'tmXBTickStartF', 'tmXBValues', 'tmXMajorGrid',
+                'tmXMajorGridLineColor', 'tmXMajorGridLineDashPattern',
+                'tmXMajorGridThicknessF', 'tmXMinorGrid', 'tmXMinorGridLineColor',
+                'tmXMinorGridLineDashPattern', 'tmXMinorGridThicknessF',
+                'tmXTAutoPrecision', 'tmXTBorderOn', 'tmXTDataLeftF',
+                'tmXTDataRightF', 'tmXTFormat', 'tmXTIrrTensionF',
+                'tmXTIrregularPoints', 'tmXTLabelAngleF', 'tmXTLabelConstantSpacingF',
+                'tmXTLabelDeltaF', 'tmXTLabelDirection', 'tmXTLabelFont',
+                'tmXTLabelFontAspectF', 'tmXTLabelFontColor', 'tmXTLabelFontHeightF',
+                'tmXTLabelFontQuality', 'tmXTLabelFontThicknessF',
+                'tmXTLabelFuncCode', 'tmXTLabelJust', 'tmXTLabelStride', 'tmXTLabels',
+                'tmXTLabelsOn', 'tmXTMajorLengthF', 'tmXTMajorLineColor',
+                'tmXTMajorOutwardLengthF', 'tmXTMajorThicknessF', 'tmXTMaxLabelLenF',
+                'tmXTMaxTicks', 'tmXTMinLabelSpacingF', 'tmXTMinorLengthF',
+                'tmXTMinorLineColor', 'tmXTMinorOn', 'tmXTMinorOutwardLengthF',
+                'tmXTMinorPerMajor', 'tmXTMinorThicknessF', 'tmXTMinorValues',
+                'tmXTMode', 'tmXTOn', 'tmXTPrecision', 'tmXTStyle', 'tmXTTickEndF',
+                'tmXTTickSpacingF', 'tmXTTickStartF', 'tmXTValues', 'tmXUseBottom',
+                'tmYLAutoPrecision', 'tmYLBorderOn', 'tmYLDataBottomF',
+                'tmYLDataTopF', 'tmYLFormat', 'tmYLIrrTensionF',
+                'tmYLIrregularPoints', 'tmYLLabelAngleF', 'tmYLLabelConstantSpacingF',
+                'tmYLLabelDeltaF', 'tmYLLabelDirection', 'tmYLLabelFont',
+                'tmYLLabelFontAspectF', 'tmYLLabelFontColor', 'tmYLLabelFontHeightF',
+                'tmYLLabelFontQuality', 'tmYLLabelFontThicknessF',
+                'tmYLLabelFuncCode', 'tmYLLabelJust', 'tmYLLabelStride', 'tmYLLabels',
+                'tmYLLabelsOn', 'tmYLMajorLengthF', 'tmYLMajorLineColor',
+                'tmYLMajorOutwardLengthF', 'tmYLMajorThicknessF', 'tmYLMaxLabelLenF',
+                'tmYLMaxTicks', 'tmYLMinLabelSpacingF', 'tmYLMinorLengthF',
+                'tmYLMinorLineColor', 'tmYLMinorOn', 'tmYLMinorOutwardLengthF',
+                'tmYLMinorPerMajor', 'tmYLMinorThicknessF', 'tmYLMinorValues',
+                'tmYLMode', 'tmYLOn', 'tmYLPrecision', 'tmYLStyle', 'tmYLTickEndF',
+                'tmYLTickSpacingF', 'tmYLTickStartF', 'tmYLValues', 'tmYMajorGrid',
+                'tmYMajorGridLineColor', 'tmYMajorGridLineDashPattern',
+                'tmYMajorGridThicknessF', 'tmYMinorGrid', 'tmYMinorGridLineColor',
+                'tmYMinorGridLineDashPattern', 'tmYMinorGridThicknessF',
+                'tmYRAutoPrecision', 'tmYRBorderOn', 'tmYRDataBottomF',
+                'tmYRDataTopF', 'tmYRFormat', 'tmYRIrrTensionF',
+                'tmYRIrregularPoints', 'tmYRLabelAngleF', 'tmYRLabelConstantSpacingF',
+                'tmYRLabelDeltaF', 'tmYRLabelDirection', 'tmYRLabelFont',
+                'tmYRLabelFontAspectF', 'tmYRLabelFontColor', 'tmYRLabelFontHeightF',
+                'tmYRLabelFontQuality', 'tmYRLabelFontThicknessF',
+                'tmYRLabelFuncCode', 'tmYRLabelJust', 'tmYRLabelStride', 'tmYRLabels',
+                'tmYRLabelsOn', 'tmYRMajorLengthF', 'tmYRMajorLineColor',
+                'tmYRMajorOutwardLengthF', 'tmYRMajorThicknessF', 'tmYRMaxLabelLenF',
+                'tmYRMaxTicks', 'tmYRMinLabelSpacingF', 'tmYRMinorLengthF',
+                'tmYRMinorLineColor', 'tmYRMinorOn', 'tmYRMinorOutwardLengthF',
+                'tmYRMinorPerMajor', 'tmYRMinorThicknessF', 'tmYRMinorValues',
+                'tmYRMode', 'tmYROn', 'tmYRPrecision', 'tmYRStyle', 'tmYRTickEndF',
+                'tmYRTickSpacingF', 'tmYRTickStartF', 'tmYRValues', 'tmYUseLeft',
+                'trGridType', 'trLineInterpolationOn',
+                'trXAxisType', 'trXCoordPoints', 'trXInterPoints', 'trXLog',
+                'trXMaxF', 'trXMinF', 'trXReverse', 'trXSamples', 'trXTensionF',
+                'trYAxisType', 'trYCoordPoints', 'trYInterPoints', 'trYLog',
+                'trYMaxF', 'trYMinF', 'trYReverse', 'trYSamples', 'trYTensionF',
+                'txAngleF', 'txBackgroundFillColor', 'txConstantSpacingF', 'txDirection',
+                'txFont', 'HLU-Fonts', 'txFontAspectF', 'txFontColor',
+                'txFontHeightF', 'txFontOpacityF', 'txFontQuality',
+                'txFontThicknessF', 'txFuncCode', 'txJust', 'txPerimColor',
+                'txPerimDashLengthF', 'txPerimDashPattern', 'txPerimOn',
+                'txPerimSpaceF', 'txPerimThicknessF', 'txPosXF', 'txPosYF',
+                'txString', 'vcExplicitLabelBarLabelsOn', 'vcFillArrowEdgeColor',
+                'vcFillArrowEdgeThicknessF', 'vcFillArrowFillColor',
+                'vcFillArrowHeadInteriorXF', 'vcFillArrowHeadMinFracXF',
+                'vcFillArrowHeadMinFracYF', 'vcFillArrowHeadXF', 'vcFillArrowHeadYF',
+                'vcFillArrowMinFracWidthF', 'vcFillArrowWidthF', 'vcFillArrowsOn',
+                'vcFillOverEdge', 'vcGlyphOpacityF', 'vcGlyphStyle',
+                'vcLabelBarEndLabelsOn', 'vcLabelFontColor', 'vcLabelFontHeightF',
+                'vcLabelsOn', 'vcLabelsUseVectorColor', 'vcLevelColors',
+                'vcLevelCount', 'vcLevelPalette', 'vcLevelSelectionMode',
+                'vcLevelSpacingF', 'vcLevels', 'vcLineArrowColor',
+                'vcLineArrowHeadMaxSizeF', 'vcLineArrowHeadMinSizeF',
+                'vcLineArrowThicknessF', 'vcMagnitudeFormat',
+                'vcMagnitudeScaleFactorF', 'vcMagnitudeScaleValueF',
+                'vcMagnitudeScalingMode', 'vcMapDirection', 'vcMaxLevelCount',
+                'vcMaxLevelValF', 'vcMaxMagnitudeF', 'vcMinAnnoAngleF',
+                'vcMinAnnoArrowAngleF', 'vcMinAnnoArrowEdgeColor',
+                'vcMinAnnoArrowFillColor', 'vcMinAnnoArrowLineColor',
+                'vcMinAnnoArrowMinOffsetF', 'vcMinAnnoArrowSpaceF',
+                'vcMinAnnoArrowUseVecColor', 'vcMinAnnoBackgroundColor',
+                'vcMinAnnoConstantSpacingF', 'vcMinAnnoExplicitMagnitudeF',
+                'vcMinAnnoFont', 'vcMinAnnoFontAspectF', 'vcMinAnnoFontColor',
+                'vcMinAnnoFontHeightF', 'vcMinAnnoFontQuality',
+                'vcMinAnnoFontThicknessF', 'vcMinAnnoFuncCode', 'vcMinAnnoJust',
+                'vcMinAnnoOn', 'vcMinAnnoOrientation', 'vcMinAnnoOrthogonalPosF',
+                'vcMinAnnoParallelPosF', 'vcMinAnnoPerimColor', 'vcMinAnnoPerimOn',
+                'vcMinAnnoPerimSpaceF', 'vcMinAnnoPerimThicknessF', 'vcMinAnnoSide',
+                'vcMinAnnoString1', 'vcMinAnnoString1On', 'vcMinAnnoString2',
+                'vcMinAnnoString2On', 'vcMinAnnoTextDirection', 'vcMinAnnoZone',
+                'vcMinDistanceF', 'vcMinFracLengthF', 'vcMinLevelValF',
+                'vcMinMagnitudeF', 'vcMonoFillArrowEdgeColor',
+                'vcMonoFillArrowFillColor', 'vcMonoLineArrowColor',
+                'vcMonoWindBarbColor', 'vcNoDataLabelOn', 'vcNoDataLabelString',
+                'vcPositionMode', 'vcRefAnnoAngleF', 'vcRefAnnoArrowAngleF',
+                'vcRefAnnoArrowEdgeColor', 'vcRefAnnoArrowFillColor',
+                'vcRefAnnoArrowLineColor', 'vcRefAnnoArrowMinOffsetF',
+                'vcRefAnnoArrowSpaceF', 'vcRefAnnoArrowUseVecColor',
+                'vcRefAnnoBackgroundColor', 'vcRefAnnoConstantSpacingF',
+                'vcRefAnnoExplicitMagnitudeF', 'vcRefAnnoFont',
+                'vcRefAnnoFontAspectF', 'vcRefAnnoFontColor', 'vcRefAnnoFontHeightF',
+                'vcRefAnnoFontQuality', 'vcRefAnnoFontThicknessF',
+                'vcRefAnnoFuncCode', 'vcRefAnnoJust', 'vcRefAnnoOn',
+                'vcRefAnnoOrientation', 'vcRefAnnoOrthogonalPosF',
+                'vcRefAnnoParallelPosF', 'vcRefAnnoPerimColor', 'vcRefAnnoPerimOn',
+                'vcRefAnnoPerimSpaceF', 'vcRefAnnoPerimThicknessF', 'vcRefAnnoSide',
+                'vcRefAnnoString1', 'vcRefAnnoString1On', 'vcRefAnnoString2',
+                'vcRefAnnoString2On', 'vcRefAnnoTextDirection', 'vcRefAnnoZone',
+                'vcRefLengthF', 'vcRefMagnitudeF', 'vcScalarFieldData',
+                'vcScalarMissingValColor', 'vcScalarValueFormat',
+                'vcScalarValueScaleFactorF', 'vcScalarValueScaleValueF',
+                'vcScalarValueScalingMode', 'vcSpanLevelPalette', 'vcUseRefAnnoRes',
+                'vcUseScalarArray', 'vcVectorDrawOrder', 'vcVectorFieldData',
+                'vcWindBarbCalmCircleSizeF', 'vcWindBarbColor',
+                'vcWindBarbLineThicknessF', 'vcWindBarbScaleFactorF',
+                'vcWindBarbTickAngleF', 'vcWindBarbTickLengthF',
+                'vcWindBarbTickSpacingF', 'vcZeroFLabelAngleF',
+                'vcZeroFLabelBackgroundColor', 'vcZeroFLabelConstantSpacingF',
+                'vcZeroFLabelFont', 'vcZeroFLabelFontAspectF',
+                'vcZeroFLabelFontColor', 'vcZeroFLabelFontHeightF',
+                'vcZeroFLabelFontQuality', 'vcZeroFLabelFontThicknessF',
+                'vcZeroFLabelFuncCode', 'vcZeroFLabelJust', 'vcZeroFLabelOn',
+                'vcZeroFLabelOrthogonalPosF', 'vcZeroFLabelParallelPosF',
+                'vcZeroFLabelPerimColor', 'vcZeroFLabelPerimOn',
+                'vcZeroFLabelPerimSpaceF', 'vcZeroFLabelPerimThicknessF',
+                'vcZeroFLabelSide', 'vcZeroFLabelString', 'vcZeroFLabelTextDirection',
+                'vcZeroFLabelZone', 'vfCopyData', 'vfDataArray',
+                'vfExchangeDimensions', 'vfExchangeUVData', 'vfMagMaxV', 'vfMagMinV',
+                'vfMissingUValueV', 'vfMissingVValueV', 'vfPolarData',
+                'vfSingleMissingValue', 'vfUDataArray', 'vfUMaxV', 'vfUMinV',
+                'vfVDataArray', 'vfVMaxV', 'vfVMinV', 'vfXArray', 'vfXCActualEndF',
+                'vfXCActualStartF', 'vfXCEndIndex', 'vfXCEndSubsetV', 'vfXCEndV',
+                'vfXCStartIndex', 'vfXCStartSubsetV', 'vfXCStartV', 'vfXCStride',
+                'vfYArray', 'vfYCActualEndF', 'vfYCActualStartF', 'vfYCEndIndex',
+                'vfYCEndSubsetV', 'vfYCEndV', 'vfYCStartIndex', 'vfYCStartSubsetV',
+                'vfYCStartV', 'vfYCStride', 'vpAnnoManagerId', 'vpClipOn',
+                'vpHeightF', 'vpKeepAspect', 'vpOn', 'vpUseSegments', 'vpWidthF',
+                'vpXF', 'vpYF', 'wkAntiAlias', 'wkBackgroundColor', 'wkBackgroundOpacityF',
+                'wkColorMapLen', 'wkColorMap', 'wkColorModel', 'wkDashTableLength',
+                'wkDefGraphicStyleId', 'wkDeviceLowerX', 'wkDeviceLowerY',
+                'wkDeviceUpperX', 'wkDeviceUpperY', 'wkFileName', 'wkFillTableLength',
+                'wkForegroundColor', 'wkFormat', 'wkFullBackground', 'wkGksWorkId',
+                'wkHeight', 'wkMarkerTableLength', 'wkMetaName', 'wkOrientation',
+                'wkPDFFileName', 'wkPDFFormat', 'wkPDFResolution', 'wkPSFileName',
+                'wkPSFormat', 'wkPSResolution', 'wkPaperHeightF', 'wkPaperSize',
+                'wkPaperWidthF', 'wkPause', 'wkTopLevelViews', 'wkViews',
+                'wkVisualType', 'wkWidth', 'wkWindowId', 'wkXColorMode', 'wsCurrentSize',
+                'wsMaximumSize', 'wsThresholdSize', 'xyComputeXMax',
+                'xyComputeXMin', 'xyComputeYMax', 'xyComputeYMin', 'xyCoordData',
+                'xyCoordDataSpec', 'xyCurveDrawOrder', 'xyDashPattern',
+                'xyDashPatterns', 'xyExplicitLabels', 'xyExplicitLegendLabels',
+                'xyLabelMode', 'xyLineColor', 'xyLineColors', 'xyLineDashSegLenF',
+                'xyLineLabelConstantSpacingF', 'xyLineLabelFont',
+                'xyLineLabelFontAspectF', 'xyLineLabelFontColor',
+                'xyLineLabelFontColors', 'xyLineLabelFontHeightF',
+                'xyLineLabelFontQuality', 'xyLineLabelFontThicknessF',
+                'xyLineLabelFuncCode', 'xyLineThicknessF', 'xyLineThicknesses',
+                'xyMarkLineMode', 'xyMarkLineModes', 'xyMarker', 'xyMarkerColor',
+                'xyMarkerColors', 'xyMarkerSizeF', 'xyMarkerSizes',
+                'xyMarkerThicknessF', 'xyMarkerThicknesses', 'xyMarkers',
+                'xyMonoDashPattern', 'xyMonoLineColor', 'xyMonoLineLabelFontColor',
+                'xyMonoLineThickness', 'xyMonoMarkLineMode', 'xyMonoMarker',
+                'xyMonoMarkerColor', 'xyMonoMarkerSize', 'xyMonoMarkerThickness',
+                'xyXIrrTensionF', 'xyXIrregularPoints', 'xyXStyle', 'xyYIrrTensionF',
+                'xyYIrregularPoints', 'xyYStyle'), prefix=r'\b'),
+             Name.Builtin),
+
+            # Booleans
+            (r'\.(True|False)\.', Name.Builtin),
+            # Comparing Operators
+            (r'\.(eq|ne|lt|le|gt|ge|not|and|or|xor)\.', Operator.Word),
+        ],
+
+        'strings': [
+            (r'(?s)"(\\\\|\\[0-7]+|\\.|[^"\\])*"', String.Double),
+        ],
+
+        'nums': [
+            (r'\d+(?![.e])(_[a-z]\w+)?', Number.Integer),
+            (r'[+-]?\d*\.\d+(e[-+]?\d+)?(_[a-z]\w+)?', Number.Float),
+            (r'[+-]?\d+\.\d*(e[-+]?\d+)?(_[a-z]\w+)?', Number.Float),
+        ],
+    }
diff --git a/pygments/lexers/nimrod.py b/pygments/lexers/nimrod.py
new file mode 100644
index 0000000000000000000000000000000000000000..365a8dcca961a8b1934b02d33cfaa181ed97736f
--- /dev/null
+++ b/pygments/lexers/nimrod.py
@@ -0,0 +1,199 @@
+"""
+    pygments.lexers.nimrod
+    ~~~~~~~~~~~~~~~~~~~~~~
+
+    Lexer for the Nim language (formerly known as Nimrod).
+
+    :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS.
+    :license: BSD, see LICENSE for details.
+"""
+
+import re
+
+from pygments.lexer import RegexLexer, include, default, bygroups
+from pygments.token import Text, Comment, Operator, Keyword, Name, String, \
+    Number, Punctuation, Error
+
+__all__ = ['NimrodLexer']
+
+
+class NimrodLexer(RegexLexer):
+    """
+    For Nim source code.
+    """
+
+    name = 'Nimrod'
+    url = 'http://nim-lang.org/'
+    aliases = ['nimrod', 'nim']
+    filenames = ['*.nim', '*.nimrod']
+    mimetypes = ['text/x-nim']
+    version_added = '1.5'
+
+    flags = re.MULTILINE | re.IGNORECASE
+
+    def underscorize(words):
+        newWords = []
+        new = []
+        for word in words:
+            for ch in word:
+                new.append(ch)
+                new.append("_?")
+            newWords.append(''.join(new))
+            new = []
+        return "|".join(newWords)
+
+    keywords = [
+        'addr', 'and', 'as', 'asm', 'bind', 'block', 'break', 'case',
+        'cast', 'concept', 'const', 'continue', 'converter', 'defer', 'discard',
+        'distinct', 'div', 'do', 'elif', 'else', 'end', 'enum', 'except',
+        'export', 'finally', 'for', 'if', 'in', 'yield', 'interface',
+        'is', 'isnot', 'iterator', 'let', 'mixin', 'mod',
+        'not', 'notin', 'object', 'of', 'or', 'out', 'ptr', 'raise',
+        'ref', 'return', 'shl', 'shr', 'static', 'try',
+        'tuple', 'type', 'using', 'when', 'while', 'xor'
+    ]
+
+    keywordsPseudo = [
+        'nil', 'true', 'false'
+    ]
+
+    opWords = [
+        'and', 'or', 'not', 'xor', 'shl', 'shr', 'div', 'mod', 'in',
+        'notin', 'is', 'isnot'
+    ]
+
+    types = [
+        'int', 'int8', 'int16', 'int32', 'int64', 'float', 'float32', 'float64',
+        'bool', 'char', 'range', 'array', 'seq', 'set', 'string'
+    ]
+
+    tokens = {
+        'root': [
+            # Comments
+            (r'##\[', String.Doc, 'doccomment'),
+            (r'##.*$', String.Doc),
+            (r'#\[', Comment.Multiline, 'comment'),
+            (r'#.*$', Comment),
+
+            # Pragmas
+            (r'\{\.', String.Other, 'pragma'),
+
+            # Operators
+            (r'[*=><+\-/@$~&%!?|\\\[\]]', Operator),
+            (r'\.\.|\.|,|\[\.|\.\]|\{\.|\.\}|\(\.|\.\)|\{|\}|\(|\)|:|\^|`|;',
+             Punctuation),
+
+            # Case statement branch
+            (r'(\n\s*)(of)(\s)', bygroups(Text.Whitespace, Keyword,
+                                          Text.Whitespace), 'casebranch'),
+
+            # Strings
+            (r'(?:[\w]+)"', String, 'rdqs'),
+            (r'"""', String.Double, 'tdqs'),
+            ('"', String, 'dqs'),
+
+            # Char
+            ("'", String.Char, 'chars'),
+
+            # Keywords
+            (rf'({underscorize(opWords)})\b', Operator.Word),
+            (r'(proc|func|method|macro|template)(\s)(?![(\[\]])',
+             bygroups(Keyword, Text.Whitespace), 'funcname'),
+            (rf'({underscorize(keywords)})\b', Keyword),
+            (r'({})\b'.format(underscorize(['from', 'import', 'include', 'export'])),
+             Keyword.Namespace),
+            (r'(v_?a_?r)\b', Keyword.Declaration),
+            (rf'({underscorize(types)})\b', Name.Builtin),
+            (rf'({underscorize(keywordsPseudo)})\b', Keyword.Pseudo),
+
+            # Identifiers
+            (r'\b((?![_\d])\w)(((?!_)\w)|(_(?!_)\w))*', Name),
+
+            # Numbers
+            (r'[0-9][0-9_]*(?=([e.]|\'f(32|64)))',
+             Number.Float, ('float-suffix', 'float-number')),
+            (r'0x[a-f0-9][a-f0-9_]*', Number.Hex, 'int-suffix'),
+            (r'0b[01][01_]*', Number.Bin, 'int-suffix'),
+            (r'0o[0-7][0-7_]*', Number.Oct, 'int-suffix'),
+            (r'[0-9][0-9_]*', Number.Integer, 'int-suffix'),
+
+            # Whitespace
+            (r'\s+', Text.Whitespace),
+            (r'.+$', Error),
+        ],
+        'chars': [
+            (r'\\([\\abcefnrtvl"\']|x[a-f0-9]{2}|[0-9]{1,3})', String.Escape),
+            (r"'", String.Char, '#pop'),
+            (r".", String.Char)
+        ],
+        'strings': [
+            (r'(?|>=|>>|>|<=|<<|<|\+|-|=|/|\*|%|\+=|-=|!|@', Operator),
+            (r'\(|\)|\[|\]|,|\.\.\.|\.\.|\.|::|:', Punctuation),
+            (r'`\{[^`]*`\}', Text),  # Extern blocks won't be Lexed by Nit
+            (r'[\r\n\t ]+', Text),
+        ],
+    }
diff --git a/pygments/lexers/nix.py b/pygments/lexers/nix.py
new file mode 100644
index 0000000000000000000000000000000000000000..3fa88c65aafd4843fe77b5b6774060339cb4c821
--- /dev/null
+++ b/pygments/lexers/nix.py
@@ -0,0 +1,144 @@
+"""
+    pygments.lexers.nix
+    ~~~~~~~~~~~~~~~~~~~
+
+    Lexers for the NixOS Nix language.
+
+    :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS.
+    :license: BSD, see LICENSE for details.
+"""
+
+import re
+
+from pygments.lexer import RegexLexer, include
+from pygments.token import Text, Comment, Operator, Keyword, Name, String, \
+    Number, Punctuation, Literal
+
+__all__ = ['NixLexer']
+
+
+class NixLexer(RegexLexer):
+    """
+    For the Nix language.
+    """
+
+    name = 'Nix'
+    url = 'http://nixos.org/nix/'
+    aliases = ['nixos', 'nix']
+    filenames = ['*.nix']
+    mimetypes = ['text/x-nix']
+    version_added = '2.0'
+
+    keywords = ['rec', 'with', 'let', 'in', 'inherit', 'assert', 'if',
+                'else', 'then', '...']
+    builtins = ['import', 'abort', 'baseNameOf', 'dirOf', 'isNull', 'builtins',
+                'map', 'removeAttrs', 'throw', 'toString', 'derivation']
+    operators = ['++', '+', '?', '.', '!', '//', '==', '/',
+                 '!=', '&&', '||', '->', '=', '<', '>', '*', '-']
+
+    punctuations = ["(", ")", "[", "]", ";", "{", "}", ":", ",", "@"]
+
+    tokens = {
+        'root': [
+            # comments starting with #
+            (r'#.*$', Comment.Single),
+
+            # multiline comments
+            (r'/\*', Comment.Multiline, 'comment'),
+
+            # whitespace
+            (r'\s+', Text),
+
+            # keywords
+            ('({})'.format('|'.join(re.escape(entry) + '\\b' for entry in keywords)), Keyword),
+
+            # highlight the builtins
+            ('({})'.format('|'.join(re.escape(entry) + '\\b' for entry in builtins)),
+             Name.Builtin),
+
+            (r'\b(true|false|null)\b', Name.Constant),
+
+            # floats
+            (r'-?(\d+\.\d*|\.\d+)([eE][-+]?\d+)?', Number.Float),
+
+            # integers
+            (r'-?[0-9]+', Number.Integer),
+
+            # paths
+            (r'[\w.+-]*(\/[\w.+-]+)+', Literal),
+            (r'~(\/[\w.+-]+)+', Literal),
+            (r'\<[\w.+-]+(\/[\w.+-]+)*\>', Literal),
+
+            # operators
+            ('({})'.format('|'.join(re.escape(entry) for entry in operators)),
+             Operator),
+
+            # word operators
+            (r'\b(or|and)\b', Operator.Word),
+
+            (r'\{', Punctuation, 'block'),
+
+            # punctuations
+            ('({})'.format('|'.join(re.escape(entry) for entry in punctuations)), Punctuation),
+
+            # strings
+            (r'"', String.Double, 'doublequote'),
+            (r"''", String.Multiline, 'multiline'),
+
+            # urls
+            (r'[a-zA-Z][a-zA-Z0-9\+\-\.]*\:[\w%/?:@&=+$,\\.!~*\'-]+', Literal),
+
+            # names of variables
+            (r'[\w-]+(?=\s*=)', String.Symbol),
+            (r'[a-zA-Z_][\w\'-]*', Text),
+
+            (r"\$\{", String.Interpol, 'antiquote'),
+        ],
+        'comment': [
+            (r'[^/*]+', Comment.Multiline),
+            (r'/\*', Comment.Multiline, '#push'),
+            (r'\*/', Comment.Multiline, '#pop'),
+            (r'[*/]', Comment.Multiline),
+        ],
+        'multiline': [
+            (r"''(\$|'|\\n|\\r|\\t|\\)", String.Escape),
+            (r"''", String.Multiline, '#pop'),
+            (r'\$\{', String.Interpol, 'antiquote'),
+            (r"[^'\$]+", String.Multiline),
+            (r"\$[^\{']", String.Multiline),
+            (r"'[^']", String.Multiline),
+            (r"\$(?=')", String.Multiline),
+        ],
+        'doublequote': [
+            (r'\\(\\|"|\$|n)', String.Escape),
+            (r'"', String.Double, '#pop'),
+            (r'\$\{', String.Interpol, 'antiquote'),
+            (r'[^"\\\$]+', String.Double),
+            (r'\$[^\{"]', String.Double),
+            (r'\$(?=")', String.Double),
+            (r'\\', String.Double),
+        ],
+        'antiquote': [
+            (r"\}", String.Interpol, '#pop'),
+            # TODO: we should probably escape also here ''${ \${
+            (r"\$\{", String.Interpol, '#push'),
+            include('root'),
+        ],
+        'block': [
+            (r"\}", Punctuation, '#pop'),
+            include('root'),
+        ],
+    }
+
+    def analyse_text(text):
+        rv = 0.0
+        # TODO: let/in
+        if re.search(r'import.+?<[^>]+>', text):
+            rv += 0.4
+        if re.search(r'mkDerivation\s+(\(|\{|rec)', text):
+            rv += 0.4
+        if re.search(r'=\s+mkIf\s+', text):
+            rv += 0.4
+        if re.search(r'\{[a-zA-Z,\s]+\}:', text):
+            rv += 0.1
+        return rv
diff --git a/pygments/lexers/numbair.py b/pygments/lexers/numbair.py
new file mode 100644
index 0000000000000000000000000000000000000000..435863e132afcff16c70ed81bb51739848c1ecc0
--- /dev/null
+++ b/pygments/lexers/numbair.py
@@ -0,0 +1,63 @@
+"""
+    pygments.lexers.numbair
+    ~~~~~~~~~~~~~~~~~~~~~~~
+
+    Lexer for other Numba Intermediate Representation.
+
+    :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS.
+    :license: BSD, see LICENSE for details.
+"""
+
+from pygments.lexer import RegexLexer, include, bygroups, words
+from pygments.token import Whitespace, Name, String,  Punctuation, Keyword, \
+    Operator, Number
+
+__all__ = ["NumbaIRLexer"]
+
+class NumbaIRLexer(RegexLexer):
+    """
+    Lexer for Numba IR
+    """
+    name = 'Numba_IR'
+    url = "https://numba.readthedocs.io/en/stable/developer/architecture.html#stage-2-generate-the-numba-ir"
+    aliases = ['numba_ir', 'numbair']
+    filenames = ['*.numba_ir']
+    mimetypes = ['text/x-numba_ir', 'text/x-numbair']
+    version_added = '2.19'
+
+    identifier = r'\$[a-zA-Z0-9._]+'
+    fun_or_var = r'([a-zA-Z_]+[a-zA-Z0-9]*)'
+
+    tokens = {
+        'root' : [
+            (r'(label)(\ [0-9]+)(:)$',
+                bygroups(Keyword, Name.Label, Punctuation)),
+
+            (r'=', Operator),
+            include('whitespace'),
+            include('keyword'),
+
+            (identifier, Name.Variable),
+            (fun_or_var + r'(\()',
+                bygroups(Name.Function, Punctuation)),
+            (fun_or_var + r'(\=)',
+                bygroups(Name.Attribute, Punctuation)),
+            (fun_or_var, Name.Constant),
+            (r'[0-9]+', Number),
+
+            # 
+            (r'<[^>\n]*>', String),
+
+            (r'[=<>{}\[\]()*.,!\':]|x\b', Punctuation)
+        ],
+
+        'keyword':[
+            (words((
+                'del', 'jump', 'call', 'branch',
+            ), suffix=' '), Keyword),
+        ],
+
+        'whitespace': [
+            (r'(\n|\s)+', Whitespace),
+        ],
+    }
diff --git a/pygments/lexers/oberon.py b/pygments/lexers/oberon.py
new file mode 100644
index 0000000000000000000000000000000000000000..61f3c2d2879e4dc7fedade39e203c1c50d5bbd20
--- /dev/null
+++ b/pygments/lexers/oberon.py
@@ -0,0 +1,120 @@
+"""
+    pygments.lexers.oberon
+    ~~~~~~~~~~~~~~~~~~~~~~
+
+    Lexers for Oberon family languages.
+
+    :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS.
+    :license: BSD, see LICENSE for details.
+"""
+
+import re
+
+from pygments.lexer import RegexLexer, include, words
+from pygments.token import Text, Comment, Operator, Keyword, Name, String, \
+    Number, Punctuation
+
+__all__ = ['ComponentPascalLexer']
+
+
+class ComponentPascalLexer(RegexLexer):
+    """
+    For Component Pascal source code.
+    """
+    name = 'Component Pascal'
+    aliases = ['componentpascal', 'cp']
+    filenames = ['*.cp', '*.cps']
+    mimetypes = ['text/x-component-pascal']
+    url = 'https://blackboxframework.org'
+    version_added = '2.1'
+
+    flags = re.MULTILINE | re.DOTALL
+
+    tokens = {
+        'root': [
+            include('whitespace'),
+            include('comments'),
+            include('punctuation'),
+            include('numliterals'),
+            include('strings'),
+            include('operators'),
+            include('builtins'),
+            include('identifiers'),
+        ],
+        'whitespace': [
+            (r'\n+', Text),  # blank lines
+            (r'\s+', Text),  # whitespace
+        ],
+        'comments': [
+            (r'\(\*([^$].*?)\*\)', Comment.Multiline),
+            # TODO: nested comments (* (* ... *) ... (* ... *) *) not supported!
+        ],
+        'punctuation': [
+            (r'[()\[\]{},.:;|]', Punctuation),
+        ],
+        'numliterals': [
+            (r'[0-9A-F]+X\b', Number.Hex),                 # char code
+            (r'[0-9A-F]+[HL]\b', Number.Hex),              # hexadecimal number
+            (r'[0-9]+\.[0-9]+E[+-][0-9]+', Number.Float),  # real number
+            (r'[0-9]+\.[0-9]+', Number.Float),             # real number
+            (r'[0-9]+', Number.Integer),                   # decimal whole number
+        ],
+        'strings': [
+            (r"'[^\n']*'", String),  # single quoted string
+            (r'"[^\n"]*"', String),  # double quoted string
+        ],
+        'operators': [
+            # Arithmetic Operators
+            (r'[+-]', Operator),
+            (r'[*/]', Operator),
+            # Relational Operators
+            (r'[=#<>]', Operator),
+            # Dereferencing Operator
+            (r'\^', Operator),
+            # Logical AND Operator
+            (r'&', Operator),
+            # Logical NOT Operator
+            (r'~', Operator),
+            # Assignment Symbol
+            (r':=', Operator),
+            # Range Constructor
+            (r'\.\.', Operator),
+            (r'\$', Operator),
+        ],
+        'identifiers': [
+            (r'([a-zA-Z_$][\w$]*)', Name),
+        ],
+        'builtins': [
+            (words((
+                'ANYPTR', 'ANYREC', 'BOOLEAN', 'BYTE', 'CHAR', 'INTEGER', 'LONGINT',
+                'REAL', 'SET', 'SHORTCHAR', 'SHORTINT', 'SHORTREAL'
+                ), suffix=r'\b'), Keyword.Type),
+            (words((
+                'ABS', 'ABSTRACT', 'ARRAY', 'ASH', 'ASSERT', 'BEGIN', 'BITS', 'BY',
+                'CAP', 'CASE', 'CHR', 'CLOSE', 'CONST', 'DEC', 'DIV', 'DO', 'ELSE',
+                'ELSIF', 'EMPTY', 'END', 'ENTIER', 'EXCL', 'EXIT', 'EXTENSIBLE', 'FOR',
+                'HALT', 'IF', 'IMPORT', 'IN', 'INC', 'INCL', 'IS', 'LEN', 'LIMITED',
+                'LONG', 'LOOP', 'MAX', 'MIN', 'MOD', 'MODULE', 'NEW', 'ODD', 'OF',
+                'OR', 'ORD', 'OUT', 'POINTER', 'PROCEDURE', 'RECORD', 'REPEAT', 'RETURN',
+                'SHORT', 'SHORTCHAR', 'SHORTINT', 'SIZE', 'THEN', 'TYPE', 'TO', 'UNTIL',
+                'VAR', 'WHILE', 'WITH'
+                ), suffix=r'\b'), Keyword.Reserved),
+            (r'(TRUE|FALSE|NIL|INF)\b', Keyword.Constant),
+        ]
+    }
+
+    def analyse_text(text):
+        """The only other lexer using .cp is the C++ one, so we check if for
+        a few common Pascal keywords here. Those are unfortunately quite
+        common across various business languages as well."""
+        result = 0
+        if 'BEGIN' in text:
+            result += 0.01
+        if 'END' in text:
+            result += 0.01
+        if 'PROCEDURE' in text:
+            result += 0.01
+        if 'END' in text:
+            result += 0.01
+
+        return result
diff --git a/pygments/lexers/objective.py b/pygments/lexers/objective.py
new file mode 100644
index 0000000000000000000000000000000000000000..899c2c44c965248e98a705d9bb330d4e8158b9ac
--- /dev/null
+++ b/pygments/lexers/objective.py
@@ -0,0 +1,513 @@
+"""
+    pygments.lexers.objective
+    ~~~~~~~~~~~~~~~~~~~~~~~~~
+
+    Lexers for Objective-C family languages.
+
+    :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS.
+    :license: BSD, see LICENSE for details.
+"""
+
+import re
+
+from pygments.lexer import RegexLexer, include, bygroups, using, this, words, \
+    inherit, default
+from pygments.token import Text, Keyword, Name, String, Operator, \
+    Number, Punctuation, Literal, Comment, Whitespace
+
+from pygments.lexers.c_cpp import CLexer, CppLexer
+
+__all__ = ['ObjectiveCLexer', 'ObjectiveCppLexer', 'LogosLexer', 'SwiftLexer']
+
+
+def objective(baselexer):
+    """
+    Generate a subclass of baselexer that accepts the Objective-C syntax
+    extensions.
+    """
+
+    # Have to be careful not to accidentally match JavaDoc/Doxygen syntax here,
+    # since that's quite common in ordinary C/C++ files.  It's OK to match
+    # JavaDoc/Doxygen keywords that only apply to Objective-C, mind.
+    #
+    # The upshot of this is that we CANNOT match @class or @interface
+    _oc_keywords = re.compile(r'@(?:end|implementation|protocol)')
+
+    # Matches [ ? identifier  ( identifier ? ] |  identifier? : )
+    # (note the identifier is *optional* when there is a ':'!)
+    _oc_message = re.compile(r'\[\s*[a-zA-Z_]\w*\s+'
+                             r'(?:[a-zA-Z_]\w*\s*\]|'
+                             r'(?:[a-zA-Z_]\w*)?:)')
+
+    class GeneratedObjectiveCVariant(baselexer):
+        """
+        Implements Objective-C syntax on top of an existing C family lexer.
+        """
+
+        tokens = {
+            'statements': [
+                (r'@"', String, 'string'),
+                (r'@(YES|NO)', Number),
+                (r"@'(\\.|\\[0-7]{1,3}|\\x[a-fA-F0-9]{1,2}|[^\\\'\n])'", String.Char),
+                (r'@(\d+\.\d*|\.\d+|\d+)[eE][+-]?\d+[lL]?', Number.Float),
+                (r'@(\d+\.\d*|\.\d+|\d+[fF])[fF]?', Number.Float),
+                (r'@0x[0-9a-fA-F]+[Ll]?', Number.Hex),
+                (r'@0[0-7]+[Ll]?', Number.Oct),
+                (r'@\d+[Ll]?', Number.Integer),
+                (r'@\(', Literal, 'literal_number'),
+                (r'@\[', Literal, 'literal_array'),
+                (r'@\{', Literal, 'literal_dictionary'),
+                (words((
+                    '@selector', '@private', '@protected', '@public', '@encode',
+                    '@synchronized', '@try', '@throw', '@catch', '@finally',
+                    '@end', '@property', '@synthesize', '__bridge', '__bridge_transfer',
+                    '__autoreleasing', '__block', '__weak', '__strong', 'weak', 'strong',
+                    'copy', 'retain', 'assign', 'unsafe_unretained', 'atomic', 'nonatomic',
+                    'readonly', 'readwrite', 'setter', 'getter', 'typeof', 'in',
+                    'out', 'inout', 'release', 'class', '@dynamic', '@optional',
+                    '@required', '@autoreleasepool', '@import'), suffix=r'\b'),
+                 Keyword),
+                (words(('id', 'instancetype', 'Class', 'IMP', 'SEL', 'BOOL',
+                        'IBOutlet', 'IBAction', 'unichar'), suffix=r'\b'),
+                 Keyword.Type),
+                (r'@(true|false|YES|NO)\n', Name.Builtin),
+                (r'(YES|NO|nil|self|super)\b', Name.Builtin),
+                # Carbon types
+                (r'(Boolean|UInt8|SInt8|UInt16|SInt16|UInt32|SInt32)\b', Keyword.Type),
+                # Carbon built-ins
+                (r'(TRUE|FALSE)\b', Name.Builtin),
+                (r'(@interface|@implementation)(\s+)', bygroups(Keyword, Text),
+                 ('#pop', 'oc_classname')),
+                (r'(@class|@protocol)(\s+)', bygroups(Keyword, Text),
+                 ('#pop', 'oc_forward_classname')),
+                # @ can also prefix other expressions like @{...} or @(...)
+                (r'@', Punctuation),
+                inherit,
+            ],
+            'oc_classname': [
+                # interface definition that inherits
+                (r'([a-zA-Z$_][\w$]*)(\s*:\s*)([a-zA-Z$_][\w$]*)?(\s*)(\{)',
+                 bygroups(Name.Class, Text, Name.Class, Text, Punctuation),
+                 ('#pop', 'oc_ivars')),
+                (r'([a-zA-Z$_][\w$]*)(\s*:\s*)([a-zA-Z$_][\w$]*)?',
+                 bygroups(Name.Class, Text, Name.Class), '#pop'),
+                # interface definition for a category
+                (r'([a-zA-Z$_][\w$]*)(\s*)(\([a-zA-Z$_][\w$]*\))(\s*)(\{)',
+                 bygroups(Name.Class, Text, Name.Label, Text, Punctuation),
+                 ('#pop', 'oc_ivars')),
+                (r'([a-zA-Z$_][\w$]*)(\s*)(\([a-zA-Z$_][\w$]*\))',
+                 bygroups(Name.Class, Text, Name.Label), '#pop'),
+                # simple interface / implementation
+                (r'([a-zA-Z$_][\w$]*)(\s*)(\{)',
+                 bygroups(Name.Class, Text, Punctuation), ('#pop', 'oc_ivars')),
+                (r'([a-zA-Z$_][\w$]*)', Name.Class, '#pop')
+            ],
+            'oc_forward_classname': [
+                (r'([a-zA-Z$_][\w$]*)(\s*,\s*)',
+                 bygroups(Name.Class, Text), 'oc_forward_classname'),
+                (r'([a-zA-Z$_][\w$]*)(\s*;?)',
+                 bygroups(Name.Class, Text), '#pop')
+            ],
+            'oc_ivars': [
+                include('whitespace'),
+                include('statements'),
+                (';', Punctuation),
+                (r'\{', Punctuation, '#push'),
+                (r'\}', Punctuation, '#pop'),
+            ],
+            'root': [
+                # methods
+                (r'^([-+])(\s*)'                         # method marker
+                 r'(\(.*?\))?(\s*)'                      # return type
+                 r'([a-zA-Z$_][\w$]*:?)',        # begin of method name
+                 bygroups(Punctuation, Text, using(this),
+                          Text, Name.Function),
+                 'method'),
+                inherit,
+            ],
+            'method': [
+                include('whitespace'),
+                # TODO unsure if ellipses are allowed elsewhere, see
+                # discussion in Issue 789
+                (r',', Punctuation),
+                (r'\.\.\.', Punctuation),
+                (r'(\(.*?\))(\s*)([a-zA-Z$_][\w$]*)',
+                 bygroups(using(this), Text, Name.Variable)),
+                (r'[a-zA-Z$_][\w$]*:', Name.Function),
+                (';', Punctuation, '#pop'),
+                (r'\{', Punctuation, 'function'),
+                default('#pop'),
+            ],
+            'literal_number': [
+                (r'\(', Punctuation, 'literal_number_inner'),
+                (r'\)', Literal, '#pop'),
+                include('statement'),
+            ],
+            'literal_number_inner': [
+                (r'\(', Punctuation, '#push'),
+                (r'\)', Punctuation, '#pop'),
+                include('statement'),
+            ],
+            'literal_array': [
+                (r'\[', Punctuation, 'literal_array_inner'),
+                (r'\]', Literal, '#pop'),
+                include('statement'),
+            ],
+            'literal_array_inner': [
+                (r'\[', Punctuation, '#push'),
+                (r'\]', Punctuation, '#pop'),
+                include('statement'),
+            ],
+            'literal_dictionary': [
+                (r'\}', Literal, '#pop'),
+                include('statement'),
+            ],
+        }
+
+        def analyse_text(text):
+            if _oc_keywords.search(text):
+                return 1.0
+            elif '@"' in text:  # strings
+                return 0.8
+            elif re.search('@[0-9]+', text):
+                return 0.7
+            elif _oc_message.search(text):
+                return 0.8
+            return 0
+
+        def get_tokens_unprocessed(self, text, stack=('root',)):
+            from pygments.lexers._cocoa_builtins import COCOA_INTERFACES, \
+                COCOA_PROTOCOLS, COCOA_PRIMITIVES
+
+            for index, token, value in \
+                    baselexer.get_tokens_unprocessed(self, text, stack):
+                if token is Name or token is Name.Class:
+                    if value in COCOA_INTERFACES or value in COCOA_PROTOCOLS \
+                       or value in COCOA_PRIMITIVES:
+                        token = Name.Builtin.Pseudo
+
+                yield index, token, value
+
+    return GeneratedObjectiveCVariant
+
+
+class ObjectiveCLexer(objective(CLexer)):
+    """
+    For Objective-C source code with preprocessor directives.
+    """
+
+    name = 'Objective-C'
+    url = 'https://developer.apple.com/library/archive/documentation/Cocoa/Conceptual/ProgrammingWithObjectiveC/Introduction/Introduction.html'
+    aliases = ['objective-c', 'objectivec', 'obj-c', 'objc']
+    filenames = ['*.m', '*.h']
+    mimetypes = ['text/x-objective-c']
+    version_added = ''
+    priority = 0.05    # Lower than C
+
+
+class ObjectiveCppLexer(objective(CppLexer)):
+    """
+    For Objective-C++ source code with preprocessor directives.
+    """
+
+    name = 'Objective-C++'
+    aliases = ['objective-c++', 'objectivec++', 'obj-c++', 'objc++']
+    filenames = ['*.mm', '*.hh']
+    mimetypes = ['text/x-objective-c++']
+    version_added = ''
+    priority = 0.05    # Lower than C++
+
+
+class LogosLexer(ObjectiveCppLexer):
+    """
+    For Logos + Objective-C source code with preprocessor directives.
+    """
+
+    name = 'Logos'
+    aliases = ['logos']
+    filenames = ['*.x', '*.xi', '*.xm', '*.xmi']
+    mimetypes = ['text/x-logos']
+    version_added = '1.6'
+    priority = 0.25
+
+    tokens = {
+        'statements': [
+            (r'(%orig|%log)\b', Keyword),
+            (r'(%c)\b(\()(\s*)([a-zA-Z$_][\w$]*)(\s*)(\))',
+             bygroups(Keyword, Punctuation, Text, Name.Class, Text, Punctuation)),
+            (r'(%init)\b(\()',
+             bygroups(Keyword, Punctuation), 'logos_init_directive'),
+            (r'(%init)(?=\s*;)', bygroups(Keyword)),
+            (r'(%hook|%group)(\s+)([a-zA-Z$_][\w$]+)',
+             bygroups(Keyword, Text, Name.Class), '#pop'),
+            (r'(%subclass)(\s+)', bygroups(Keyword, Text),
+             ('#pop', 'logos_classname')),
+            inherit,
+        ],
+        'logos_init_directive': [
+            (r'\s+', Text),
+            (',', Punctuation, ('logos_init_directive', '#pop')),
+            (r'([a-zA-Z$_][\w$]*)(\s*)(=)(\s*)([^);]*)',
+             bygroups(Name.Class, Text, Punctuation, Text, Text)),
+            (r'([a-zA-Z$_][\w$]*)', Name.Class),
+            (r'\)', Punctuation, '#pop'),
+        ],
+        'logos_classname': [
+            (r'([a-zA-Z$_][\w$]*)(\s*:\s*)([a-zA-Z$_][\w$]*)?',
+             bygroups(Name.Class, Text, Name.Class), '#pop'),
+            (r'([a-zA-Z$_][\w$]*)', Name.Class, '#pop')
+        ],
+        'root': [
+            (r'(%subclass)(\s+)', bygroups(Keyword, Text),
+             'logos_classname'),
+            (r'(%hook|%group)(\s+)([a-zA-Z$_][\w$]+)',
+             bygroups(Keyword, Text, Name.Class)),
+            (r'(%config)(\s*\(\s*)(\w+)(\s*=)(.*?)(\)\s*)',
+             bygroups(Keyword, Text, Name.Variable, Text, String, Text)),
+            (r'(%ctor)(\s*)(\{)', bygroups(Keyword, Text, Punctuation),
+             'function'),
+            (r'(%new)(\s*)(\()(.*?)(\))',
+             bygroups(Keyword, Text, Keyword, String, Keyword)),
+            (r'(\s*)(%end)(\s*)', bygroups(Text, Keyword, Text)),
+            inherit,
+        ],
+    }
+
+    _logos_keywords = re.compile(r'%(?:hook|ctor|init|c\()')
+
+    def analyse_text(text):
+        if LogosLexer._logos_keywords.search(text):
+            return 1.0
+        return 0
+
+
+class SwiftLexer(RegexLexer):
+    """
+    For Swift source.
+    """
+    name = 'Swift'
+    url = 'https://www.swift.org/'
+    filenames = ['*.swift']
+    aliases = ['swift']
+    mimetypes = ['text/x-swift']
+    version_added = '2.0'
+
+    tokens = {
+        'root': [
+            # Whitespace and Comments
+            (r'\n', Text),
+            (r'\s+', Whitespace),
+            (r'//', Comment.Single, 'comment-single'),
+            (r'/\*', Comment.Multiline, 'comment-multi'),
+            (r'#(if|elseif|else|endif|available)\b', Comment.Preproc, 'preproc'),
+
+            # Keywords
+            include('keywords'),
+
+            # Global Types
+            (words((
+                'Array', 'AutoreleasingUnsafeMutablePointer', 'BidirectionalReverseView',
+                'Bit', 'Bool', 'CFunctionPointer', 'COpaquePointer', 'CVaListPointer',
+                'Character', 'ClosedInterval', 'CollectionOfOne', 'ContiguousArray',
+                'Dictionary', 'DictionaryGenerator', 'DictionaryIndex', 'Double',
+                'EmptyCollection', 'EmptyGenerator', 'EnumerateGenerator',
+                'EnumerateSequence', 'FilterCollectionView',
+                'FilterCollectionViewIndex', 'FilterGenerator', 'FilterSequenceView',
+                'Float', 'Float80', 'FloatingPointClassification', 'GeneratorOf',
+                'GeneratorOfOne', 'GeneratorSequence', 'HalfOpenInterval', 'HeapBuffer',
+                'HeapBufferStorage', 'ImplicitlyUnwrappedOptional', 'IndexingGenerator',
+                'Int', 'Int16', 'Int32', 'Int64', 'Int8', 'LazyBidirectionalCollection',
+                'LazyForwardCollection', 'LazyRandomAccessCollection',
+                'LazySequence', 'MapCollectionView', 'MapSequenceGenerator',
+                'MapSequenceView', 'MirrorDisposition', 'ObjectIdentifier', 'OnHeap',
+                'Optional', 'PermutationGenerator', 'QuickLookObject',
+                'RandomAccessReverseView', 'Range', 'RangeGenerator', 'RawByte', 'Repeat',
+                'ReverseBidirectionalIndex', 'ReverseRandomAccessIndex', 'SequenceOf',
+                'SinkOf', 'Slice', 'StaticString', 'StrideThrough', 'StrideThroughGenerator',
+                'StrideTo', 'StrideToGenerator', 'String', 'UInt', 'UInt16', 'UInt32',
+                'UInt64', 'UInt8', 'UTF16', 'UTF32', 'UTF8', 'UnicodeDecodingResult',
+                'UnicodeScalar', 'Unmanaged', 'UnsafeBufferPointer',
+                'UnsafeBufferPointerGenerator', 'UnsafeMutableBufferPointer',
+                'UnsafeMutablePointer', 'UnsafePointer', 'Zip2', 'ZipGenerator2',
+                # Protocols
+                'AbsoluteValuable', 'AnyObject', 'ArrayLiteralConvertible',
+                'BidirectionalIndexType', 'BitwiseOperationsType',
+                'BooleanLiteralConvertible', 'BooleanType', 'CVarArgType',
+                'CollectionType', 'Comparable', 'DebugPrintable',
+                'DictionaryLiteralConvertible', 'Equatable',
+                'ExtendedGraphemeClusterLiteralConvertible',
+                'ExtensibleCollectionType', 'FloatLiteralConvertible',
+                'FloatingPointType', 'ForwardIndexType', 'GeneratorType', 'Hashable',
+                'IntegerArithmeticType', 'IntegerLiteralConvertible', 'IntegerType',
+                'IntervalType', 'MirrorType', 'MutableCollectionType', 'MutableSliceable',
+                'NilLiteralConvertible', 'OutputStreamType', 'Printable',
+                'RandomAccessIndexType', 'RangeReplaceableCollectionType',
+                'RawOptionSetType', 'RawRepresentable', 'Reflectable', 'SequenceType',
+                'SignedIntegerType', 'SignedNumberType', 'SinkType', 'Sliceable',
+                'Streamable', 'Strideable', 'StringInterpolationConvertible',
+                'StringLiteralConvertible', 'UnicodeCodecType',
+                'UnicodeScalarLiteralConvertible', 'UnsignedIntegerType',
+                '_ArrayBufferType', '_BidirectionalIndexType', '_CocoaStringType',
+                '_CollectionType', '_Comparable', '_ExtensibleCollectionType',
+                '_ForwardIndexType', '_Incrementable', '_IntegerArithmeticType',
+                '_IntegerType', '_ObjectiveCBridgeable', '_RandomAccessIndexType',
+                '_RawOptionSetType', '_SequenceType', '_Sequence_Type',
+                '_SignedIntegerType', '_SignedNumberType', '_Sliceable', '_Strideable',
+                '_SwiftNSArrayRequiredOverridesType', '_SwiftNSArrayType',
+                '_SwiftNSCopyingType', '_SwiftNSDictionaryRequiredOverridesType',
+                '_SwiftNSDictionaryType', '_SwiftNSEnumeratorType',
+                '_SwiftNSFastEnumerationType', '_SwiftNSStringRequiredOverridesType',
+                '_SwiftNSStringType', '_UnsignedIntegerType',
+                # Variables
+                'C_ARGC', 'C_ARGV', 'Process',
+                # Typealiases
+                'Any', 'AnyClass', 'BooleanLiteralType', 'CBool', 'CChar', 'CChar16',
+                'CChar32', 'CDouble', 'CFloat', 'CInt', 'CLong', 'CLongLong', 'CShort',
+                'CSignedChar', 'CUnsignedInt', 'CUnsignedLong', 'CUnsignedShort',
+                'CWideChar', 'ExtendedGraphemeClusterType', 'Float32', 'Float64',
+                'FloatLiteralType', 'IntMax', 'IntegerLiteralType', 'StringLiteralType',
+                'UIntMax', 'UWord', 'UnicodeScalarType', 'Void', 'Word',
+                # Foundation/Cocoa
+                'NSErrorPointer', 'NSObjectProtocol', 'Selector'), suffix=r'\b'),
+             Name.Builtin),
+            # Functions
+            (words((
+                'abs', 'advance', 'alignof', 'alignofValue', 'assert', 'assertionFailure',
+                'contains', 'count', 'countElements', 'debugPrint', 'debugPrintln',
+                'distance', 'dropFirst', 'dropLast', 'dump', 'enumerate', 'equal',
+                'extend', 'fatalError', 'filter', 'find', 'first', 'getVaList', 'indices',
+                'insert', 'isEmpty', 'join', 'last', 'lazy', 'lexicographicalCompare',
+                'map', 'max', 'maxElement', 'min', 'minElement', 'numericCast', 'overlaps',
+                'partition', 'precondition', 'preconditionFailure', 'prefix', 'print',
+                'println', 'reduce', 'reflect', 'removeAll', 'removeAtIndex', 'removeLast',
+                'removeRange', 'reverse', 'sizeof', 'sizeofValue', 'sort', 'sorted',
+                'splice', 'split', 'startsWith', 'stride', 'strideof', 'strideofValue',
+                'suffix', 'swap', 'toDebugString', 'toString', 'transcode',
+                'underestimateCount', 'unsafeAddressOf', 'unsafeBitCast', 'unsafeDowncast',
+                'withExtendedLifetime', 'withUnsafeMutablePointer',
+                'withUnsafeMutablePointers', 'withUnsafePointer', 'withUnsafePointers',
+                'withVaList'), suffix=r'\b'),
+             Name.Builtin.Pseudo),
+
+            # Implicit Block Variables
+            (r'\$\d+', Name.Variable),
+
+            # Binary Literal
+            (r'0b[01_]+', Number.Bin),
+            # Octal Literal
+            (r'0o[0-7_]+', Number.Oct),
+            # Hexadecimal Literal
+            (r'0x[0-9a-fA-F_]+', Number.Hex),
+            # Decimal Literal
+            (r'[0-9][0-9_]*(\.[0-9_]+[eE][+\-]?[0-9_]+|'
+             r'\.[0-9_]*|[eE][+\-]?[0-9_]+)', Number.Float),
+            (r'[0-9][0-9_]*', Number.Integer),
+            # String Literal
+            (r'"""', String, 'string-multi'),
+            (r'"', String, 'string'),
+
+            # Operators and Punctuation
+            (r'[(){}\[\].,:;=@#`?]|->|[<&?](?=\w)|(?<=\w)[>!?]', Punctuation),
+            (r'[/=\-+!*%<>&|^?~]+', Operator),
+
+            # Identifier
+            (r'[a-zA-Z_]\w*', Name)
+        ],
+        'keywords': [
+            (words((
+                'as', 'async', 'await', 'break', 'case', 'catch', 'continue', 'default', 'defer',
+                'do', 'else', 'fallthrough', 'for', 'guard', 'if', 'in', 'is',
+                'repeat', 'return', '#selector', 'switch', 'throw', 'try',
+                'where', 'while'), suffix=r'\b'),
+             Keyword),
+            (r'@availability\([^)]+\)', Keyword.Reserved),
+            (words((
+                'associativity', 'convenience', 'dynamic', 'didSet', 'final',
+                'get', 'indirect', 'infix', 'inout', 'lazy', 'left', 'mutating',
+                'none', 'nonmutating', 'optional', 'override', 'postfix',
+                'precedence', 'prefix', 'Protocol', 'required', 'rethrows',
+                'right', 'set', 'throws', 'Type', 'unowned', 'weak', 'willSet',
+                '@availability', '@autoclosure', '@noreturn',
+                '@NSApplicationMain', '@NSCopying', '@NSManaged', '@objc',
+                '@UIApplicationMain', '@IBAction', '@IBDesignable',
+                '@IBInspectable', '@IBOutlet'), suffix=r'\b'),
+             Keyword.Reserved),
+            (r'(as|dynamicType|false|is|nil|self|Self|super|true|__COLUMN__'
+             r'|__FILE__|__FUNCTION__|__LINE__|_'
+             r'|#(?:file|line|column|function))\b', Keyword.Constant),
+            (r'import\b', Keyword.Declaration, 'module'),
+            (r'(class|enum|extension|struct|protocol)(\s+)([a-zA-Z_]\w*)',
+             bygroups(Keyword.Declaration, Whitespace, Name.Class)),
+            (r'(func)(\s+)([a-zA-Z_]\w*)',
+             bygroups(Keyword.Declaration, Whitespace, Name.Function)),
+            (r'(var|let)(\s+)([a-zA-Z_]\w*)', bygroups(Keyword.Declaration,
+             Whitespace, Name.Variable)),
+            (words((
+                'actor', 'associatedtype', 'class', 'deinit', 'enum', 'extension', 'func', 'import',
+                'init', 'internal', 'let', 'operator', 'private', 'protocol', 'public',
+                'static', 'struct', 'subscript', 'typealias', 'var'), suffix=r'\b'),
+             Keyword.Declaration)
+        ],
+        'comment': [
+            (r':param: [a-zA-Z_]\w*|:returns?:|(FIXME|MARK|TODO):',
+             Comment.Special)
+        ],
+
+        # Nested
+        'comment-single': [
+            (r'\n', Whitespace, '#pop'),
+            include('comment'),
+            (r'[^\n]+', Comment.Single)
+        ],
+        'comment-multi': [
+            include('comment'),
+            (r'[^*/]+', Comment.Multiline),
+            (r'/\*', Comment.Multiline, '#push'),
+            (r'\*/', Comment.Multiline, '#pop'),
+            (r'[*/]+', Comment.Multiline)
+        ],
+        'module': [
+            (r'\n', Whitespace, '#pop'),
+            (r'[a-zA-Z_]\w*', Name.Class),
+            include('root')
+        ],
+        'preproc': [
+            (r'\n', Whitespace, '#pop'),
+            include('keywords'),
+            (r'[A-Za-z]\w*', Comment.Preproc),
+            include('root')
+        ],
+        'string': [
+            (r'"', String, '#pop'),
+            include("string-common"),
+        ],
+        'string-multi': [
+            (r'"""', String, '#pop'),
+            include("string-common"),
+        ],
+        'string-common': [
+            (r'\\\(', String.Interpol, 'string-intp'),
+            (r"""\\['"\\nrt]|\\x[0-9a-fA-F]{2}|\\[0-7]{1,3}"""
+             r"""|\\u[0-9a-fA-F]{4}|\\U[0-9a-fA-F]{8}""", String.Escape),
+            (r'[^\\"]+', String),
+            (r'\\', String)
+        ],
+        'string-intp': [
+            (r'\(', String.Interpol, '#push'),
+            (r'\)', String.Interpol, '#pop'),
+            include('root')
+        ]
+    }
+
+    def get_tokens_unprocessed(self, text):
+        from pygments.lexers._cocoa_builtins import COCOA_INTERFACES, \
+            COCOA_PROTOCOLS, COCOA_PRIMITIVES
+
+        for index, token, value in \
+                RegexLexer.get_tokens_unprocessed(self, text):
+            if token is Name or token is Name.Class:
+                if value in COCOA_INTERFACES or value in COCOA_PROTOCOLS \
+                   or value in COCOA_PRIMITIVES:
+                    token = Name.Builtin.Pseudo
+
+            yield index, token, value
diff --git a/pygments/lexers/ooc.py b/pygments/lexers/ooc.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a990801a3c31535c8713c80a0b4d95215f4a6c8
--- /dev/null
+++ b/pygments/lexers/ooc.py
@@ -0,0 +1,84 @@
+"""
+    pygments.lexers.ooc
+    ~~~~~~~~~~~~~~~~~~~
+
+    Lexers for the Ooc language.
+
+    :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS.
+    :license: BSD, see LICENSE for details.
+"""
+
+from pygments.lexer import RegexLexer, bygroups, words
+from pygments.token import Text, Comment, Operator, Keyword, Name, String, \
+    Number, Punctuation
+
+__all__ = ['OocLexer']
+
+
+class OocLexer(RegexLexer):
+    """
+    For Ooc source code
+    """
+    name = 'Ooc'
+    url = 'https://ooc-lang.github.io/'
+    aliases = ['ooc']
+    filenames = ['*.ooc']
+    mimetypes = ['text/x-ooc']
+    version_added = '1.2'
+
+    tokens = {
+        'root': [
+            (words((
+                'class', 'interface', 'implement', 'abstract', 'extends', 'from',
+                'this', 'super', 'new', 'const', 'final', 'static', 'import',
+                'use', 'extern', 'inline', 'proto', 'break', 'continue',
+                'fallthrough', 'operator', 'if', 'else', 'for', 'while', 'do',
+                'switch', 'case', 'as', 'in', 'version', 'return', 'true',
+                'false', 'null'), prefix=r'\b', suffix=r'\b'),
+             Keyword),
+            (r'include\b', Keyword, 'include'),
+            (r'(cover)([ \t]+)(from)([ \t]+)(\w+[*@]?)',
+             bygroups(Keyword, Text, Keyword, Text, Name.Class)),
+            (r'(func)((?:[ \t]|\\\n)+)(~[a-z_]\w*)',
+             bygroups(Keyword, Text, Name.Function)),
+            (r'\bfunc\b', Keyword),
+            # Note: %= not listed on https://ooc-lang.github.io/docs/lang/operators/
+            (r'//.*', Comment),
+            (r'(?s)/\*.*?\*/', Comment.Multiline),
+            (r'(==?|\+=?|-[=>]?|\*=?|/=?|:=|!=?|%=?|\?|>{1,3}=?|<{1,3}=?|\.\.|'
+             r'&&?|\|\|?|\^=?)', Operator),
+            (r'(\.)([ \t]*)([a-z]\w*)', bygroups(Operator, Text,
+                                                 Name.Function)),
+            (r'[A-Z][A-Z0-9_]+', Name.Constant),
+            (r'[A-Z]\w*([@*]|\[[ \t]*\])?', Name.Class),
+
+            (r'([a-z]\w*(?:~[a-z]\w*)?)((?:[ \t]|\\\n)*)(?=\()',
+             bygroups(Name.Function, Text)),
+            (r'[a-z]\w*', Name.Variable),
+
+            # : introduces types
+            (r'[:(){}\[\];,]', Punctuation),
+
+            (r'0x[0-9a-fA-F]+', Number.Hex),
+            (r'0c[0-9]+', Number.Oct),
+            (r'0b[01]+', Number.Bin),
+            (r'[0-9_]\.[0-9_]*(?!\.)', Number.Float),
+            (r'[0-9_]+', Number.Decimal),
+
+            (r'"(?:\\.|\\[0-7]{1,3}|\\x[a-fA-F0-9]{1,2}|[^\\"])*"',
+             String.Double),
+            (r"'(?:\\.|\\[0-9]{1,3}|\\x[a-fA-F0-9]{1,2}|[^\\\'\n])'",
+             String.Char),
+            (r'@', Punctuation),  # pointer dereference
+            (r'\.', Punctuation),  # imports or chain operator
+
+            (r'\\[ \t\n]', Text),
+            (r'[ \t]+', Text),
+        ],
+        'include': [
+            (r'[\w/]+', Name),
+            (r',', Punctuation),
+            (r'[ \t]', Text),
+            (r'[;\n]', Text, '#pop'),
+        ],
+    }
diff --git a/pygments/lexers/openscad.py b/pygments/lexers/openscad.py
new file mode 100644
index 0000000000000000000000000000000000000000..b06de227c7c485db854e5ba25616431a84129ba8
--- /dev/null
+++ b/pygments/lexers/openscad.py
@@ -0,0 +1,96 @@
+"""
+    pygments.lexers.openscad
+    ~~~~~~~~~~~~~~~~~~~~~~~~
+
+    Lexers for the OpenSCAD languages.
+
+    :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS.
+    :license: BSD, see LICENSE for details.
+"""
+
+from pygments.lexer import RegexLexer, bygroups, words, include
+from pygments.token import Text, Comment, Punctuation, Operator, Keyword, Name, Number, Whitespace, Literal, String
+
+__all__ = ['OpenScadLexer']
+
+
+class OpenScadLexer(RegexLexer):
+    """For openSCAD code.
+    """
+    name = "OpenSCAD"
+    url = "https://openscad.org/"
+    aliases = ["openscad"]
+    filenames = ["*.scad"]
+    mimetypes = ["application/x-openscad"]
+    version_added = '2.16'
+
+    tokens = {
+        "root": [
+            (r"[^\S\n]+", Whitespace),
+            (r'//', Comment.Single, 'comment-single'),
+            (r'/\*', Comment.Multiline, 'comment-multi'),
+            (r"[{}\[\]\(\),;:]", Punctuation),
+            (r"[*!#%\-+=?/]", Operator),
+            (r"<=|<|==|!=|>=|>|&&|\|\|", Operator),
+            (r"\$(f[asn]|t|vp[rtd]|children)", Operator),
+            (r"(undef|PI)\b", Keyword.Constant),
+            (
+                r"(use|include)((?:\s|\\\\s)+)",
+                bygroups(Keyword.Namespace, Text),
+                "includes",
+            ),
+            (r"(module)(\s*)([^\s\(]+)",
+             bygroups(Keyword.Namespace, Whitespace, Name.Namespace)),
+            (r"(function)(\s*)([^\s\(]+)",
+             bygroups(Keyword.Declaration, Whitespace, Name.Function)),
+            (words(("true", "false"), prefix=r"\b", suffix=r"\b"), Literal),
+            (words((
+                "function", "module", "include", "use", "for",
+                "intersection_for", "if", "else", "return"
+                ), prefix=r"\b", suffix=r"\b"), Keyword
+            ),
+            (words((
+                "circle", "square", "polygon", "text", "sphere", "cube",
+                "cylinder", "polyhedron", "translate", "rotate", "scale",
+                "resize", "mirror", "multmatrix", "color", "offset", "hull",
+                "minkowski", "union", "difference", "intersection", "abs",
+                "sign", "sin", "cos", "tan", "acos", "asin", "atan", "atan2",
+                "floor", "round", "ceil", "ln", "log", "pow", "sqrt", "exp",
+                "rands", "min", "max", "concat", "lookup", "str", "chr",
+                "search", "version", "version_num", "norm", "cross",
+                "parent_module", "echo", "import", "import_dxf",
+                "dxf_linear_extrude", "linear_extrude", "rotate_extrude",
+                "surface", "projection", "render", "dxf_cross",
+                "dxf_dim", "let", "assign", "len"
+                ), prefix=r"\b", suffix=r"\b"),
+                Name.Builtin
+            ),
+            (r"\bchildren\b", Name.Builtin.Pseudo),
+            (r'""".*?"""', String.Double),
+            (r'"(\\\\|\\[^\\]|[^"\\])*"', String.Double),
+            (r"-?\d+(\.\d+)?(e[+-]?\d+)?", Number),
+            (r"\w+", Name),
+        ],
+        "includes": [
+            (
+                r"(<)([^>]*)(>)",
+                bygroups(Punctuation, Comment.PreprocFile, Punctuation),
+            ),
+        ],
+        'comment': [
+            (r':param: [a-zA-Z_]\w*|:returns?:|(FIXME|MARK|TODO):',
+             Comment.Special)
+        ],
+        'comment-single': [
+            (r'\n', Text, '#pop'),
+            include('comment'),
+            (r'[^\n]+', Comment.Single)
+        ],
+        'comment-multi': [
+            include('comment'),
+            (r'[^*/]+', Comment.Multiline),
+            (r'/\*', Comment.Multiline, '#push'),
+            (r'\*/', Comment.Multiline, '#pop'),
+            (r'[*/]', Comment.Multiline)
+        ],
+    }
diff --git a/pygments/lexers/other.py b/pygments/lexers/other.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b7dfb4ab534f51b785f35a1c13478397debac6c
--- /dev/null
+++ b/pygments/lexers/other.py
@@ -0,0 +1,41 @@
+"""
+    pygments.lexers.other
+    ~~~~~~~~~~~~~~~~~~~~~
+
+    Just export lexer classes previously contained in this module.
+
+    :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS.
+    :license: BSD, see LICENSE for details.
+"""
+
+# ruff: noqa: F401
+from pygments.lexers.sql import SqlLexer, MySqlLexer, SqliteConsoleLexer
+from pygments.lexers.shell import BashLexer, BashSessionLexer, BatchLexer, \
+    TcshLexer
+from pygments.lexers.robotframework import RobotFrameworkLexer
+from pygments.lexers.testing import GherkinLexer
+from pygments.lexers.esoteric import BrainfuckLexer, BefungeLexer, RedcodeLexer
+from pygments.lexers.prolog import LogtalkLexer
+from pygments.lexers.snobol import SnobolLexer
+from pygments.lexers.rebol import RebolLexer
+from pygments.lexers.configs import KconfigLexer, Cfengine3Lexer
+from pygments.lexers.modeling import ModelicaLexer
+from pygments.lexers.scripting import AppleScriptLexer, MOOCodeLexer, \
+    HybrisLexer
+from pygments.lexers.graphics import PostScriptLexer, GnuplotLexer, \
+    AsymptoteLexer, PovrayLexer
+from pygments.lexers.business import ABAPLexer, OpenEdgeLexer, \
+    GoodDataCLLexer, MaqlLexer
+from pygments.lexers.automation import AutoItLexer, AutohotkeyLexer
+from pygments.lexers.dsls import ProtoBufLexer, BroLexer, PuppetLexer, \
+    MscgenLexer, VGLLexer
+from pygments.lexers.basic import CbmBasicV2Lexer
+from pygments.lexers.pawn import SourcePawnLexer, PawnLexer
+from pygments.lexers.ecl import ECLLexer
+from pygments.lexers.urbi import UrbiscriptLexer
+from pygments.lexers.smalltalk import SmalltalkLexer, NewspeakLexer
+from pygments.lexers.installers import NSISLexer, RPMSpecLexer
+from pygments.lexers.textedit import AwkLexer
+from pygments.lexers.smv import NuSMVLexer
+
+__all__ = []
diff --git a/pygments/lexers/parasail.py b/pygments/lexers/parasail.py
new file mode 100644
index 0000000000000000000000000000000000000000..150d6a9c48213856c423f4ceb28fb1ab7400f137
--- /dev/null
+++ b/pygments/lexers/parasail.py
@@ -0,0 +1,78 @@
+"""
+    pygments.lexers.parasail
+    ~~~~~~~~~~~~~~~~~~~~~~~~
+
+    Lexer for ParaSail.
+
+    :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS.
+    :license: BSD, see LICENSE for details.
+"""
+
+import re
+
+from pygments.lexer import RegexLexer, include
+from pygments.token import Text, Comment, Operator, Keyword, Name, String, \
+    Number, Punctuation, Literal
+
+__all__ = ['ParaSailLexer']
+
+
+class ParaSailLexer(RegexLexer):
+    """
+    For ParaSail source code.
+    """
+
+    name = 'ParaSail'
+    url = 'http://www.parasail-lang.org'
+    aliases = ['parasail']
+    filenames = ['*.psi', '*.psl']
+    mimetypes = ['text/x-parasail']
+    version_added = '2.1'
+
+    flags = re.MULTILINE
+
+    tokens = {
+        'root': [
+            (r'[^\S\n]+', Text),
+            (r'//.*?\n', Comment.Single),
+            (r'\b(and|or|xor)=', Operator.Word),
+            (r'\b(and(\s+then)?|or(\s+else)?|xor|rem|mod|'
+             r'(is|not)\s+null)\b',
+             Operator.Word),
+            # Keywords
+            (r'\b(abs|abstract|all|block|class|concurrent|const|continue|'
+             r'each|end|exit|extends|exports|forward|func|global|implements|'
+             r'import|in|interface|is|lambda|locked|new|not|null|of|op|'
+             r'optional|private|queued|ref|return|reverse|separate|some|'
+             r'type|until|var|with|'
+             # Control flow
+             r'if|then|else|elsif|case|for|while|loop)\b',
+             Keyword.Reserved),
+            (r'(abstract\s+)?(interface|class|op|func|type)',
+             Keyword.Declaration),
+            # Literals
+            (r'"[^"]*"', String),
+            (r'\\[\'ntrf"0]', String.Escape),
+            (r'#[a-zA-Z]\w*', Literal),       # Enumeration
+            include('numbers'),
+            (r"'[^']'", String.Char),
+            (r'[a-zA-Z]\w*', Name),
+            # Operators and Punctuation
+            (r'(<==|==>|<=>|\*\*=|<\|=|<<=|>>=|==|!=|=\?|<=|>=|'
+             r'\*\*|<<|>>|=>|:=|\+=|-=|\*=|\|=|\||/=|\+|-|\*|/|'
+             r'\.\.|<\.\.|\.\.<|<\.\.<)',
+             Operator),
+            (r'(<|>|\[|\]|\(|\)|\||:|;|,|.|\{|\}|->)',
+             Punctuation),
+            (r'\n+', Text),
+        ],
+        'numbers': [
+            (r'\d[0-9_]*#[0-9a-fA-F][0-9a-fA-F_]*#', Number.Hex),  # any base
+            (r'0[xX][0-9a-fA-F][0-9a-fA-F_]*', Number.Hex),        # C-like hex
+            (r'0[bB][01][01_]*', Number.Bin),                      # C-like bin
+            (r'\d[0-9_]*\.\d[0-9_]*[eE][+-]\d[0-9_]*',             # float exp
+             Number.Float),
+            (r'\d[0-9_]*\.\d[0-9_]*', Number.Float),               # float
+            (r'\d[0-9_]*', Number.Integer),                        # integer
+        ],
+    }
diff --git a/pygments/lexers/parsers.py b/pygments/lexers/parsers.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a4ed9d1df789fd8e559e3c6de2fb161c4620547
--- /dev/null
+++ b/pygments/lexers/parsers.py
@@ -0,0 +1,798 @@
+"""
+    pygments.lexers.parsers
+    ~~~~~~~~~~~~~~~~~~~~~~~
+
+    Lexers for parser generators.
+
+    :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS.
+    :license: BSD, see LICENSE for details.
+"""
+
+import re
+
+from pygments.lexer import RegexLexer, DelegatingLexer, \
+    include, bygroups, using
+from pygments.token import Punctuation, Other, Text, Comment, Operator, \
+    Keyword, Name, String, Number, Whitespace
+from pygments.lexers.jvm import JavaLexer
+from pygments.lexers.c_cpp import CLexer, CppLexer
+from pygments.lexers.objective import ObjectiveCLexer
+from pygments.lexers.d import DLexer
+from pygments.lexers.dotnet import CSharpLexer
+from pygments.lexers.ruby import RubyLexer
+from pygments.lexers.python import PythonLexer
+from pygments.lexers.perl import PerlLexer
+
+__all__ = ['RagelLexer', 'RagelEmbeddedLexer', 'RagelCLexer', 'RagelDLexer',
+           'RagelCppLexer', 'RagelObjectiveCLexer', 'RagelRubyLexer',
+           'RagelJavaLexer', 'AntlrLexer', 'AntlrPythonLexer',
+           'AntlrPerlLexer', 'AntlrRubyLexer', 'AntlrCppLexer',
+           'AntlrCSharpLexer', 'AntlrObjectiveCLexer',
+           'AntlrJavaLexer', 'AntlrActionScriptLexer',
+           'TreetopLexer', 'EbnfLexer']
+
+
+class RagelLexer(RegexLexer):
+    """A pure `Ragel `_ lexer.  Use this
+    for fragments of Ragel.  For ``.rl`` files, use
+    :class:`RagelEmbeddedLexer` instead (or one of the
+    language-specific subclasses).
+
+    """
+
+    name = 'Ragel'
+    url = 'http://www.colm.net/open-source/ragel/'
+    aliases = ['ragel']
+    filenames = []
+    version_added = '1.1'
+
+    tokens = {
+        'whitespace': [
+            (r'\s+', Whitespace)
+        ],
+        'comments': [
+            (r'\#.*$', Comment),
+        ],
+        'keywords': [
+            (r'(access|action|alphtype)\b', Keyword),
+            (r'(getkey|write|machine|include)\b', Keyword),
+            (r'(any|ascii|extend|alpha|digit|alnum|lower|upper)\b', Keyword),
+            (r'(xdigit|cntrl|graph|print|punct|space|zlen|empty)\b', Keyword)
+        ],
+        'numbers': [
+            (r'0x[0-9A-Fa-f]+', Number.Hex),
+            (r'[+-]?[0-9]+', Number.Integer),
+        ],
+        'literals': [
+            (r'"(\\\\|\\[^\\]|[^"\\])*"', String.Double),
+            (r"'(\\\\|\\[^\\]|[^'\\])*'", String.Single),
+            (r'\[(\\\\|\\[^\\]|[^\\\]])*\]', String),          # square bracket literals
+            (r'/(?!\*)(\\\\|\\[^\\]|[^/\\])*/', String.Regex),  # regular expressions
+        ],
+        'identifiers': [
+            (r'[a-zA-Z_]\w*', Name.Variable),
+        ],
+        'operators': [
+            (r',', Operator),                           # Join
+            (r'\||&|--?', Operator),                    # Union, Intersection and Subtraction
+            (r'\.|<:|:>>?', Operator),                  # Concatention
+            (r':', Operator),                           # Label
+            (r'->', Operator),                          # Epsilon Transition
+            (r'(>|\$|%|<|@|<>)(/|eof\b)', Operator),    # EOF Actions
+            (r'(>|\$|%|<|@|<>)(!|err\b)', Operator),    # Global Error Actions
+            (r'(>|\$|%|<|@|<>)(\^|lerr\b)', Operator),  # Local Error Actions
+            (r'(>|\$|%|<|@|<>)(~|to\b)', Operator),     # To-State Actions
+            (r'(>|\$|%|<|@|<>)(\*|from\b)', Operator),  # From-State Actions
+            (r'>|@|\$|%', Operator),                    # Transition Actions and Priorities
+            (r'\*|\?|\+|\{[0-9]*,[0-9]*\}', Operator),  # Repetition
+            (r'!|\^', Operator),                        # Negation
+            (r'\(|\)', Operator),                       # Grouping
+        ],
+        'root': [
+            include('literals'),
+            include('whitespace'),
+            include('comments'),
+            include('keywords'),
+            include('numbers'),
+            include('identifiers'),
+            include('operators'),
+            (r'\{', Punctuation, 'host'),
+            (r'=', Operator),
+            (r';', Punctuation),
+        ],
+        'host': [
+            (r'(' + r'|'.join((  # keep host code in largest possible chunks
+                r'[^{}\'"/#]+',  # exclude unsafe characters
+                r'[^\\]\\[{}]',  # allow escaped { or }
+
+                # strings and comments may safely contain unsafe characters
+                r'"(\\\\|\\[^\\]|[^"\\])*"',
+                r"'(\\\\|\\[^\\]|[^'\\])*'",
+                r'//.*$\n?',            # single line comment
+                r'/\*(.|\n)*?\*/',      # multi-line javadoc-style comment
+                r'\#.*$\n?',            # ruby comment
+
+                # regular expression: There's no reason for it to start
+                # with a * and this stops confusion with comments.
+                r'/(?!\*)(\\\\|\\[^\\]|[^/\\])*/',
+
+                # / is safe now that we've handled regex and javadoc comments
+                r'/',
+            )) + r')+', Other),
+
+            (r'\{', Punctuation, '#push'),
+            (r'\}', Punctuation, '#pop'),
+        ],
+    }
+
+
+class RagelEmbeddedLexer(RegexLexer):
+    """
+    A lexer for Ragel embedded in a host language file.
+
+    This will only highlight Ragel statements. If you want host language
+    highlighting then call the language-specific Ragel lexer.
+    """
+
+    name = 'Embedded Ragel'
+    aliases = ['ragel-em']
+    filenames = ['*.rl']
+    url = 'http://www.colm.net/open-source/ragel/'
+    version_added = '1.1'
+
+    tokens = {
+        'root': [
+            (r'(' + r'|'.join((   # keep host code in largest possible chunks
+                r'[^%\'"/#]+',    # exclude unsafe characters
+                r'%(?=[^%]|$)',   # a single % sign is okay, just not 2 of them
+
+                # strings and comments may safely contain unsafe characters
+                r'"(\\\\|\\[^\\]|[^"\\])*"',
+                r"'(\\\\|\\[^\\]|[^'\\])*'",
+                r'/\*(.|\n)*?\*/',      # multi-line javadoc-style comment
+                r'//.*$\n?',  # single line comment
+                r'\#.*$\n?',  # ruby/ragel comment
+                r'/(?!\*)(\\\\|\\[^\\]|[^/\\])*/',  # regular expression
+
+                # / is safe now that we've handled regex and javadoc comments
+                r'/',
+            )) + r')+', Other),
+
+            # Single Line FSM.
+            # Please don't put a quoted newline in a single line FSM.
+            # That's just mean. It will break this.
+            (r'(%%)(?![{%])(.*)($|;)(\n?)', bygroups(Punctuation,
+                                                     using(RagelLexer),
+                                                     Punctuation, Text)),
+
+            # Multi Line FSM.
+            (r'(%%%%|%%)\{', Punctuation, 'multi-line-fsm'),
+        ],
+        'multi-line-fsm': [
+            (r'(' + r'|'.join((  # keep ragel code in largest possible chunks.
+                r'(' + r'|'.join((
+                    r'[^}\'"\[/#]',   # exclude unsafe characters
+                    r'\}(?=[^%]|$)',   # } is okay as long as it's not followed by %
+                    r'\}%(?=[^%]|$)',  # ...well, one %'s okay, just not two...
+                    r'[^\\]\\[{}]',   # ...and } is okay if it's escaped
+
+                    # allow / if it's preceded with one of these symbols
+                    # (ragel EOF actions)
+                    r'(>|\$|%|<|@|<>)/',
+
+                    # specifically allow regex followed immediately by *
+                    # so it doesn't get mistaken for a comment
+                    r'/(?!\*)(\\\\|\\[^\\]|[^/\\])*/\*',
+
+                    # allow / as long as it's not followed by another / or by a *
+                    r'/(?=[^/*]|$)',
+
+                    # We want to match as many of these as we can in one block.
+                    # Not sure if we need the + sign here,
+                    # does it help performance?
+                )) + r')+',
+
+                # strings and comments may safely contain unsafe characters
+                r'"(\\\\|\\[^\\]|[^"\\])*"',
+                r"'(\\\\|\\[^\\]|[^'\\])*'",
+                r"\[(\\\\|\\[^\\]|[^\]\\])*\]",  # square bracket literal
+                r'/\*(.|\n)*?\*/',          # multi-line javadoc-style comment
+                r'//.*$\n?',                # single line comment
+                r'\#.*$\n?',                # ruby/ragel comment
+            )) + r')+', using(RagelLexer)),
+
+            (r'\}%%', Punctuation, '#pop'),
+        ]
+    }
+
+    def analyse_text(text):
+        return '@LANG: indep' in text
+
+
+class RagelRubyLexer(DelegatingLexer):
+    """
+    A lexer for Ragel in a Ruby host file.
+    """
+
+    name = 'Ragel in Ruby Host'
+    aliases = ['ragel-ruby', 'ragel-rb']
+    filenames = ['*.rl']
+    url = 'http://www.colm.net/open-source/ragel/'
+    version_added = '1.1'
+
+    def __init__(self, **options):
+        super().__init__(RubyLexer, RagelEmbeddedLexer, **options)
+
+    def analyse_text(text):
+        return '@LANG: ruby' in text
+
+
+class RagelCLexer(DelegatingLexer):
+    """
+    A lexer for Ragel in a C host file.
+    """
+
+    name = 'Ragel in C Host'
+    aliases = ['ragel-c']
+    filenames = ['*.rl']
+    url = 'http://www.colm.net/open-source/ragel/'
+    version_added = '1.1'
+
+    def __init__(self, **options):
+        super().__init__(CLexer, RagelEmbeddedLexer, **options)
+
+    def analyse_text(text):
+        return '@LANG: c' in text
+
+
+class RagelDLexer(DelegatingLexer):
+    """
+    A lexer for Ragel in a D host file.
+    """
+
+    name = 'Ragel in D Host'
+    aliases = ['ragel-d']
+    filenames = ['*.rl']
+    url = 'http://www.colm.net/open-source/ragel/'
+    version_added = '1.1'
+
+    def __init__(self, **options):
+        super().__init__(DLexer, RagelEmbeddedLexer, **options)
+
+    def analyse_text(text):
+        return '@LANG: d' in text
+
+
+class RagelCppLexer(DelegatingLexer):
+    """
+    A lexer for Ragel in a C++ host file.
+    """
+
+    name = 'Ragel in CPP Host'
+    aliases = ['ragel-cpp']
+    filenames = ['*.rl']
+    url = 'http://www.colm.net/open-source/ragel/'
+    version_added = '1.1'
+
+    def __init__(self, **options):
+        super().__init__(CppLexer, RagelEmbeddedLexer, **options)
+
+    def analyse_text(text):
+        return '@LANG: c++' in text
+
+
+class RagelObjectiveCLexer(DelegatingLexer):
+    """
+    A lexer for Ragel in an Objective C host file.
+    """
+
+    name = 'Ragel in Objective C Host'
+    aliases = ['ragel-objc']
+    filenames = ['*.rl']
+    url = 'http://www.colm.net/open-source/ragel/'
+    version_added = '1.1'
+
+    def __init__(self, **options):
+        super().__init__(ObjectiveCLexer, RagelEmbeddedLexer, **options)
+
+    def analyse_text(text):
+        return '@LANG: objc' in text
+
+
+class RagelJavaLexer(DelegatingLexer):
+    """
+    A lexer for Ragel in a Java host file.
+    """
+
+    name = 'Ragel in Java Host'
+    aliases = ['ragel-java']
+    filenames = ['*.rl']
+    url = 'http://www.colm.net/open-source/ragel/'
+    version_added = '1.1'
+
+    def __init__(self, **options):
+        super().__init__(JavaLexer, RagelEmbeddedLexer, **options)
+
+    def analyse_text(text):
+        return '@LANG: java' in text
+
+
+class AntlrLexer(RegexLexer):
+    """
+    Generic ANTLR Lexer.
+    Should not be called directly, instead
+    use DelegatingLexer for your target language.
+    """
+
+    name = 'ANTLR'
+    aliases = ['antlr']
+    filenames = []
+    url = 'https://www.antlr.org'
+    version_added = '1.1'
+
+    _id = r'[A-Za-z]\w*'
+    _TOKEN_REF = r'[A-Z]\w*'
+    _RULE_REF = r'[a-z]\w*'
+    _STRING_LITERAL = r'\'(?:\\\\|\\\'|[^\']*)\''
+    _INT = r'[0-9]+'
+
+    tokens = {
+        'whitespace': [
+            (r'\s+', Whitespace),
+        ],
+        'comments': [
+            (r'//.*$', Comment),
+            (r'/\*(.|\n)*?\*/', Comment),
+        ],
+        'root': [
+            include('whitespace'),
+            include('comments'),
+
+            (r'(lexer|parser|tree)?(\s*)(grammar\b)(\s*)(' + _id + ')(;)',
+             bygroups(Keyword, Whitespace, Keyword, Whitespace, Name.Class,
+                      Punctuation)),
+            # optionsSpec
+            (r'options\b', Keyword, 'options'),
+            # tokensSpec
+            (r'tokens\b', Keyword, 'tokens'),
+            # attrScope
+            (r'(scope)(\s*)(' + _id + r')(\s*)(\{)',
+             bygroups(Keyword, Whitespace, Name.Variable, Whitespace,
+                      Punctuation), 'action'),
+            # exception
+            (r'(catch|finally)\b', Keyword, 'exception'),
+            # action
+            (r'(@' + _id + r')(\s*)(::)?(\s*)(' + _id + r')(\s*)(\{)',
+             bygroups(Name.Label, Whitespace, Punctuation, Whitespace,
+                      Name.Label, Whitespace, Punctuation), 'action'),
+            # rule
+            (r'((?:protected|private|public|fragment)\b)?(\s*)(' + _id + ')(!)?',
+             bygroups(Keyword, Whitespace, Name.Label, Punctuation),
+             ('rule-alts', 'rule-prelims')),
+        ],
+        'exception': [
+            (r'\n', Whitespace, '#pop'),
+            (r'\s', Whitespace),
+            include('comments'),
+
+            (r'\[', Punctuation, 'nested-arg-action'),
+            (r'\{', Punctuation, 'action'),
+        ],
+        'rule-prelims': [
+            include('whitespace'),
+            include('comments'),
+
+            (r'returns\b', Keyword),
+            (r'\[', Punctuation, 'nested-arg-action'),
+            (r'\{', Punctuation, 'action'),
+            # throwsSpec
+            (r'(throws)(\s+)(' + _id + ')',
+             bygroups(Keyword, Whitespace, Name.Label)),
+            (r'(,)(\s*)(' + _id + ')',
+             bygroups(Punctuation, Whitespace, Name.Label)),  # Additional throws
+            # optionsSpec
+            (r'options\b', Keyword, 'options'),
+            # ruleScopeSpec - scope followed by target language code or name of action
+            # TODO finish implementing other possibilities for scope
+            # L173 ANTLRv3.g from ANTLR book
+            (r'(scope)(\s+)(\{)', bygroups(Keyword, Whitespace, Punctuation),
+             'action'),
+            (r'(scope)(\s+)(' + _id + r')(\s*)(;)',
+             bygroups(Keyword, Whitespace, Name.Label, Whitespace, Punctuation)),
+            # ruleAction
+            (r'(@' + _id + r')(\s*)(\{)',
+             bygroups(Name.Label, Whitespace, Punctuation), 'action'),
+            # finished prelims, go to rule alts!
+            (r':', Punctuation, '#pop')
+        ],
+        'rule-alts': [
+            include('whitespace'),
+            include('comments'),
+
+            # These might need to go in a separate 'block' state triggered by (
+            (r'options\b', Keyword, 'options'),
+            (r':', Punctuation),
+
+            # literals
+            (r'"(\\\\|\\[^\\]|[^"\\])*"', String.Double),
+            (r"'(\\\\|\\[^\\]|[^'\\])*'", String.Single),
+            (r'<<([^>]|>[^>])>>', String),
+            # identifiers
+            # Tokens start with capital letter.
+            (r'\$?[A-Z_]\w*', Name.Constant),
+            # Rules start with small letter.
+            (r'\$?[a-z_]\w*', Name.Variable),
+            # operators
+            (r'(\+|\||->|=>|=|\(|\)|\.\.|\.|\?|\*|\^|!|\#|~)', Operator),
+            (r',', Punctuation),
+            (r'\[', Punctuation, 'nested-arg-action'),
+            (r'\{', Punctuation, 'action'),
+            (r';', Punctuation, '#pop')
+        ],
+        'tokens': [
+            include('whitespace'),
+            include('comments'),
+            (r'\{', Punctuation),
+            (r'(' + _TOKEN_REF + r')(\s*)(=)?(\s*)(' + _STRING_LITERAL
+             + r')?(\s*)(;)',
+             bygroups(Name.Label, Whitespace, Punctuation, Whitespace,
+                      String, Whitespace, Punctuation)),
+            (r'\}', Punctuation, '#pop'),
+        ],
+        'options': [
+            include('whitespace'),
+            include('comments'),
+            (r'\{', Punctuation),
+            (r'(' + _id + r')(\s*)(=)(\s*)(' +
+             '|'.join((_id, _STRING_LITERAL, _INT, r'\*')) + r')(\s*)(;)',
+             bygroups(Name.Variable, Whitespace, Punctuation, Whitespace,
+                      Text, Whitespace, Punctuation)),
+            (r'\}', Punctuation, '#pop'),
+        ],
+        'action': [
+            (r'(' + r'|'.join((    # keep host code in largest possible chunks
+                r'[^${}\'"/\\]+',  # exclude unsafe characters
+
+                # strings and comments may safely contain unsafe characters
+                r'"(\\\\|\\[^\\]|[^"\\])*"',
+                r"'(\\\\|\\[^\\]|[^'\\])*'",
+                r'//.*$\n?',            # single line comment
+                r'/\*(.|\n)*?\*/',      # multi-line javadoc-style comment
+
+                # regular expression: There's no reason for it to start
+                # with a * and this stops confusion with comments.
+                r'/(?!\*)(\\\\|\\[^\\]|[^/\\])*/',
+
+                # backslashes are okay, as long as we are not backslashing a %
+                r'\\(?!%)',
+
+                # Now that we've handled regex and javadoc comments
+                # it's safe to let / through.
+                r'/',
+            )) + r')+', Other),
+            (r'(\\)(%)', bygroups(Punctuation, Other)),
+            (r'(\$[a-zA-Z]+)(\.?)(text|value)?',
+             bygroups(Name.Variable, Punctuation, Name.Property)),
+            (r'\{', Punctuation, '#push'),
+            (r'\}', Punctuation, '#pop'),
+        ],
+        'nested-arg-action': [
+            (r'(' + r'|'.join((    # keep host code in largest possible chunks.
+                r'[^$\[\]\'"/]+',  # exclude unsafe characters
+
+                # strings and comments may safely contain unsafe characters
+                r'"(\\\\|\\[^\\]|[^"\\])*"',
+                r"'(\\\\|\\[^\\]|[^'\\])*'",
+                r'//.*$\n?',            # single line comment
+                r'/\*(.|\n)*?\*/',      # multi-line javadoc-style comment
+
+                # regular expression: There's no reason for it to start
+                # with a * and this stops confusion with comments.
+                r'/(?!\*)(\\\\|\\[^\\]|[^/\\])*/',
+
+                # Now that we've handled regex and javadoc comments
+                # it's safe to let / through.
+                r'/',
+            )) + r')+', Other),
+
+
+            (r'\[', Punctuation, '#push'),
+            (r'\]', Punctuation, '#pop'),
+            (r'(\$[a-zA-Z]+)(\.?)(text|value)?',
+             bygroups(Name.Variable, Punctuation, Name.Property)),
+            (r'(\\\\|\\\]|\\\[|[^\[\]])+', Other),
+        ]
+    }
+
+    def analyse_text(text):
+        return re.search(r'^\s*grammar\s+[a-zA-Z0-9]+\s*;', text, re.M)
+
+
+# http://www.antlr.org/wiki/display/ANTLR3/Code+Generation+Targets
+
+class AntlrCppLexer(DelegatingLexer):
+    """
+    ANTLR with C++ Target
+    """
+
+    name = 'ANTLR With CPP Target'
+    aliases = ['antlr-cpp']
+    filenames = ['*.G', '*.g']
+    url = 'https://www.antlr.org'
+    version_added = '1.1'
+
+    def __init__(self, **options):
+        super().__init__(CppLexer, AntlrLexer, **options)
+
+    def analyse_text(text):
+        return AntlrLexer.analyse_text(text) and \
+            re.search(r'^\s*language\s*=\s*C\s*;', text, re.M)
+
+
+class AntlrObjectiveCLexer(DelegatingLexer):
+    """
+    ANTLR with Objective-C Target
+    """
+
+    name = 'ANTLR With ObjectiveC Target'
+    aliases = ['antlr-objc']
+    filenames = ['*.G', '*.g']
+    url = 'https://www.antlr.org'
+    version_added = '1.1'
+
+    def __init__(self, **options):
+        super().__init__(ObjectiveCLexer, AntlrLexer, **options)
+
+    def analyse_text(text):
+        return AntlrLexer.analyse_text(text) and \
+            re.search(r'^\s*language\s*=\s*ObjC\s*;', text)
+
+
+class AntlrCSharpLexer(DelegatingLexer):
+    """
+    ANTLR with C# Target
+    """
+
+    name = 'ANTLR With C# Target'
+    aliases = ['antlr-csharp', 'antlr-c#']
+    filenames = ['*.G', '*.g']
+    url = 'https://www.antlr.org'
+    version_added = '1.1'
+
+    def __init__(self, **options):
+        super().__init__(CSharpLexer, AntlrLexer, **options)
+
+    def analyse_text(text):
+        return AntlrLexer.analyse_text(text) and \
+            re.search(r'^\s*language\s*=\s*CSharp2\s*;', text, re.M)
+
+
+class AntlrPythonLexer(DelegatingLexer):
+    """
+    ANTLR with Python Target
+    """
+
+    name = 'ANTLR With Python Target'
+    aliases = ['antlr-python']
+    filenames = ['*.G', '*.g']
+    url = 'https://www.antlr.org'
+    version_added = '1.1'
+
+    def __init__(self, **options):
+        super().__init__(PythonLexer, AntlrLexer, **options)
+
+    def analyse_text(text):
+        return AntlrLexer.analyse_text(text) and \
+            re.search(r'^\s*language\s*=\s*Python\s*;', text, re.M)
+
+
+class AntlrJavaLexer(DelegatingLexer):
+    """
+    ANTLR with Java Target
+    """
+
+    name = 'ANTLR With Java Target'
+    aliases = ['antlr-java']
+    filenames = ['*.G', '*.g']
+    url = 'https://www.antlr.org'
+    version_added = '1.1'
+
+    def __init__(self, **options):
+        super().__init__(JavaLexer, AntlrLexer, **options)
+
+    def analyse_text(text):
+        # Antlr language is Java by default
+        return AntlrLexer.analyse_text(text) and 0.9
+
+
+class AntlrRubyLexer(DelegatingLexer):
+    """
+    ANTLR with Ruby Target
+    """
+
+    name = 'ANTLR With Ruby Target'
+    aliases = ['antlr-ruby', 'antlr-rb']
+    filenames = ['*.G', '*.g']
+    url = 'https://www.antlr.org'
+    version_added = '1.1'
+
+    def __init__(self, **options):
+        super().__init__(RubyLexer, AntlrLexer, **options)
+
+    def analyse_text(text):
+        return AntlrLexer.analyse_text(text) and \
+            re.search(r'^\s*language\s*=\s*Ruby\s*;', text, re.M)
+
+
+class AntlrPerlLexer(DelegatingLexer):
+    """
+    ANTLR with Perl Target
+    """
+
+    name = 'ANTLR With Perl Target'
+    aliases = ['antlr-perl']
+    filenames = ['*.G', '*.g']
+    url = 'https://www.antlr.org'
+    version_added = '1.1'
+
+    def __init__(self, **options):
+        super().__init__(PerlLexer, AntlrLexer, **options)
+
+    def analyse_text(text):
+        return AntlrLexer.analyse_text(text) and \
+            re.search(r'^\s*language\s*=\s*Perl5\s*;', text, re.M)
+
+
+class AntlrActionScriptLexer(DelegatingLexer):
+    """
+    ANTLR with ActionScript Target
+    """
+
+    name = 'ANTLR With ActionScript Target'
+    aliases = ['antlr-actionscript', 'antlr-as']
+    filenames = ['*.G', '*.g']
+    url = 'https://www.antlr.org'
+    version_added = '1.1'
+
+    def __init__(self, **options):
+        from pygments.lexers.actionscript import ActionScriptLexer
+        super().__init__(ActionScriptLexer, AntlrLexer, **options)
+
+    def analyse_text(text):
+        return AntlrLexer.analyse_text(text) and \
+            re.search(r'^\s*language\s*=\s*ActionScript\s*;', text, re.M)
+
+
+class TreetopBaseLexer(RegexLexer):
+    """
+    A base lexer for `Treetop `_ grammars.
+    Not for direct use; use :class:`TreetopLexer` instead.
+
+    .. versionadded:: 1.6
+    """
+
+    tokens = {
+        'root': [
+            include('space'),
+            (r'require[ \t]+[^\n\r]+[\n\r]', Other),
+            (r'module\b', Keyword.Namespace, 'module'),
+            (r'grammar\b', Keyword, 'grammar'),
+        ],
+        'module': [
+            include('space'),
+            include('end'),
+            (r'module\b', Keyword, '#push'),
+            (r'grammar\b', Keyword, 'grammar'),
+            (r'[A-Z]\w*(?:::[A-Z]\w*)*', Name.Namespace),
+        ],
+        'grammar': [
+            include('space'),
+            include('end'),
+            (r'rule\b', Keyword, 'rule'),
+            (r'include\b', Keyword, 'include'),
+            (r'[A-Z]\w*', Name),
+        ],
+        'include': [
+            include('space'),
+            (r'[A-Z]\w*(?:::[A-Z]\w*)*', Name.Class, '#pop'),
+        ],
+        'rule': [
+            include('space'),
+            include('end'),
+            (r'"(\\\\|\\[^\\]|[^"\\])*"', String.Double),
+            (r"'(\\\\|\\[^\\]|[^'\\])*'", String.Single),
+            (r'([A-Za-z_]\w*)(:)', bygroups(Name.Label, Punctuation)),
+            (r'[A-Za-z_]\w*', Name),
+            (r'[()]', Punctuation),
+            (r'[?+*/&!~]', Operator),
+            (r'\[(?:\\.|\[:\^?[a-z]+:\]|[^\\\]])+\]', String.Regex),
+            (r'([0-9]*)(\.\.)([0-9]*)',
+             bygroups(Number.Integer, Operator, Number.Integer)),
+            (r'(<)([^>]+)(>)', bygroups(Punctuation, Name.Class, Punctuation)),
+            (r'\{', Punctuation, 'inline_module'),
+            (r'\.', String.Regex),
+        ],
+        'inline_module': [
+            (r'\{', Other, 'ruby'),
+            (r'\}', Punctuation, '#pop'),
+            (r'[^{}]+', Other),
+        ],
+        'ruby': [
+            (r'\{', Other, '#push'),
+            (r'\}', Other, '#pop'),
+            (r'[^{}]+', Other),
+        ],
+        'space': [
+            (r'[ \t\n\r]+', Whitespace),
+            (r'#[^\n]*', Comment.Single),
+        ],
+        'end': [
+            (r'end\b', Keyword, '#pop'),
+        ],
+    }
+
+
+class TreetopLexer(DelegatingLexer):
+    """
+    A lexer for Treetop grammars.
+    """
+
+    name = 'Treetop'
+    aliases = ['treetop']
+    filenames = ['*.treetop', '*.tt']
+    url = 'https://cjheath.github.io/treetop'
+    version_added = '1.6'
+
+    def __init__(self, **options):
+        super().__init__(RubyLexer, TreetopBaseLexer, **options)
+
+
+class EbnfLexer(RegexLexer):
+    """
+    Lexer for `ISO/IEC 14977 EBNF
+    `_
+    grammars.
+    """
+
+    name = 'EBNF'
+    aliases = ['ebnf']
+    filenames = ['*.ebnf']
+    mimetypes = ['text/x-ebnf']
+    url = 'https://en.wikipedia.org/wiki/Extended_Backus%E2%80%93Naur_Form'
+    version_added = '2.0'
+
+    tokens = {
+        'root': [
+            include('whitespace'),
+            include('comment_start'),
+            include('identifier'),
+            (r'=', Operator, 'production'),
+        ],
+        'production': [
+            include('whitespace'),
+            include('comment_start'),
+            include('identifier'),
+            (r'"[^"]*"', String.Double),
+            (r"'[^']*'", String.Single),
+            (r'(\?[^?]*\?)', Name.Entity),
+            (r'[\[\]{}(),|]', Punctuation),
+            (r'-', Operator),
+            (r';', Punctuation, '#pop'),
+            (r'\.', Punctuation, '#pop'),
+        ],
+        'whitespace': [
+            (r'\s+', Text),
+        ],
+        'comment_start': [
+            (r'\(\*', Comment.Multiline, 'comment'),
+        ],
+        'comment': [
+            (r'[^*)]', Comment.Multiline),
+            include('comment_start'),
+            (r'\*\)', Comment.Multiline, '#pop'),
+            (r'[*)]', Comment.Multiline),
+        ],
+        'identifier': [
+            (r'([a-zA-Z][\w \-]*)', Keyword),
+        ],
+    }
diff --git a/pygments/lexers/pascal.py b/pygments/lexers/pascal.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f40dcc86a1fc26e0277ffd311f6086d231fc33e
--- /dev/null
+++ b/pygments/lexers/pascal.py
@@ -0,0 +1,644 @@
+"""
+    pygments.lexers.pascal
+    ~~~~~~~~~~~~~~~~~~~~~~
+
+    Lexers for Pascal family languages.
+
+    :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS.
+    :license: BSD, see LICENSE for details.
+"""
+
+import re
+
+from pygments.lexer import Lexer
+from pygments.util import get_bool_opt, get_list_opt
+from pygments.token import Comment, Operator, Keyword, Name, String, \
+    Number, Punctuation, Error, Whitespace
+from pygments.scanner import Scanner
+
+# compatibility import
+from pygments.lexers.modula2 import Modula2Lexer # noqa: F401
+
+__all__ = ['DelphiLexer', 'PortugolLexer']
+
+
+class PortugolLexer(Lexer):
+    """For Portugol, a Pascal dialect with keywords in Portuguese."""
+    name = 'Portugol'
+    aliases = ['portugol']
+    filenames = ['*.alg', '*.portugol']
+    mimetypes = []
+    url = "https://www.apoioinformatica.inf.br/produtos/visualg/linguagem"
+    version_added = ''
+
+    def __init__(self, **options):
+        Lexer.__init__(self, **options)
+        self.lexer = DelphiLexer(**options, portugol=True)
+
+    def get_tokens_unprocessed(self, text):
+        return self.lexer.get_tokens_unprocessed(text)
+
+
+class DelphiLexer(Lexer):
+    """
+    For Delphi (Borland Object Pascal),
+    Turbo Pascal and Free Pascal source code.
+
+    Additional options accepted:
+
+    `turbopascal`
+        Highlight Turbo Pascal specific keywords (default: ``True``).
+    `delphi`
+        Highlight Borland Delphi specific keywords (default: ``True``).
+    `freepascal`
+        Highlight Free Pascal specific keywords (default: ``True``).
+    `units`
+        A list of units that should be considered builtin, supported are
+        ``System``, ``SysUtils``, ``Classes`` and ``Math``.
+        Default is to consider all of them builtin.
+    """
+    name = 'Delphi'
+    aliases = ['delphi', 'pas', 'pascal', 'objectpascal']
+    filenames = ['*.pas', '*.dpr']
+    mimetypes = ['text/x-pascal']
+    url = 'https://www.embarcadero.com/products/delphi'
+    version_added = ''
+
+    TURBO_PASCAL_KEYWORDS = (
+        'absolute', 'and', 'array', 'asm', 'begin', 'break', 'case',
+        'const', 'constructor', 'continue', 'destructor', 'div', 'do',
+        'downto', 'else', 'end', 'file', 'for', 'function', 'goto',
+        'if', 'implementation', 'in', 'inherited', 'inline', 'interface',
+        'label', 'mod', 'nil', 'not', 'object', 'of', 'on', 'operator',
+        'or', 'packed', 'procedure', 'program', 'record', 'reintroduce',
+        'repeat', 'self', 'set', 'shl', 'shr', 'string', 'then', 'to',
+        'type', 'unit', 'until', 'uses', 'var', 'while', 'with', 'xor'
+    )
+
+    DELPHI_KEYWORDS = (
+        'as', 'class', 'except', 'exports', 'finalization', 'finally',
+        'initialization', 'is', 'library', 'on', 'property', 'raise',
+        'threadvar', 'try'
+    )
+
+    FREE_PASCAL_KEYWORDS = (
+        'dispose', 'exit', 'false', 'new', 'true'
+    )
+
+    BLOCK_KEYWORDS = {
+        'begin', 'class', 'const', 'constructor', 'destructor', 'end',
+        'finalization', 'function', 'implementation', 'initialization',
+        'label', 'library', 'operator', 'procedure', 'program', 'property',
+        'record', 'threadvar', 'type', 'unit', 'uses', 'var'
+    }
+
+    FUNCTION_MODIFIERS = {
+        'alias', 'cdecl', 'export', 'inline', 'interrupt', 'nostackframe',
+        'pascal', 'register', 'safecall', 'softfloat', 'stdcall',
+        'varargs', 'name', 'dynamic', 'near', 'virtual', 'external',
+        'override', 'assembler'
+    }
+
+    # XXX: those aren't global. but currently we know no way for defining
+    #      them just for the type context.
+    DIRECTIVES = {
+        'absolute', 'abstract', 'assembler', 'cppdecl', 'default', 'far',
+        'far16', 'forward', 'index', 'oldfpccall', 'private', 'protected',
+        'published', 'public'
+    }
+
+    BUILTIN_TYPES = {
+        'ansichar', 'ansistring', 'bool', 'boolean', 'byte', 'bytebool',
+        'cardinal', 'char', 'comp', 'currency', 'double', 'dword',
+        'extended', 'int64', 'integer', 'iunknown', 'longbool', 'longint',
+        'longword', 'pansichar', 'pansistring', 'pbool', 'pboolean',
+        'pbyte', 'pbytearray', 'pcardinal', 'pchar', 'pcomp', 'pcurrency',
+        'pdate', 'pdatetime', 'pdouble', 'pdword', 'pextended', 'phandle',
+        'pint64', 'pinteger', 'plongint', 'plongword', 'pointer',
+        'ppointer', 'pshortint', 'pshortstring', 'psingle', 'psmallint',
+        'pstring', 'pvariant', 'pwidechar', 'pwidestring', 'pword',
+        'pwordarray', 'pwordbool', 'real', 'real48', 'shortint',
+        'shortstring', 'single', 'smallint', 'string', 'tclass', 'tdate',
+        'tdatetime', 'textfile', 'thandle', 'tobject', 'ttime', 'variant',
+        'widechar', 'widestring', 'word', 'wordbool'
+    }
+
+    BUILTIN_UNITS = {
+        'System': (
+            'abs', 'acquireexceptionobject', 'addr', 'ansitoutf8',
+            'append', 'arctan', 'assert', 'assigned', 'assignfile',
+            'beginthread', 'blockread', 'blockwrite', 'break', 'chdir',
+            'chr', 'close', 'closefile', 'comptocurrency', 'comptodouble',
+            'concat', 'continue', 'copy', 'cos', 'dec', 'delete',
+            'dispose', 'doubletocomp', 'endthread', 'enummodules',
+            'enumresourcemodules', 'eof', 'eoln', 'erase', 'exceptaddr',
+            'exceptobject', 'exclude', 'exit', 'exp', 'filepos', 'filesize',
+            'fillchar', 'finalize', 'findclasshinstance', 'findhinstance',
+            'findresourcehinstance', 'flush', 'frac', 'freemem',
+            'get8087cw', 'getdir', 'getlasterror', 'getmem',
+            'getmemorymanager', 'getmodulefilename', 'getvariantmanager',
+            'halt', 'hi', 'high', 'inc', 'include', 'initialize', 'insert',
+            'int', 'ioresult', 'ismemorymanagerset', 'isvariantmanagerset',
+            'length', 'ln', 'lo', 'low', 'mkdir', 'move', 'new', 'odd',
+            'olestrtostring', 'olestrtostrvar', 'ord', 'paramcount',
+            'paramstr', 'pi', 'pos', 'pred', 'ptr', 'pucs4chars', 'random',
+            'randomize', 'read', 'readln', 'reallocmem',
+            'releaseexceptionobject', 'rename', 'reset', 'rewrite', 'rmdir',
+            'round', 'runerror', 'seek', 'seekeof', 'seekeoln',
+            'set8087cw', 'setlength', 'setlinebreakstyle',
+            'setmemorymanager', 'setstring', 'settextbuf',
+            'setvariantmanager', 'sin', 'sizeof', 'slice', 'sqr', 'sqrt',
+            'str', 'stringofchar', 'stringtoolestr', 'stringtowidechar',
+            'succ', 'swap', 'trunc', 'truncate', 'typeinfo',
+            'ucs4stringtowidestring', 'unicodetoutf8', 'uniquestring',
+            'upcase', 'utf8decode', 'utf8encode', 'utf8toansi',
+            'utf8tounicode', 'val', 'vararrayredim', 'varclear',
+            'widecharlentostring', 'widecharlentostrvar',
+            'widechartostring', 'widechartostrvar',
+            'widestringtoucs4string', 'write', 'writeln'
+        ),
+        'SysUtils': (
+            'abort', 'addexitproc', 'addterminateproc', 'adjustlinebreaks',
+            'allocmem', 'ansicomparefilename', 'ansicomparestr',
+            'ansicomparetext', 'ansidequotedstr', 'ansiextractquotedstr',
+            'ansilastchar', 'ansilowercase', 'ansilowercasefilename',
+            'ansipos', 'ansiquotedstr', 'ansisamestr', 'ansisametext',
+            'ansistrcomp', 'ansistricomp', 'ansistrlastchar', 'ansistrlcomp',
+            'ansistrlicomp', 'ansistrlower', 'ansistrpos', 'ansistrrscan',
+            'ansistrscan', 'ansistrupper', 'ansiuppercase',
+            'ansiuppercasefilename', 'appendstr', 'assignstr', 'beep',
+            'booltostr', 'bytetocharindex', 'bytetocharlen', 'bytetype',
+            'callterminateprocs', 'changefileext', 'charlength',
+            'chartobyteindex', 'chartobytelen', 'comparemem', 'comparestr',
+            'comparetext', 'createdir', 'createguid', 'currentyear',
+            'currtostr', 'currtostrf', 'date', 'datetimetofiledate',
+            'datetimetostr', 'datetimetostring', 'datetimetosystemtime',
+            'datetimetotimestamp', 'datetostr', 'dayofweek', 'decodedate',
+            'decodedatefully', 'decodetime', 'deletefile', 'directoryexists',
+            'diskfree', 'disksize', 'disposestr', 'encodedate', 'encodetime',
+            'exceptionerrormessage', 'excludetrailingbackslash',
+            'excludetrailingpathdelimiter', 'expandfilename',
+            'expandfilenamecase', 'expanduncfilename', 'extractfiledir',
+            'extractfiledrive', 'extractfileext', 'extractfilename',
+            'extractfilepath', 'extractrelativepath', 'extractshortpathname',
+            'fileage', 'fileclose', 'filecreate', 'filedatetodatetime',
+            'fileexists', 'filegetattr', 'filegetdate', 'fileisreadonly',
+            'fileopen', 'fileread', 'filesearch', 'fileseek', 'filesetattr',
+            'filesetdate', 'filesetreadonly', 'filewrite', 'finalizepackage',
+            'findclose', 'findcmdlineswitch', 'findfirst', 'findnext',
+            'floattocurr', 'floattodatetime', 'floattodecimal', 'floattostr',
+            'floattostrf', 'floattotext', 'floattotextfmt', 'fmtloadstr',
+            'fmtstr', 'forcedirectories', 'format', 'formatbuf', 'formatcurr',
+            'formatdatetime', 'formatfloat', 'freeandnil', 'getcurrentdir',
+            'getenvironmentvariable', 'getfileversion', 'getformatsettings',
+            'getlocaleformatsettings', 'getmodulename', 'getpackagedescription',
+            'getpackageinfo', 'gettime', 'guidtostring', 'incamonth',
+            'includetrailingbackslash', 'includetrailingpathdelimiter',
+            'incmonth', 'initializepackage', 'interlockeddecrement',
+            'interlockedexchange', 'interlockedexchangeadd',
+            'interlockedincrement', 'inttohex', 'inttostr', 'isdelimiter',
+            'isequalguid', 'isleapyear', 'ispathdelimiter', 'isvalidident',
+            'languages', 'lastdelimiter', 'loadpackage', 'loadstr',
+            'lowercase', 'msecstotimestamp', 'newstr', 'nextcharindex', 'now',
+            'outofmemoryerror', 'quotedstr', 'raiselastoserror',
+            'raiselastwin32error', 'removedir', 'renamefile', 'replacedate',
+            'replacetime', 'safeloadlibrary', 'samefilename', 'sametext',
+            'setcurrentdir', 'showexception', 'sleep', 'stralloc', 'strbufsize',
+            'strbytetype', 'strcat', 'strcharlength', 'strcomp', 'strcopy',
+            'strdispose', 'strecopy', 'strend', 'strfmt', 'stricomp',
+            'stringreplace', 'stringtoguid', 'strlcat', 'strlcomp', 'strlcopy',
+            'strlen', 'strlfmt', 'strlicomp', 'strlower', 'strmove', 'strnew',
+            'strnextchar', 'strpas', 'strpcopy', 'strplcopy', 'strpos',
+            'strrscan', 'strscan', 'strtobool', 'strtobooldef', 'strtocurr',
+            'strtocurrdef', 'strtodate', 'strtodatedef', 'strtodatetime',
+            'strtodatetimedef', 'strtofloat', 'strtofloatdef', 'strtoint',
+            'strtoint64', 'strtoint64def', 'strtointdef', 'strtotime',
+            'strtotimedef', 'strupper', 'supports', 'syserrormessage',
+            'systemtimetodatetime', 'texttofloat', 'time', 'timestamptodatetime',
+            'timestamptomsecs', 'timetostr', 'trim', 'trimleft', 'trimright',
+            'tryencodedate', 'tryencodetime', 'tryfloattocurr', 'tryfloattodatetime',
+            'trystrtobool', 'trystrtocurr', 'trystrtodate', 'trystrtodatetime',
+            'trystrtofloat', 'trystrtoint', 'trystrtoint64', 'trystrtotime',
+            'unloadpackage', 'uppercase', 'widecomparestr', 'widecomparetext',
+            'widefmtstr', 'wideformat', 'wideformatbuf', 'widelowercase',
+            'widesamestr', 'widesametext', 'wideuppercase', 'win32check',
+            'wraptext'
+        ),
+        'Classes': (
+            'activateclassgroup', 'allocatehwnd', 'bintohex', 'checksynchronize',
+            'collectionsequal', 'countgenerations', 'deallocatehwnd', 'equalrect',
+            'extractstrings', 'findclass', 'findglobalcomponent', 'getclass',
+            'groupdescendantswith', 'hextobin', 'identtoint',
+            'initinheritedcomponent', 'inttoident', 'invalidpoint',
+            'isuniqueglobalcomponentname', 'linestart', 'objectbinarytotext',
+            'objectresourcetotext', 'objecttexttobinary', 'objecttexttoresource',
+            'pointsequal', 'readcomponentres', 'readcomponentresex',
+            'readcomponentresfile', 'rect', 'registerclass', 'registerclassalias',
+            'registerclasses', 'registercomponents', 'registerintegerconsts',
+            'registernoicon', 'registernonactivex', 'smallpoint', 'startclassgroup',
+            'teststreamformat', 'unregisterclass', 'unregisterclasses',
+            'unregisterintegerconsts', 'unregistermoduleclasses',
+            'writecomponentresfile'
+        ),
+        'Math': (
+            'arccos', 'arccosh', 'arccot', 'arccoth', 'arccsc', 'arccsch', 'arcsec',
+            'arcsech', 'arcsin', 'arcsinh', 'arctan2', 'arctanh', 'ceil',
+            'comparevalue', 'cosecant', 'cosh', 'cot', 'cotan', 'coth', 'csc',
+            'csch', 'cycletodeg', 'cycletograd', 'cycletorad', 'degtocycle',
+            'degtograd', 'degtorad', 'divmod', 'doubledecliningbalance',
+            'ensurerange', 'floor', 'frexp', 'futurevalue', 'getexceptionmask',
+            'getprecisionmode', 'getroundmode', 'gradtocycle', 'gradtodeg',
+            'gradtorad', 'hypot', 'inrange', 'interestpayment', 'interestrate',
+            'internalrateofreturn', 'intpower', 'isinfinite', 'isnan', 'iszero',
+            'ldexp', 'lnxp1', 'log10', 'log2', 'logn', 'max', 'maxintvalue',
+            'maxvalue', 'mean', 'meanandstddev', 'min', 'minintvalue', 'minvalue',
+            'momentskewkurtosis', 'netpresentvalue', 'norm', 'numberofperiods',
+            'payment', 'periodpayment', 'poly', 'popnstddev', 'popnvariance',
+            'power', 'presentvalue', 'radtocycle', 'radtodeg', 'radtograd',
+            'randg', 'randomrange', 'roundto', 'samevalue', 'sec', 'secant',
+            'sech', 'setexceptionmask', 'setprecisionmode', 'setroundmode',
+            'sign', 'simpleroundto', 'sincos', 'sinh', 'slndepreciation', 'stddev',
+            'sum', 'sumint', 'sumofsquares', 'sumsandsquares', 'syddepreciation',
+            'tan', 'tanh', 'totalvariance', 'variance'
+        )
+    }
+
+    ASM_REGISTERS = {
+        'ah', 'al', 'ax', 'bh', 'bl', 'bp', 'bx', 'ch', 'cl', 'cr0',
+        'cr1', 'cr2', 'cr3', 'cr4', 'cs', 'cx', 'dh', 'di', 'dl', 'dr0',
+        'dr1', 'dr2', 'dr3', 'dr4', 'dr5', 'dr6', 'dr7', 'ds', 'dx',
+        'eax', 'ebp', 'ebx', 'ecx', 'edi', 'edx', 'es', 'esi', 'esp',
+        'fs', 'gs', 'mm0', 'mm1', 'mm2', 'mm3', 'mm4', 'mm5', 'mm6',
+        'mm7', 'si', 'sp', 'ss', 'st0', 'st1', 'st2', 'st3', 'st4', 'st5',
+        'st6', 'st7', 'xmm0', 'xmm1', 'xmm2', 'xmm3', 'xmm4', 'xmm5',
+        'xmm6', 'xmm7'
+    }
+
+    ASM_INSTRUCTIONS = {
+        'aaa', 'aad', 'aam', 'aas', 'adc', 'add', 'and', 'arpl', 'bound',
+        'bsf', 'bsr', 'bswap', 'bt', 'btc', 'btr', 'bts', 'call', 'cbw',
+        'cdq', 'clc', 'cld', 'cli', 'clts', 'cmc', 'cmova', 'cmovae',
+        'cmovb', 'cmovbe', 'cmovc', 'cmovcxz', 'cmove', 'cmovg',
+        'cmovge', 'cmovl', 'cmovle', 'cmovna', 'cmovnae', 'cmovnb',
+        'cmovnbe', 'cmovnc', 'cmovne', 'cmovng', 'cmovnge', 'cmovnl',
+        'cmovnle', 'cmovno', 'cmovnp', 'cmovns', 'cmovnz', 'cmovo',
+        'cmovp', 'cmovpe', 'cmovpo', 'cmovs', 'cmovz', 'cmp', 'cmpsb',
+        'cmpsd', 'cmpsw', 'cmpxchg', 'cmpxchg486', 'cmpxchg8b', 'cpuid',
+        'cwd', 'cwde', 'daa', 'das', 'dec', 'div', 'emms', 'enter', 'hlt',
+        'ibts', 'icebp', 'idiv', 'imul', 'in', 'inc', 'insb', 'insd',
+        'insw', 'int', 'int01', 'int03', 'int1', 'int3', 'into', 'invd',
+        'invlpg', 'iret', 'iretd', 'iretw', 'ja', 'jae', 'jb', 'jbe',
+        'jc', 'jcxz', 'jcxz', 'je', 'jecxz', 'jg', 'jge', 'jl', 'jle',
+        'jmp', 'jna', 'jnae', 'jnb', 'jnbe', 'jnc', 'jne', 'jng', 'jnge',
+        'jnl', 'jnle', 'jno', 'jnp', 'jns', 'jnz', 'jo', 'jp', 'jpe',
+        'jpo', 'js', 'jz', 'lahf', 'lar', 'lcall', 'lds', 'lea', 'leave',
+        'les', 'lfs', 'lgdt', 'lgs', 'lidt', 'ljmp', 'lldt', 'lmsw',
+        'loadall', 'loadall286', 'lock', 'lodsb', 'lodsd', 'lodsw',
+        'loop', 'loope', 'loopne', 'loopnz', 'loopz', 'lsl', 'lss', 'ltr',
+        'mov', 'movd', 'movq', 'movsb', 'movsd', 'movsw', 'movsx',
+        'movzx', 'mul', 'neg', 'nop', 'not', 'or', 'out', 'outsb', 'outsd',
+        'outsw', 'pop', 'popa', 'popad', 'popaw', 'popf', 'popfd', 'popfw',
+        'push', 'pusha', 'pushad', 'pushaw', 'pushf', 'pushfd', 'pushfw',
+        'rcl', 'rcr', 'rdmsr', 'rdpmc', 'rdshr', 'rdtsc', 'rep', 'repe',
+        'repne', 'repnz', 'repz', 'ret', 'retf', 'retn', 'rol', 'ror',
+        'rsdc', 'rsldt', 'rsm', 'sahf', 'sal', 'salc', 'sar', 'sbb',
+        'scasb', 'scasd', 'scasw', 'seta', 'setae', 'setb', 'setbe',
+        'setc', 'setcxz', 'sete', 'setg', 'setge', 'setl', 'setle',
+        'setna', 'setnae', 'setnb', 'setnbe', 'setnc', 'setne', 'setng',
+        'setnge', 'setnl', 'setnle', 'setno', 'setnp', 'setns', 'setnz',
+        'seto', 'setp', 'setpe', 'setpo', 'sets', 'setz', 'sgdt', 'shl',
+        'shld', 'shr', 'shrd', 'sidt', 'sldt', 'smi', 'smint', 'smintold',
+        'smsw', 'stc', 'std', 'sti', 'stosb', 'stosd', 'stosw', 'str',
+        'sub', 'svdc', 'svldt', 'svts', 'syscall', 'sysenter', 'sysexit',
+        'sysret', 'test', 'ud1', 'ud2', 'umov', 'verr', 'verw', 'wait',
+        'wbinvd', 'wrmsr', 'wrshr', 'xadd', 'xbts', 'xchg', 'xlat',
+        'xlatb', 'xor'
+    }
+
+    PORTUGOL_KEYWORDS = (
+        'aleatorio',
+        'algoritmo',
+        'arquivo',
+        'ate',
+        'caso',
+        'cronometro',
+        'debug',
+        'e',
+        'eco',
+        'enquanto',
+        'entao',
+        'escolha',
+        'escreva',
+        'escreval',
+        'faca',
+        'falso',
+        'fimalgoritmo',
+        'fimenquanto',
+        'fimescolha',
+        'fimfuncao',
+        'fimpara',
+        'fimprocedimento',
+        'fimrepita',
+        'fimse',
+        'funcao',
+        'inicio',
+        'int',
+        'interrompa',
+        'leia',
+        'limpatela',
+        'mod',
+        'nao',
+        'ou',
+        'outrocaso',
+        'para',
+        'passo',
+        'pausa',
+        'procedimento',
+        'repita',
+        'retorne',
+        'se',
+        'senao',
+        'timer',
+        'var',
+        'vetor',
+        'verdadeiro',
+        'xou',
+        'div',
+        'mod',
+        'abs',
+        'arccos',
+        'arcsen',
+        'arctan',
+        'cos',
+        'cotan',
+        'Exp',
+        'grauprad',
+        'int',
+        'log',
+        'logn',
+        'pi',
+        'quad',
+        'radpgrau',
+        'raizq',
+        'rand',
+        'randi',
+        'sen',
+        'Tan',
+        'asc',
+        'carac',
+        'caracpnum',
+        'compr',
+        'copia',
+        'maiusc',
+        'minusc',
+        'numpcarac',
+        'pos',
+    )
+
+    PORTUGOL_BUILTIN_TYPES = {
+        'inteiro', 'real', 'caractere', 'logico'
+    }
+
+    def __init__(self, **options):
+        Lexer.__init__(self, **options)
+        self.keywords = set()
+        self.builtins = set()
+        if get_bool_opt(options, 'portugol', False):
+            self.keywords.update(self.PORTUGOL_KEYWORDS)
+            self.builtins.update(self.PORTUGOL_BUILTIN_TYPES)
+            self.is_portugol = True
+        else:
+            self.is_portugol = False
+
+            if get_bool_opt(options, 'turbopascal', True):
+                self.keywords.update(self.TURBO_PASCAL_KEYWORDS)
+            if get_bool_opt(options, 'delphi', True):
+                self.keywords.update(self.DELPHI_KEYWORDS)
+            if get_bool_opt(options, 'freepascal', True):
+                self.keywords.update(self.FREE_PASCAL_KEYWORDS)
+            for unit in get_list_opt(options, 'units', list(self.BUILTIN_UNITS)):
+                self.builtins.update(self.BUILTIN_UNITS[unit])
+
+    def get_tokens_unprocessed(self, text):
+        scanner = Scanner(text, re.DOTALL | re.MULTILINE | re.IGNORECASE)
+        stack = ['initial']
+        in_function_block = False
+        in_property_block = False
+        was_dot = False
+        next_token_is_function = False
+        next_token_is_property = False
+        collect_labels = False
+        block_labels = set()
+        brace_balance = [0, 0]
+
+        while not scanner.eos:
+            token = Error
+
+            if stack[-1] == 'initial':
+                if scanner.scan(r'\s+'):
+                    token = Whitespace
+                elif not self.is_portugol and scanner.scan(r'\{.*?\}|\(\*.*?\*\)'):
+                    if scanner.match.startswith('$'):
+                        token = Comment.Preproc
+                    else:
+                        token = Comment.Multiline
+                elif scanner.scan(r'//.*?$'):
+                    token = Comment.Single
+                elif self.is_portugol and scanner.scan(r'(<\-)|(>=)|(<=)|%|<|>|-|\+|\*|\=|(<>)|\/|\.|:|,'):
+                    token = Operator
+                elif not self.is_portugol and scanner.scan(r'[-+*\/=<>:;,.@\^]'):
+                    token = Operator
+                    # stop label highlighting on next ";"
+                    if collect_labels and scanner.match == ';':
+                        collect_labels = False
+                elif scanner.scan(r'[\(\)\[\]]+'):
+                    token = Punctuation
+                    # abort function naming ``foo = Function(...)``
+                    next_token_is_function = False
+                    # if we are in a function block we count the open
+                    # braces because ootherwise it's impossible to
+                    # determine the end of the modifier context
+                    if in_function_block or in_property_block:
+                        if scanner.match == '(':
+                            brace_balance[0] += 1
+                        elif scanner.match == ')':
+                            brace_balance[0] -= 1
+                        elif scanner.match == '[':
+                            brace_balance[1] += 1
+                        elif scanner.match == ']':
+                            brace_balance[1] -= 1
+                elif scanner.scan(r'[A-Za-z_][A-Za-z_0-9]*'):
+                    lowercase_name = scanner.match.lower()
+                    if lowercase_name == 'result':
+                        token = Name.Builtin.Pseudo
+                    elif lowercase_name in self.keywords:
+                        token = Keyword
+                        # if we are in a special block and a
+                        # block ending keyword occurs (and the parenthesis
+                        # is balanced) we end the current block context
+                        if self.is_portugol:
+                            if lowercase_name in ('funcao', 'procedimento'):
+                                in_function_block = True
+                                next_token_is_function = True
+                        else:
+                            if (in_function_block or in_property_block) and \
+                                    lowercase_name in self.BLOCK_KEYWORDS and \
+                                    brace_balance[0] <= 0 and \
+                                    brace_balance[1] <= 0:
+                                in_function_block = False
+                                in_property_block = False
+                                brace_balance = [0, 0]
+                                block_labels = set()
+                            if lowercase_name in ('label', 'goto'):
+                                collect_labels = True
+                            elif lowercase_name == 'asm':
+                                stack.append('asm')
+                            elif lowercase_name == 'property':
+                                in_property_block = True
+                                next_token_is_property = True
+                            elif lowercase_name in ('procedure', 'operator',
+                                                    'function', 'constructor',
+                                                    'destructor'):
+                                in_function_block = True
+                                next_token_is_function = True
+                    # we are in a function block and the current name
+                    # is in the set of registered modifiers. highlight
+                    # it as pseudo keyword
+                    elif not self.is_portugol and in_function_block and \
+                            lowercase_name in self.FUNCTION_MODIFIERS:
+                        token = Keyword.Pseudo
+                    # if we are in a property highlight some more
+                    # modifiers
+                    elif not self.is_portugol and in_property_block and \
+                            lowercase_name in ('read', 'write'):
+                        token = Keyword.Pseudo
+                        next_token_is_function = True
+                    # if the last iteration set next_token_is_function
+                    # to true we now want this name highlighted as
+                    # function. so do that and reset the state
+                    elif next_token_is_function:
+                        # Look if the next token is a dot. If yes it's
+                        # not a function, but a class name and the
+                        # part after the dot a function name
+                        if not self.is_portugol and scanner.test(r'\s*\.\s*'):
+                            token = Name.Class
+                        # it's not a dot, our job is done
+                        else:
+                            token = Name.Function
+                            next_token_is_function = False
+
+                            if self.is_portugol:
+                                block_labels.add(scanner.match.lower())
+
+                    # same for properties
+                    elif not self.is_portugol and next_token_is_property:
+                        token = Name.Property
+                        next_token_is_property = False
+                    # Highlight this token as label and add it
+                    # to the list of known labels
+                    elif not self.is_portugol and collect_labels:
+                        token = Name.Label
+                        block_labels.add(scanner.match.lower())
+                    # name is in list of known labels
+                    elif lowercase_name in block_labels:
+                        token = Name.Label
+                    elif self.is_portugol and lowercase_name in self.PORTUGOL_BUILTIN_TYPES:
+                        token = Keyword.Type
+                    elif not self.is_portugol and lowercase_name in self.BUILTIN_TYPES:
+                        token = Keyword.Type
+                    elif not self.is_portugol and lowercase_name in self.DIRECTIVES:
+                        token = Keyword.Pseudo
+                    # builtins are just builtins if the token
+                    # before isn't a dot
+                    elif not self.is_portugol and not was_dot and lowercase_name in self.builtins:
+                        token = Name.Builtin
+                    else:
+                        token = Name
+                elif self.is_portugol and scanner.scan(r"\""):
+                    token = String
+                    stack.append('string')
+                elif not self.is_portugol and scanner.scan(r"'"):
+                    token = String
+                    stack.append('string')
+                elif not self.is_portugol and scanner.scan(r'\#(\d+|\$[0-9A-Fa-f]+)'):
+                    token = String.Char
+                elif not self.is_portugol and scanner.scan(r'\$[0-9A-Fa-f]+'):
+                    token = Number.Hex
+                elif scanner.scan(r'\d+(?![eE]|\.[^.])'):
+                    token = Number.Integer
+                elif scanner.scan(r'\d+(\.\d+([eE][+-]?\d+)?|[eE][+-]?\d+)'):
+                    token = Number.Float
+                else:
+                    # if the stack depth is deeper than once, pop
+                    if len(stack) > 1:
+                        stack.pop()
+                    scanner.get_char()
+
+            elif stack[-1] == 'string':
+                if self.is_portugol:
+                    if scanner.scan(r"''"):
+                        token = String.Escape
+                    elif scanner.scan(r"\""):
+                        token = String
+                        stack.pop()
+                    elif scanner.scan(r"[^\"]*"):
+                        token = String
+                    else:
+                        scanner.get_char()
+                        stack.pop()
+                else:
+                    if scanner.scan(r"''"):
+                        token = String.Escape
+                    elif scanner.scan(r"'"):
+                        token = String
+                        stack.pop()
+                    elif scanner.scan(r"[^']*"):
+                        token = String
+                    else:
+                        scanner.get_char()
+                        stack.pop()
+            elif not self.is_portugol and stack[-1] == 'asm':
+                if scanner.scan(r'\s+'):
+                    token = Whitespace
+                elif scanner.scan(r'end'):
+                    token = Keyword
+                    stack.pop()
+                elif scanner.scan(r'\{.*?\}|\(\*.*?\*\)'):
+                    if scanner.match.startswith('$'):
+                        token = Comment.Preproc
+                    else:
+                        token = Comment.Multiline
+                elif scanner.scan(r'//.*?$'):
+                    token = Comment.Single
+                elif scanner.scan(r"'"):
+                    token = String
+                    stack.append('string')
+                elif scanner.scan(r'@@[A-Za-z_][A-Za-z_0-9]*'):
+                    token = Name.Label
+                elif scanner.scan(r'[A-Za-z_][A-Za-z_0-9]*'):
+                    lowercase_name = scanner.match.lower()
+                    if lowercase_name in self.ASM_INSTRUCTIONS:
+                        token = Keyword
+                    elif lowercase_name in self.ASM_REGISTERS:
+                        token = Name.Builtin
+                    else:
+                        token = Name
+                elif scanner.scan(r'[-+*\/=<>:;,.@\^]+'):
+                    token = Operator
+                elif scanner.scan(r'[\(\)\[\]]+'):
+                    token = Punctuation
+                elif scanner.scan(r'\$[0-9A-Fa-f]+'):
+                    token = Number.Hex
+                elif scanner.scan(r'\d+(?![eE]|\.[^.])'):
+                    token = Number.Integer
+                elif scanner.scan(r'\d+(\.\d+([eE][+-]?\d+)?|[eE][+-]?\d+)'):
+                    token = Number.Float
+                else:
+                    scanner.get_char()
+                    stack.pop()
+
+            # save the dot!!!11
+            if not self.is_portugol and scanner.match.strip():
+                was_dot = scanner.match == '.'
+
+            yield scanner.start_pos, token, scanner.match or ''
diff --git a/pygments/lexers/pawn.py b/pygments/lexers/pawn.py
new file mode 100644
index 0000000000000000000000000000000000000000..99d9c963d1473206ca68c389b068ba783783ce92
--- /dev/null
+++ b/pygments/lexers/pawn.py
@@ -0,0 +1,202 @@
+"""
+    pygments.lexers.pawn
+    ~~~~~~~~~~~~~~~~~~~~
+
+    Lexers for the Pawn languages.
+
+    :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS.
+    :license: BSD, see LICENSE for details.
+"""
+
+from pygments.lexer import RegexLexer
+from pygments.token import Text, Comment, Operator, Keyword, Name, String, \
+    Number, Punctuation
+from pygments.util import get_bool_opt
+
+__all__ = ['SourcePawnLexer', 'PawnLexer']
+
+
+class SourcePawnLexer(RegexLexer):
+    """
+    For SourcePawn source code with preprocessor directives.
+    """
+    name = 'SourcePawn'
+    aliases = ['sp']
+    filenames = ['*.sp']
+    mimetypes = ['text/x-sourcepawn']
+    url = 'https://github.com/alliedmodders/sourcepawn'
+    version_added = '1.6'
+
+    #: optional Comment or Whitespace
+    _ws = r'(?:\s|//.*?\n|/\*.*?\*/)+'
+    #: only one /* */ style comment
+    _ws1 = r'\s*(?:/[*].*?[*]/\s*)*'
+
+    tokens = {
+        'root': [
+            # preprocessor directives: without whitespace
+            (r'^#if\s+0', Comment.Preproc, 'if0'),
+            ('^#', Comment.Preproc, 'macro'),
+            # or with whitespace
+            ('^' + _ws1 + r'#if\s+0', Comment.Preproc, 'if0'),
+            ('^' + _ws1 + '#', Comment.Preproc, 'macro'),
+            (r'\n', Text),
+            (r'\s+', Text),
+            (r'\\\n', Text),  # line continuation
+            (r'/(\\\n)?/(\n|(.|\n)*?[^\\]\n)', Comment.Single),
+            (r'/(\\\n)?\*(.|\n)*?\*(\\\n)?/', Comment.Multiline),
+            (r'[{}]', Punctuation),
+            (r'L?"', String, 'string'),
+            (r"L?'(\\.|\\[0-7]{1,3}|\\x[a-fA-F0-9]{1,2}|[^\\\'\n])'", String.Char),
+            (r'(\d+\.\d*|\.\d+|\d+)[eE][+-]?\d+[LlUu]*', Number.Float),
+            (r'(\d+\.\d*|\.\d+|\d+[fF])[fF]?', Number.Float),
+            (r'0x[0-9a-fA-F]+[LlUu]*', Number.Hex),
+            (r'0[0-7]+[LlUu]*', Number.Oct),
+            (r'\d+[LlUu]*', Number.Integer),
+            (r'[~!%^&*+=|?:<>/-]', Operator),
+            (r'[()\[\],.;]', Punctuation),
+            (r'(case|const|continue|native|'
+             r'default|else|enum|for|if|new|operator|'
+             r'public|return|sizeof|static|decl|struct|switch)\b', Keyword),
+            (r'(bool|Float)\b', Keyword.Type),
+            (r'(true|false)\b', Keyword.Constant),
+            (r'[a-zA-Z_]\w*', Name),
+        ],
+        'string': [
+            (r'"', String, '#pop'),
+            (r'\\([\\abfnrtv"\']|x[a-fA-F0-9]{2,4}|[0-7]{1,3})', String.Escape),
+            (r'[^\\"\n]+', String),  # all other characters
+            (r'\\\n', String),       # line continuation
+            (r'\\', String),         # stray backslash
+        ],
+        'macro': [
+            (r'[^/\n]+', Comment.Preproc),
+            (r'/\*(.|\n)*?\*/', Comment.Multiline),
+            (r'//.*?\n', Comment.Single, '#pop'),
+            (r'/', Comment.Preproc),
+            (r'(?<=\\)\n', Comment.Preproc),
+            (r'\n', Comment.Preproc, '#pop'),
+        ],
+        'if0': [
+            (r'^\s*#if.*?(?/-]', Operator),
+            (r'[()\[\],.;]', Punctuation),
+            (r'(switch|case|default|const|new|static|char|continue|break|'
+             r'if|else|for|while|do|operator|enum|'
+             r'public|return|sizeof|tagof|state|goto)\b', Keyword),
+            (r'(bool|Float)\b', Keyword.Type),
+            (r'(true|false)\b', Keyword.Constant),
+            (r'[a-zA-Z_]\w*', Name),
+        ],
+        'string': [
+            (r'"', String, '#pop'),
+            (r'\\([\\abfnrtv"\']|x[a-fA-F0-9]{2,4}|[0-7]{1,3})', String.Escape),
+            (r'[^\\"\n]+', String),  # all other characters
+            (r'\\\n', String),       # line continuation
+            (r'\\', String),         # stray backslash
+        ],
+        'macro': [
+            (r'[^/\n]+', Comment.Preproc),
+            (r'/\*(.|\n)*?\*/', Comment.Multiline),
+            (r'//.*?\n', Comment.Single, '#pop'),
+            (r'/', Comment.Preproc),
+            (r'(?<=\\)\n', Comment.Preproc),
+            (r'\n', Comment.Preproc, '#pop'),
+        ],
+        'if0': [
+            (r'^\s*#if.*?(?<-]', Operator),
+            (r'[a-zA-Z][a-zA-Z0-9_-]*', Name),
+            (r'\?[a-zA-Z][a-zA-Z0-9_-]*', Name.Variable),
+            (r'[0-9]+\.[0-9]+', Number.Float),
+            (r'[0-9]+', Number.Integer),
+        ],
+        'keywords': [
+            (words((
+                ':requirements', ':types', ':constants',
+                ':predicates', ':functions', ':action', ':agent',
+                ':parameters', ':precondition', ':effect',
+                ':durative-action', ':duration', ':condition',
+                ':derived', ':domain', ':objects', ':init',
+                ':goal', ':metric', ':length', ':serial', ':parallel',
+                # the following are requirements
+                ':strips', ':typing', ':negative-preconditions',
+                ':disjunctive-preconditions', ':equality',
+                ':existential-preconditions', ':universal-preconditions',
+                ':conditional-effects', ':fluents', ':numeric-fluents',
+                ':object-fluents', ':adl', ':durative-actions',
+                ':continuous-effects', ':derived-predicates',
+                ':time-intial-literals', ':preferences',
+                ':constraints', ':action-costs', ':multi-agent',
+                ':unfactored-privacy', ':factored-privacy',
+                ':non-deterministic'
+                ), suffix=r'\b'), Keyword)
+        ],
+        'builtins': [
+            (words((
+                'define', 'domain', 'object', 'either', 'and',
+                'forall', 'preference', 'imply', 'or', 'exists',
+                'not', 'when', 'assign', 'scale-up', 'scale-down',
+                'increase', 'decrease', 'at', 'over', 'start',
+                'end', 'all', 'problem', 'always', 'sometime',
+                'within', 'at-most-once', 'sometime-after',
+                'sometime-before', 'always-within', 'hold-during',
+                'hold-after', 'minimize', 'maximize',
+                'total-time', 'is-violated'), suffix=r'\b'),
+                Name.Builtin)
+        ]
+    }
+
diff --git a/pygments/lexers/perl.py b/pygments/lexers/perl.py
new file mode 100644
index 0000000000000000000000000000000000000000..33f91f5800b7e40c8bb195f11974454eb3e6aabf
--- /dev/null
+++ b/pygments/lexers/perl.py
@@ -0,0 +1,733 @@
+"""
+    pygments.lexers.perl
+    ~~~~~~~~~~~~~~~~~~~~
+
+    Lexers for Perl, Raku and related languages.
+
+    :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS.
+    :license: BSD, see LICENSE for details.
+"""
+
+import re
+
+from pygments.lexer import RegexLexer, ExtendedRegexLexer, include, bygroups, \
+    using, this, default, words
+from pygments.token import Text, Comment, Operator, Keyword, Name, String, \
+    Number, Punctuation, Whitespace
+from pygments.util import shebang_matches
+
+__all__ = ['PerlLexer', 'Perl6Lexer']
+
+
+class PerlLexer(RegexLexer):
+    """
+    For Perl source code.
+    """
+
+    name = 'Perl'
+    url = 'https://www.perl.org'
+    aliases = ['perl', 'pl']
+    filenames = ['*.pl', '*.pm', '*.t', '*.perl']
+    mimetypes = ['text/x-perl', 'application/x-perl']
+    version_added = ''
+
+    flags = re.DOTALL | re.MULTILINE
+    # TODO: give this to a perl guy who knows how to parse perl...
+    tokens = {
+        'balanced-regex': [
+            (r'/(\\\\|\\[^\\]|[^\\/])*/[egimosx]*', String.Regex, '#pop'),
+            (r'!(\\\\|\\[^\\]|[^\\!])*![egimosx]*', String.Regex, '#pop'),
+            (r'\\(\\\\|[^\\])*\\[egimosx]*', String.Regex, '#pop'),
+            (r'\{(\\\\|\\[^\\]|[^\\}])*\}[egimosx]*', String.Regex, '#pop'),
+            (r'<(\\\\|\\[^\\]|[^\\>])*>[egimosx]*', String.Regex, '#pop'),
+            (r'\[(\\\\|\\[^\\]|[^\\\]])*\][egimosx]*', String.Regex, '#pop'),
+            (r'\((\\\\|\\[^\\]|[^\\)])*\)[egimosx]*', String.Regex, '#pop'),
+            (r'@(\\\\|\\[^\\]|[^\\@])*@[egimosx]*', String.Regex, '#pop'),
+            (r'%(\\\\|\\[^\\]|[^\\%])*%[egimosx]*', String.Regex, '#pop'),
+            (r'\$(\\\\|\\[^\\]|[^\\$])*\$[egimosx]*', String.Regex, '#pop'),
+        ],
+        'root': [
+            (r'\A\#!.+?$', Comment.Hashbang),
+            (r'\#.*?$', Comment.Single),
+            (r'^=[a-zA-Z0-9]+\s+.*?\n=cut', Comment.Multiline),
+            (words((
+                'case', 'continue', 'do', 'else', 'elsif', 'for', 'foreach',
+                'if', 'last', 'my', 'next', 'our', 'redo', 'reset', 'then',
+                'unless', 'until', 'while', 'print', 'new', 'BEGIN',
+                'CHECK', 'INIT', 'END', 'return'), suffix=r'\b'),
+             Keyword),
+            (r'(format)(\s+)(\w+)(\s*)(=)(\s*\n)',
+             bygroups(Keyword, Whitespace, Name, Whitespace, Punctuation, Whitespace), 'format'),
+            (r'(eq|lt|gt|le|ge|ne|not|and|or|cmp)\b', Operator.Word),
+            # common delimiters
+            (r's/(\\\\|\\[^\\]|[^\\/])*/(\\\\|\\[^\\]|[^\\/])*/[egimosx]*',
+                String.Regex),
+            (r's!(\\\\|\\!|[^!])*!(\\\\|\\!|[^!])*![egimosx]*', String.Regex),
+            (r's\\(\\\\|[^\\])*\\(\\\\|[^\\])*\\[egimosx]*', String.Regex),
+            (r's@(\\\\|\\[^\\]|[^\\@])*@(\\\\|\\[^\\]|[^\\@])*@[egimosx]*',
+                String.Regex),
+            (r's%(\\\\|\\[^\\]|[^\\%])*%(\\\\|\\[^\\]|[^\\%])*%[egimosx]*',
+                String.Regex),
+            # balanced delimiters
+            (r's\{(\\\\|\\[^\\]|[^\\}])*\}\s*', String.Regex, 'balanced-regex'),
+            (r's<(\\\\|\\[^\\]|[^\\>])*>\s*', String.Regex, 'balanced-regex'),
+            (r's\[(\\\\|\\[^\\]|[^\\\]])*\]\s*', String.Regex,
+                'balanced-regex'),
+            (r's\((\\\\|\\[^\\]|[^\\)])*\)\s*', String.Regex,
+                'balanced-regex'),
+
+            (r'm?/(\\\\|\\[^\\]|[^\\/\n])*/[gcimosx]*', String.Regex),
+            (r'm(?=[/!\\{<\[(@%$])', String.Regex, 'balanced-regex'),
+            (r'((?<==~)|(?<=\())\s*/(\\\\|\\[^\\]|[^\\/])*/[gcimosx]*',
+                String.Regex),
+            (r'\s+', Whitespace),
+            (words((
+                'abs', 'accept', 'alarm', 'atan2', 'bind', 'binmode', 'bless', 'caller', 'chdir',
+                'chmod', 'chomp', 'chop', 'chown', 'chr', 'chroot', 'close', 'closedir', 'connect',
+                'continue', 'cos', 'crypt', 'dbmclose', 'dbmopen', 'defined', 'delete', 'die',
+                'dump', 'each', 'endgrent', 'endhostent', 'endnetent', 'endprotoent',
+                'endpwent', 'endservent', 'eof', 'eval', 'exec', 'exists', 'exit', 'exp', 'fcntl',
+                'fileno', 'flock', 'fork', 'format', 'formline', 'getc', 'getgrent', 'getgrgid',
+                'getgrnam', 'gethostbyaddr', 'gethostbyname', 'gethostent', 'getlogin',
+                'getnetbyaddr', 'getnetbyname', 'getnetent', 'getpeername', 'getpgrp',
+                'getppid', 'getpriority', 'getprotobyname', 'getprotobynumber',
+                'getprotoent', 'getpwent', 'getpwnam', 'getpwuid', 'getservbyname',
+                'getservbyport', 'getservent', 'getsockname', 'getsockopt', 'glob', 'gmtime',
+                'goto', 'grep', 'hex', 'import', 'index', 'int', 'ioctl', 'join', 'keys', 'kill', 'last',
+                'lc', 'lcfirst', 'length', 'link', 'listen', 'local', 'localtime', 'log', 'lstat',
+                'map', 'mkdir', 'msgctl', 'msgget', 'msgrcv', 'msgsnd', 'my', 'next', 'oct', 'open',
+                'opendir', 'ord', 'our', 'pack', 'pipe', 'pop', 'pos', 'printf',
+                'prototype', 'push', 'quotemeta', 'rand', 'read', 'readdir',
+                'readline', 'readlink', 'readpipe', 'recv', 'redo', 'ref', 'rename',
+                'reverse', 'rewinddir', 'rindex', 'rmdir', 'scalar', 'seek', 'seekdir',
+                'select', 'semctl', 'semget', 'semop', 'send', 'setgrent', 'sethostent', 'setnetent',
+                'setpgrp', 'setpriority', 'setprotoent', 'setpwent', 'setservent',
+                'setsockopt', 'shift', 'shmctl', 'shmget', 'shmread', 'shmwrite', 'shutdown',
+                'sin', 'sleep', 'socket', 'socketpair', 'sort', 'splice', 'split', 'sprintf', 'sqrt',
+                'srand', 'stat', 'study', 'substr', 'symlink', 'syscall', 'sysopen', 'sysread',
+                'sysseek', 'system', 'syswrite', 'tell', 'telldir', 'tie', 'tied', 'time', 'times', 'tr',
+                'truncate', 'uc', 'ucfirst', 'umask', 'undef', 'unlink', 'unpack', 'unshift', 'untie',
+                'utime', 'values', 'vec', 'wait', 'waitpid', 'wantarray', 'warn', 'write'), suffix=r'\b'),
+             Name.Builtin),
+            (r'((__(DATA|DIE|WARN)__)|(STD(IN|OUT|ERR)))\b', Name.Builtin.Pseudo),
+            (r'(<<)([\'"]?)([a-zA-Z_]\w*)(\2;?\n.*?\n)(\3)(\n)',
+             bygroups(String, String, String.Delimiter, String, String.Delimiter, Whitespace)),
+            (r'__END__', Comment.Preproc, 'end-part'),
+            (r'\$\^[ADEFHILMOPSTWX]', Name.Variable.Global),
+            (r"\$[\\\"\[\]'&`+*.,;=%~?@$!<>(^|/-](?!\w)", Name.Variable.Global),
+            (r'[$@%#]+', Name.Variable, 'varname'),
+            (r'0_?[0-7]+(_[0-7]+)*', Number.Oct),
+            (r'0x[0-9A-Fa-f]+(_[0-9A-Fa-f]+)*', Number.Hex),
+            (r'0b[01]+(_[01]+)*', Number.Bin),
+            (r'(?i)(\d*(_\d*)*\.\d+(_\d*)*|\d+(_\d*)*\.\d+(_\d*)*)(e[+-]?\d+)?',
+             Number.Float),
+            (r'(?i)\d+(_\d*)*e[+-]?\d+(_\d*)*', Number.Float),
+            (r'\d+(_\d+)*', Number.Integer),
+            (r"'(\\\\|\\[^\\]|[^'\\])*'", String),
+            (r'"(\\\\|\\[^\\]|[^"\\])*"', String),
+            (r'`(\\\\|\\[^\\]|[^`\\])*`', String.Backtick),
+            (r'<([^\s>]+)>', String.Regex),
+            (r'(q|qq|qw|qr|qx)\{', String.Other, 'cb-string'),
+            (r'(q|qq|qw|qr|qx)\(', String.Other, 'rb-string'),
+            (r'(q|qq|qw|qr|qx)\[', String.Other, 'sb-string'),
+            (r'(q|qq|qw|qr|qx)\<', String.Other, 'lt-string'),
+            (r'(q|qq|qw|qr|qx)([\W_])(.|\n)*?\2', String.Other),
+            (r'(package)(\s+)([a-zA-Z_]\w*(?:::[a-zA-Z_]\w*)*)',
+             bygroups(Keyword, Whitespace, Name.Namespace)),
+            (r'(use|require|no)(\s+)([a-zA-Z_]\w*(?:::[a-zA-Z_]\w*)*)',
+             bygroups(Keyword, Whitespace, Name.Namespace)),
+            (r'(sub)(\s+)', bygroups(Keyword, Whitespace), 'funcname'),
+            (words((
+                'no', 'package', 'require', 'use'), suffix=r'\b'),
+             Keyword),
+            (r'(\[\]|\*\*|::|<<|>>|>=|<=>|<=|={3}|!=|=~|'
+             r'!~|&&?|\|\||\.{1,3})', Operator),
+            (r'[-+/*%=<>&^|!\\~]=?', Operator),
+            (r'[()\[\]:;,<>/?{}]', Punctuation),  # yes, there's no shortage
+                                                  # of punctuation in Perl!
+            (r'(?=\w)', Name, 'name'),
+        ],
+        'format': [
+            (r'\.\n', String.Interpol, '#pop'),
+            (r'[^\n]*\n', String.Interpol),
+        ],
+        'varname': [
+            (r'\s+', Whitespace),
+            (r'\{', Punctuation, '#pop'),    # hash syntax?
+            (r'\)|,', Punctuation, '#pop'),  # argument specifier
+            (r'\w+::', Name.Namespace),
+            (r'[\w:]+', Name.Variable, '#pop'),
+        ],
+        'name': [
+            (r'[a-zA-Z_]\w*(::[a-zA-Z_]\w*)*(::)?(?=\s*->)', Name.Namespace, '#pop'),
+            (r'[a-zA-Z_]\w*(::[a-zA-Z_]\w*)*::', Name.Namespace, '#pop'),
+            (r'[\w:]+', Name, '#pop'),
+            (r'[A-Z_]+(?=\W)', Name.Constant, '#pop'),
+            (r'(?=\W)', Text, '#pop'),
+        ],
+        'funcname': [
+            (r'[a-zA-Z_]\w*[!?]?', Name.Function),
+            (r'\s+', Whitespace),
+            # argument declaration
+            (r'(\([$@%]*\))(\s*)', bygroups(Punctuation, Whitespace)),
+            (r';', Punctuation, '#pop'),
+            (r'.*?\{', Punctuation, '#pop'),
+        ],
+        'cb-string': [
+            (r'\\[{}\\]', String.Other),
+            (r'\\', String.Other),
+            (r'\{', String.Other, 'cb-string'),
+            (r'\}', String.Other, '#pop'),
+            (r'[^{}\\]+', String.Other)
+        ],
+        'rb-string': [
+            (r'\\[()\\]', String.Other),
+            (r'\\', String.Other),
+            (r'\(', String.Other, 'rb-string'),
+            (r'\)', String.Other, '#pop'),
+            (r'[^()]+', String.Other)
+        ],
+        'sb-string': [
+            (r'\\[\[\]\\]', String.Other),
+            (r'\\', String.Other),
+            (r'\[', String.Other, 'sb-string'),
+            (r'\]', String.Other, '#pop'),
+            (r'[^\[\]]+', String.Other)
+        ],
+        'lt-string': [
+            (r'\\[<>\\]', String.Other),
+            (r'\\', String.Other),
+            (r'\<', String.Other, 'lt-string'),
+            (r'\>', String.Other, '#pop'),
+            (r'[^<>]+', String.Other)
+        ],
+        'end-part': [
+            (r'.+', Comment.Preproc, '#pop')
+        ]
+    }
+
+    def analyse_text(text):
+        if shebang_matches(text, r'perl'):
+            return True
+
+        result = 0
+
+        if re.search(r'(?:my|our)\s+[$@%(]', text):
+            result += 0.9
+
+        if ':=' in text:
+            # := is not valid Perl, but it appears in unicon, so we should
+            # become less confident if we think we found Perl with :=
+            result /= 2
+
+        return result
+
+
+class Perl6Lexer(ExtendedRegexLexer):
+    """
+    For Raku (a.k.a. Perl 6) source code.
+    """
+
+    name = 'Perl6'
+    url = 'https://www.raku.org'
+    aliases = ['perl6', 'pl6', 'raku']
+    filenames = ['*.pl', '*.pm', '*.nqp', '*.p6', '*.6pl', '*.p6l', '*.pl6',
+                 '*.6pm', '*.p6m', '*.pm6', '*.t', '*.raku', '*.rakumod',
+                 '*.rakutest', '*.rakudoc']
+    mimetypes = ['text/x-perl6', 'application/x-perl6']
+    version_added = '2.0'
+    flags = re.MULTILINE | re.DOTALL
+
+    PERL6_IDENTIFIER_RANGE = r"['\w:-]"
+
+    PERL6_KEYWORDS = (
+        #Phasers
+        'BEGIN','CATCH','CHECK','CLOSE','CONTROL','DOC','END','ENTER','FIRST',
+        'INIT','KEEP','LAST','LEAVE','NEXT','POST','PRE','QUIT','UNDO',
+        #Keywords
+        'anon','augment','but','class','constant','default','does','else',
+        'elsif','enum','for','gather','given','grammar','has','if','import',
+        'is','let','loop','made','make','method','module','multi','my','need',
+        'orwith','our','proceed','proto','repeat','require','return',
+        'return-rw','returns','role','rule','state','sub','submethod','subset',
+        'succeed','supersede','token','try','unit','unless','until','use',
+        'when','while','with','without',
+        #Traits
+        'export','native','repr','required','rw','symbol',
+    )
+
+    PERL6_BUILTINS = (
+        'ACCEPTS','abs','abs2rel','absolute','accept','accessed','acos',
+        'acosec','acosech','acosh','acotan','acotanh','acquire','act','action',
+        'actions','add','add_attribute','add_enum_value','add_fallback',
+        'add_method','add_parent','add_private_method','add_role','add_trustee',
+        'adverb','after','all','allocate','allof','allowed','alternative-names',
+        'annotations','antipair','antipairs','any','anyof','app_lifetime',
+        'append','arch','archname','args','arity','Array','asec','asech','asin',
+        'asinh','ASSIGN-KEY','ASSIGN-POS','assuming','ast','at','atan','atan2',
+        'atanh','AT-KEY','atomic-assign','atomic-dec-fetch','atomic-fetch',
+        'atomic-fetch-add','atomic-fetch-dec','atomic-fetch-inc',
+        'atomic-fetch-sub','atomic-inc-fetch','AT-POS','attributes','auth',
+        'await','backtrace','Bag','BagHash','bail-out','base','basename',
+        'base-repeating','batch','BIND-KEY','BIND-POS','bind-stderr',
+        'bind-stdin','bind-stdout','bind-udp','bits','bless','block','Bool',
+        'bool-only','bounds','break','Bridge','broken','BUILD','build-date',
+        'bytes','cache','callframe','calling-package','CALL-ME','callsame',
+        'callwith','can','cancel','candidates','cando','can-ok','canonpath',
+        'caps','caption','Capture','cas','catdir','categorize','categorize-list',
+        'catfile','catpath','cause','ceiling','cglobal','changed','Channel',
+        'chars','chdir','child','child-name','child-typename','chmod','chomp',
+        'chop','chr','chrs','chunks','cis','classify','classify-list','cleanup',
+        'clone','close','closed','close-stdin','cmp-ok','code','codes','collate',
+        'column','comb','combinations','command','comment','compiler','Complex',
+        'compose','compose_type','composer','condition','config',
+        'configure_destroy','configure_type_checking','conj','connect',
+        'constraints','construct','contains','contents','copy','cos','cosec',
+        'cosech','cosh','cotan','cotanh','count','count-only','cpu-cores',
+        'cpu-usage','CREATE','create_type','cross','cue','curdir','curupdir','d',
+        'Date','DateTime','day','daycount','day-of-month','day-of-week',
+        'day-of-year','days-in-month','declaration','decode','decoder','deepmap',
+        'default','defined','DEFINITE','delayed','DELETE-KEY','DELETE-POS',
+        'denominator','desc','DESTROY','destroyers','devnull','diag',
+        'did-you-mean','die','dies-ok','dir','dirname','dir-sep','DISTROnames',
+        'do','does','does-ok','done','done-testing','duckmap','dynamic','e',
+        'eager','earlier','elems','emit','enclosing','encode','encoder',
+        'encoding','end','ends-with','enum_from_value','enum_value_list',
+        'enum_values','enums','eof','EVAL','eval-dies-ok','EVALFILE',
+        'eval-lives-ok','exception','excludes-max','excludes-min','EXISTS-KEY',
+        'EXISTS-POS','exit','exitcode','exp','expected','explicitly-manage',
+        'expmod','extension','f','fail','fails-like','fc','feature','file',
+        'filename','find_method','find_method_qualified','finish','first','flat',
+        'flatmap','flip','floor','flunk','flush','fmt','format','formatter',
+        'freeze','from','from-list','from-loop','from-posix','full',
+        'full-barrier','get','get_value','getc','gist','got','grab','grabpairs',
+        'grep','handle','handled','handles','hardware','has_accessor','Hash',
+        'head','headers','hh-mm-ss','hidden','hides','hour','how','hyper','id',
+        'illegal','im','in','indent','index','indices','indir','infinite',
+        'infix','infix:<+>','infix:<->','install_method_cache','Instant',
+        'instead','Int','int-bounds','interval','in-timezone','invalid-str',
+        'invert','invocant','IO','IO::Notification.watch-path','is_trusted',
+        'is_type','isa','is-absolute','isa-ok','is-approx','is-deeply',
+        'is-hidden','is-initial-thread','is-int','is-lazy','is-leap-year',
+        'isNaN','isnt','is-prime','is-relative','is-routine','is-setting',
+        'is-win','item','iterator','join','keep','kept','KERNELnames','key',
+        'keyof','keys','kill','kv','kxxv','l','lang','last','lastcall','later',
+        'lazy','lc','leading','level','like','line','lines','link','List',
+        'listen','live','lives-ok','local','lock','log','log10','lookup','lsb',
+        'made','MAIN','make','Map','match','max','maxpairs','merge','message',
+        'method','method_table','methods','migrate','min','minmax','minpairs',
+        'minute','misplaced','Mix','MixHash','mkdir','mode','modified','month',
+        'move','mro','msb','multi','multiness','my','name','named','named_names',
+        'narrow','nativecast','native-descriptor','nativesizeof','new','new_type',
+        'new-from-daycount','new-from-pairs','next','nextcallee','next-handle',
+        'nextsame','nextwith','NFC','NFD','NFKC','NFKD','nl-in','nl-out',
+        'nodemap','nok','none','norm','not','note','now','nude','Num',
+        'numerator','Numeric','of','offset','offset-in-hours','offset-in-minutes',
+        'ok','old','on-close','one','on-switch','open','opened','operation',
+        'optional','ord','ords','orig','os-error','osname','out-buffer','pack',
+        'package','package-kind','package-name','packages','pair','pairs',
+        'pairup','parameter','params','parent','parent-name','parents','parse',
+        'parse-base','parsefile','parse-names','parts','pass','path','path-sep',
+        'payload','peer-host','peer-port','periods','perl','permutations','phaser',
+        'pick','pickpairs','pid','placeholder','plan','plus','polar','poll',
+        'polymod','pop','pos','positional','posix','postfix','postmatch',
+        'precomp-ext','precomp-target','pred','prefix','prematch','prepend',
+        'print','printf','print-nl','print-to','private','private_method_table',
+        'proc','produce','Promise','prompt','protect','pull-one','push',
+        'push-all','push-at-least','push-exactly','push-until-lazy','put',
+        'qualifier-type','quit','r','race','radix','rand','range','Rat','raw',
+        're','read','readchars','readonly','ready','Real','reallocate','reals',
+        'reason','rebless','receive','recv','redispatcher','redo','reduce',
+        'rel2abs','relative','release','rename','repeated','replacement',
+        'report','reserved','resolve','restore','result','resume','rethrow',
+        'reverse','right','rindex','rmdir','role','roles_to_compose','rolish',
+        'roll','rootdir','roots','rotate','rotor','round','roundrobin',
+        'routine-type','run','rwx','s','samecase','samemark','samewith','say',
+        'schedule-on','scheduler','scope','sec','sech','second','seek','self',
+        'send','Set','set_hidden','set_name','set_package','set_rw','set_value',
+        'SetHash','set-instruments','setup_finalization','shape','share','shell',
+        'shift','sibling','sigil','sign','signal','signals','signature','sin',
+        'sinh','sink','sink-all','skip','skip-at-least','skip-at-least-pull-one',
+        'skip-one','skip-rest','sleep','sleep-timer','sleep-until','Slip','slurp',
+        'slurp-rest','slurpy','snap','snapper','so','socket-host','socket-port',
+        'sort','source','source-package','spawn','SPEC','splice','split',
+        'splitdir','splitpath','sprintf','spurt','sqrt','squish','srand','stable',
+        'start','started','starts-with','status','stderr','stdout','Str',
+        'sub_signature','subbuf','subbuf-rw','subname','subparse','subst',
+        'subst-mutate','substr','substr-eq','substr-rw','subtest','succ','sum',
+        'Supply','symlink','t','tail','take','take-rw','tan','tanh','tap',
+        'target','target-name','tc','tclc','tell','then','throttle','throw',
+        'throws-like','timezone','tmpdir','to','today','todo','toggle','to-posix',
+        'total','trailing','trans','tree','trim','trim-leading','trim-trailing',
+        'truncate','truncated-to','trusts','try_acquire','trying','twigil','type',
+        'type_captures','typename','uc','udp','uncaught_handler','unimatch',
+        'uniname','uninames','uniparse','uniprop','uniprops','unique','unival',
+        'univals','unlike','unlink','unlock','unpack','unpolar','unshift',
+        'unwrap','updir','USAGE','use-ok','utc','val','value','values','VAR',
+        'variable','verbose-config','version','VMnames','volume','vow','w','wait',
+        'warn','watch','watch-path','week','weekday-of-month','week-number',
+        'week-year','WHAT','when','WHERE','WHEREFORE','WHICH','WHO',
+        'whole-second','WHY','wordcase','words','workaround','wrap','write',
+        'write-to','x','yada','year','yield','yyyy-mm-dd','z','zip','zip-latest',
+
+    )
+
+    PERL6_BUILTIN_CLASSES = (
+        #Booleans
+        'False','True',
+        #Classes
+        'Any','Array','Associative','AST','atomicint','Attribute','Backtrace',
+        'Backtrace::Frame','Bag','Baggy','BagHash','Blob','Block','Bool','Buf',
+        'Callable','CallFrame','Cancellation','Capture','CArray','Channel','Code',
+        'compiler','Complex','ComplexStr','Cool','CurrentThreadScheduler',
+        'Cursor','Date','Dateish','DateTime','Distro','Duration','Encoding',
+        'Exception','Failure','FatRat','Grammar','Hash','HyperWhatever','Instant',
+        'Int','int16','int32','int64','int8','IntStr','IO','IO::ArgFiles',
+        'IO::CatHandle','IO::Handle','IO::Notification','IO::Path',
+        'IO::Path::Cygwin','IO::Path::QNX','IO::Path::Unix','IO::Path::Win32',
+        'IO::Pipe','IO::Socket','IO::Socket::Async','IO::Socket::INET','IO::Spec',
+        'IO::Spec::Cygwin','IO::Spec::QNX','IO::Spec::Unix','IO::Spec::Win32',
+        'IO::Special','Iterable','Iterator','Junction','Kernel','Label','List',
+        'Lock','Lock::Async','long','longlong','Macro','Map','Match',
+        'Metamodel::AttributeContainer','Metamodel::C3MRO','Metamodel::ClassHOW',
+        'Metamodel::EnumHOW','Metamodel::Finalization','Metamodel::MethodContainer',
+        'Metamodel::MROBasedMethodDispatch','Metamodel::MultipleInheritance',
+        'Metamodel::Naming','Metamodel::Primitives','Metamodel::PrivateMethodContainer',
+        'Metamodel::RoleContainer','Metamodel::Trusting','Method','Mix','MixHash',
+        'Mixy','Mu','NFC','NFD','NFKC','NFKD','Nil','Num','num32','num64',
+        'Numeric','NumStr','ObjAt','Order','Pair','Parameter','Perl','Pod::Block',
+        'Pod::Block::Code','Pod::Block::Comment','Pod::Block::Declarator',
+        'Pod::Block::Named','Pod::Block::Para','Pod::Block::Table','Pod::Heading',
+        'Pod::Item','Pointer','Positional','PositionalBindFailover','Proc',
+        'Proc::Async','Promise','Proxy','PseudoStash','QuantHash','Range','Rat',
+        'Rational','RatStr','Real','Regex','Routine','Scalar','Scheduler',
+        'Semaphore','Seq','Set','SetHash','Setty','Signature','size_t','Slip',
+        'Stash','Str','StrDistance','Stringy','Sub','Submethod','Supplier',
+        'Supplier::Preserving','Supply','Systemic','Tap','Telemetry',
+        'Telemetry::Instrument::Thread','Telemetry::Instrument::Usage',
+        'Telemetry::Period','Telemetry::Sampler','Thread','ThreadPoolScheduler',
+        'UInt','uint16','uint32','uint64','uint8','Uni','utf8','Variable',
+        'Version','VM','Whatever','WhateverCode','WrapHandle'
+    )
+
+    PERL6_OPERATORS = (
+        'X', 'Z', 'after', 'also', 'and', 'andthen', 'before', 'cmp', 'div',
+        'eq', 'eqv', 'extra', 'ff', 'fff', 'ge', 'gt', 'le', 'leg', 'lt', 'm',
+        'mm', 'mod', 'ne', 'or', 'orelse', 'rx', 's', 'tr', 'x', 'xor', 'xx',
+        '++', '--', '**', '!', '+', '-', '~', '?', '|', '||', '+^', '~^', '?^',
+        '^', '*', '/', '%', '%%', '+&', '+<', '+>', '~&', '~<', '~>', '?&',
+        'gcd', 'lcm', '+', '-', '+|', '+^', '~|', '~^', '?|', '?^',
+        '~', '&', '^', 'but', 'does', '<=>', '..', '..^', '^..', '^..^',
+        '!=', '==', '<', '<=', '>', '>=', '~~', '===', '!eqv',
+        '&&', '||', '^^', '//', 'min', 'max', '??', '!!', 'ff', 'fff', 'so',
+        'not', '<==', '==>', '<<==', '==>>','unicmp',
+    )
+
+    # Perl 6 has a *lot* of possible bracketing characters
+    # this list was lifted from STD.pm6 (https://github.com/perl6/std)
+    PERL6_BRACKETS = {
+        '\u0028': '\u0029', '\u003c': '\u003e', '\u005b': '\u005d',
+        '\u007b': '\u007d', '\u00ab': '\u00bb', '\u0f3a': '\u0f3b',
+        '\u0f3c': '\u0f3d', '\u169b': '\u169c', '\u2018': '\u2019',
+        '\u201a': '\u2019', '\u201b': '\u2019', '\u201c': '\u201d',
+        '\u201e': '\u201d', '\u201f': '\u201d', '\u2039': '\u203a',
+        '\u2045': '\u2046', '\u207d': '\u207e', '\u208d': '\u208e',
+        '\u2208': '\u220b', '\u2209': '\u220c', '\u220a': '\u220d',
+        '\u2215': '\u29f5', '\u223c': '\u223d', '\u2243': '\u22cd',
+        '\u2252': '\u2253', '\u2254': '\u2255', '\u2264': '\u2265',
+        '\u2266': '\u2267', '\u2268': '\u2269', '\u226a': '\u226b',
+        '\u226e': '\u226f', '\u2270': '\u2271', '\u2272': '\u2273',
+        '\u2274': '\u2275', '\u2276': '\u2277', '\u2278': '\u2279',
+        '\u227a': '\u227b', '\u227c': '\u227d', '\u227e': '\u227f',
+        '\u2280': '\u2281', '\u2282': '\u2283', '\u2284': '\u2285',
+        '\u2286': '\u2287', '\u2288': '\u2289', '\u228a': '\u228b',
+        '\u228f': '\u2290', '\u2291': '\u2292', '\u2298': '\u29b8',
+        '\u22a2': '\u22a3', '\u22a6': '\u2ade', '\u22a8': '\u2ae4',
+        '\u22a9': '\u2ae3', '\u22ab': '\u2ae5', '\u22b0': '\u22b1',
+        '\u22b2': '\u22b3', '\u22b4': '\u22b5', '\u22b6': '\u22b7',
+        '\u22c9': '\u22ca', '\u22cb': '\u22cc', '\u22d0': '\u22d1',
+        '\u22d6': '\u22d7', '\u22d8': '\u22d9', '\u22da': '\u22db',
+        '\u22dc': '\u22dd', '\u22de': '\u22df', '\u22e0': '\u22e1',
+        '\u22e2': '\u22e3', '\u22e4': '\u22e5', '\u22e6': '\u22e7',
+        '\u22e8': '\u22e9', '\u22ea': '\u22eb', '\u22ec': '\u22ed',
+        '\u22f0': '\u22f1', '\u22f2': '\u22fa', '\u22f3': '\u22fb',
+        '\u22f4': '\u22fc', '\u22f6': '\u22fd', '\u22f7': '\u22fe',
+        '\u2308': '\u2309', '\u230a': '\u230b', '\u2329': '\u232a',
+        '\u23b4': '\u23b5', '\u2768': '\u2769', '\u276a': '\u276b',
+        '\u276c': '\u276d', '\u276e': '\u276f', '\u2770': '\u2771',
+        '\u2772': '\u2773', '\u2774': '\u2775', '\u27c3': '\u27c4',
+        '\u27c5': '\u27c6', '\u27d5': '\u27d6', '\u27dd': '\u27de',
+        '\u27e2': '\u27e3', '\u27e4': '\u27e5', '\u27e6': '\u27e7',
+        '\u27e8': '\u27e9', '\u27ea': '\u27eb', '\u2983': '\u2984',
+        '\u2985': '\u2986', '\u2987': '\u2988', '\u2989': '\u298a',
+        '\u298b': '\u298c', '\u298d': '\u298e', '\u298f': '\u2990',
+        '\u2991': '\u2992', '\u2993': '\u2994', '\u2995': '\u2996',
+        '\u2997': '\u2998', '\u29c0': '\u29c1', '\u29c4': '\u29c5',
+        '\u29cf': '\u29d0', '\u29d1': '\u29d2', '\u29d4': '\u29d5',
+        '\u29d8': '\u29d9', '\u29da': '\u29db', '\u29f8': '\u29f9',
+        '\u29fc': '\u29fd', '\u2a2b': '\u2a2c', '\u2a2d': '\u2a2e',
+        '\u2a34': '\u2a35', '\u2a3c': '\u2a3d', '\u2a64': '\u2a65',
+        '\u2a79': '\u2a7a', '\u2a7d': '\u2a7e', '\u2a7f': '\u2a80',
+        '\u2a81': '\u2a82', '\u2a83': '\u2a84', '\u2a8b': '\u2a8c',
+        '\u2a91': '\u2a92', '\u2a93': '\u2a94', '\u2a95': '\u2a96',
+        '\u2a97': '\u2a98', '\u2a99': '\u2a9a', '\u2a9b': '\u2a9c',
+        '\u2aa1': '\u2aa2', '\u2aa6': '\u2aa7', '\u2aa8': '\u2aa9',
+        '\u2aaa': '\u2aab', '\u2aac': '\u2aad', '\u2aaf': '\u2ab0',
+        '\u2ab3': '\u2ab4', '\u2abb': '\u2abc', '\u2abd': '\u2abe',
+        '\u2abf': '\u2ac0', '\u2ac1': '\u2ac2', '\u2ac3': '\u2ac4',
+        '\u2ac5': '\u2ac6', '\u2acd': '\u2ace', '\u2acf': '\u2ad0',
+        '\u2ad1': '\u2ad2', '\u2ad3': '\u2ad4', '\u2ad5': '\u2ad6',
+        '\u2aec': '\u2aed', '\u2af7': '\u2af8', '\u2af9': '\u2afa',
+        '\u2e02': '\u2e03', '\u2e04': '\u2e05', '\u2e09': '\u2e0a',
+        '\u2e0c': '\u2e0d', '\u2e1c': '\u2e1d', '\u2e20': '\u2e21',
+        '\u3008': '\u3009', '\u300a': '\u300b', '\u300c': '\u300d',
+        '\u300e': '\u300f', '\u3010': '\u3011', '\u3014': '\u3015',
+        '\u3016': '\u3017', '\u3018': '\u3019', '\u301a': '\u301b',
+        '\u301d': '\u301e', '\ufd3e': '\ufd3f', '\ufe17': '\ufe18',
+        '\ufe35': '\ufe36', '\ufe37': '\ufe38', '\ufe39': '\ufe3a',
+        '\ufe3b': '\ufe3c', '\ufe3d': '\ufe3e', '\ufe3f': '\ufe40',
+        '\ufe41': '\ufe42', '\ufe43': '\ufe44', '\ufe47': '\ufe48',
+        '\ufe59': '\ufe5a', '\ufe5b': '\ufe5c', '\ufe5d': '\ufe5e',
+        '\uff08': '\uff09', '\uff1c': '\uff1e', '\uff3b': '\uff3d',
+        '\uff5b': '\uff5d', '\uff5f': '\uff60', '\uff62': '\uff63',
+    }
+
+    def _build_word_match(words, boundary_regex_fragment=None, prefix='', suffix=''):
+        if boundary_regex_fragment is None:
+            return r'\b(' + prefix + r'|'.join(re.escape(x) for x in words) + \
+                suffix + r')\b'
+        else:
+            return r'(? 0:
+                    next_open_pos = text.find(opening_chars, search_pos + n_chars)
+                    next_close_pos = text.find(closing_chars, search_pos + n_chars)
+
+                    if next_close_pos == -1:
+                        next_close_pos = len(text)
+                        nesting_level = 0
+                    elif next_open_pos != -1 and next_open_pos < next_close_pos:
+                        nesting_level += 1
+                        search_pos = next_open_pos
+                    else:  # next_close_pos < next_open_pos
+                        nesting_level -= 1
+                        search_pos = next_close_pos
+
+                end_pos = next_close_pos
+
+            if end_pos < 0:     # if we didn't find a closer, just highlight the
+                                # rest of the text in this class
+                end_pos = len(text)
+
+            if adverbs is not None and re.search(r':to\b', adverbs):
+                heredoc_terminator = text[match.start('delimiter') + n_chars:end_pos]
+                end_heredoc = re.search(r'^\s*' + re.escape(heredoc_terminator) +
+                                        r'\s*$', text[end_pos:], re.MULTILINE)
+
+                if end_heredoc:
+                    end_pos += end_heredoc.end()
+                else:
+                    end_pos = len(text)
+
+            yield match.start(), token_class, text[match.start():end_pos + n_chars]
+            context.pos = end_pos + n_chars
+
+        return callback
+
+    def opening_brace_callback(lexer, match, context):
+        stack = context.stack
+
+        yield match.start(), Text, context.text[match.start():match.end()]
+        context.pos = match.end()
+
+        # if we encounter an opening brace and we're one level
+        # below a token state, it means we need to increment
+        # the nesting level for braces so we know later when
+        # we should return to the token rules.
+        if len(stack) > 2 and stack[-2] == 'token':
+            context.perl6_token_nesting_level += 1
+
+    def closing_brace_callback(lexer, match, context):
+        stack = context.stack
+
+        yield match.start(), Text, context.text[match.start():match.end()]
+        context.pos = match.end()
+
+        # if we encounter a free closing brace and we're one level
+        # below a token state, it means we need to check the nesting
+        # level to see if we need to return to the token state.
+        if len(stack) > 2 and stack[-2] == 'token':
+            context.perl6_token_nesting_level -= 1
+            if context.perl6_token_nesting_level == 0:
+                stack.pop()
+
+    def embedded_perl6_callback(lexer, match, context):
+        context.perl6_token_nesting_level = 1
+        yield match.start(), Text, context.text[match.start():match.end()]
+        context.pos = match.end()
+        context.stack.append('root')
+
+    # If you're modifying these rules, be careful if you need to process '{' or '}'
+    # characters. We have special logic for processing these characters (due to the fact
+    # that you can nest Perl 6 code in regex blocks), so if you need to process one of
+    # them, make sure you also process the corresponding one!
+    tokens = {
+        'common': [
+            (r'#[`|=](?P(?P[' + ''.join(PERL6_BRACKETS) + r'])(?P=first_char)*)',
+             brackets_callback(Comment.Multiline)),
+            (r'#[^\n]*$', Comment.Single),
+            (r'^(\s*)=begin\s+(\w+)\b.*?^\1=end\s+\2', Comment.Multiline),
+            (r'^(\s*)=for.*?\n\s*?\n', Comment.Multiline),
+            (r'^=.*?\n\s*?\n', Comment.Multiline),
+            (r'(regex|token|rule)(\s*' + PERL6_IDENTIFIER_RANGE + '+:sym)',
+             bygroups(Keyword, Name), 'token-sym-brackets'),
+            (r'(regex|token|rule)(?!' + PERL6_IDENTIFIER_RANGE + r')(\s*' + PERL6_IDENTIFIER_RANGE + '+)?',
+             bygroups(Keyword, Name), 'pre-token'),
+            # deal with a special case in the Perl 6 grammar (role q { ... })
+            (r'(role)(\s+)(q)(\s*)', bygroups(Keyword, Whitespace, Name, Whitespace)),
+            (_build_word_match(PERL6_KEYWORDS, PERL6_IDENTIFIER_RANGE), Keyword),
+            (_build_word_match(PERL6_BUILTIN_CLASSES, PERL6_IDENTIFIER_RANGE, suffix='(?::[UD])?'),
+             Name.Builtin),
+            (_build_word_match(PERL6_BUILTINS, PERL6_IDENTIFIER_RANGE), Name.Builtin),
+            # copied from PerlLexer
+            (r'[$@%&][.^:?=!~]?' + PERL6_IDENTIFIER_RANGE + '+(?:<<.*?>>|<.*?>|«.*?»)*',
+             Name.Variable),
+            (r'\$[!/](?:<<.*?>>|<.*?>|«.*?»)*', Name.Variable.Global),
+            (r'::\?\w+', Name.Variable.Global),
+            (r'[$@%&]\*' + PERL6_IDENTIFIER_RANGE + '+(?:<<.*?>>|<.*?>|«.*?»)*',
+             Name.Variable.Global),
+            (r'\$(?:<.*?>)+', Name.Variable),
+            (r'(?:q|qq|Q)[a-zA-Z]?\s*(?P:[\w\s:]+)?\s*(?P(?P[^0-9a-zA-Z:\s])'
+             r'(?P=first_char)*)', brackets_callback(String)),
+            # copied from PerlLexer
+            (r'0_?[0-7]+(_[0-7]+)*', Number.Oct),
+            (r'0x[0-9A-Fa-f]+(_[0-9A-Fa-f]+)*', Number.Hex),
+            (r'0b[01]+(_[01]+)*', Number.Bin),
+            (r'(?i)(\d*(_\d*)*\.\d+(_\d*)*|\d+(_\d*)*\.\d+(_\d*)*)(e[+-]?\d+)?',
+             Number.Float),
+            (r'(?i)\d+(_\d*)*e[+-]?\d+(_\d*)*', Number.Float),
+            (r'\d+(_\d+)*', Number.Integer),
+            (r'(?<=~~)\s*/(?:\\\\|\\/|.)*?/', String.Regex),
+            (r'(?<=[=(,])\s*/(?:\\\\|\\/|.)*?/', String.Regex),
+            (r'm\w+(?=\()', Name),
+            (r'(?:m|ms|rx)\s*(?P:[\w\s:]+)?\s*(?P(?P[^\w:\s])'
+             r'(?P=first_char)*)', brackets_callback(String.Regex)),
+            (r'(?:s|ss|tr)\s*(?::[\w\s:]+)?\s*/(?:\\\\|\\/|.)*?/(?:\\\\|\\/|.)*?/',
+             String.Regex),
+            (r'<[^\s=].*?\S>', String),
+            (_build_word_match(PERL6_OPERATORS), Operator),
+            (r'\w' + PERL6_IDENTIFIER_RANGE + '*', Name),
+            (r"'(\\\\|\\[^\\]|[^'\\])*'", String),
+            (r'"(\\\\|\\[^\\]|[^"\\])*"', String),
+        ],
+        'root': [
+            include('common'),
+            (r'\{', opening_brace_callback),
+            (r'\}', closing_brace_callback),
+            (r'.+?', Text),
+        ],
+        'pre-token': [
+            include('common'),
+            (r'\{', Text, ('#pop', 'token')),
+            (r'.+?', Text),
+        ],
+        'token-sym-brackets': [
+            (r'(?P(?P[' + ''.join(PERL6_BRACKETS) + '])(?P=first_char)*)',
+             brackets_callback(Name), ('#pop', 'pre-token')),
+            default(('#pop', 'pre-token')),
+        ],
+        'token': [
+            (r'\}', Text, '#pop'),
+            (r'(?<=:)(?:my|our|state|constant|temp|let).*?;', using(this)),
+            # make sure that quotes in character classes aren't treated as strings
+            (r'<(?:[-!?+.]\s*)?\[.*?\]>', String.Regex),
+            # make sure that '#' characters in quotes aren't treated as comments
+            (r"(?my|our)\s+)?(?:module|class|role|enum|grammar)', line)
+            if class_decl:
+                if saw_perl_decl or class_decl.group('scope') is not None:
+                    return True
+                rating = 0.05
+                continue
+            break
+
+        if ':=' in text:
+            # Same logic as above for PerlLexer
+            rating /= 2
+
+        return rating
+
+    def __init__(self, **options):
+        super().__init__(**options)
+        self.encoding = options.get('encoding', 'utf-8')
diff --git a/pygments/lexers/phix.py b/pygments/lexers/phix.py
new file mode 100644
index 0000000000000000000000000000000000000000..f0b03775e27efedbcd60b6f8ef0df0c93bb1eebc
--- /dev/null
+++ b/pygments/lexers/phix.py
@@ -0,0 +1,363 @@
+"""
+    pygments.lexers.phix
+    ~~~~~~~~~~~~~~~~~~~~
+
+    Lexers for Phix.
+
+    :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS.
+    :license: BSD, see LICENSE for details.
+"""
+
+import re
+
+from pygments.lexer import RegexLexer, words
+from pygments.token import Text, Comment, Operator, Keyword, Name, String, \
+    Whitespace
+
+__all__ = ['PhixLexer']
+
+
+class PhixLexer(RegexLexer):
+    """
+    Pygments Lexer for Phix files (.exw).
+    See http://phix.x10.mx
+    """
+
+    name = 'Phix'
+    url = 'http://phix.x10.mx'
+    aliases = ['phix']
+    filenames = ['*.exw']
+    mimetypes = ['text/x-phix']
+    version_added = '2.14'
+
+    flags = re.MULTILINE    # nb: **NOT** re.DOTALL! (totally spanners comment handling)
+
+    preproc = (
+        'ifdef', 'elsifdef', 'elsedef'
+    )
+    # Note these lists are auto-generated by pwa/p2js.exw, when pwa\src\p2js_keywords.e (etc)
+    #     change, though of course subsequent copy/commit/pull requests are all manual steps.
+    types = (
+        'string', 'nullable_string', 'atom_string', 'atom', 'bool', 'boolean',
+        'cdCanvan', 'cdCanvas', 'complex', 'CURLcode', 'dictionary', 'int',
+        'integer', 'Ihandle', 'Ihandles', 'Ihandln', 'mpfr', 'mpq', 'mpz',
+        'mpz_or_string', 'number', 'rid_string', 'seq', 'sequence', 'timedate',
+        'object'
+    )
+    keywords = (
+        'abstract', 'class', 'continue', 'export', 'extends', 'nullable',
+        'private', 'public', 'static', 'struct', 'trace',
+        'and', 'break', 'by', 'case', 'catch', 'const', 'constant', 'debug',
+        'default', 'do', 'else', 'elsif', 'end', 'enum', 'exit', 'fallthru',
+        'fallthrough', 'for', 'forward', 'function', 'global', 'if', 'in',
+        'include', 'js', 'javascript', 'javascript_semantics', 'let', 'not',
+        'or', 'procedure', 'profile', 'profile_time', 'return', 'safe_mode',
+        'switch', 'then', 'to', 'try', 'type', 'type_check', 'until', 'warning',
+        'while', 'with', 'without', 'xor'
+    )
+    routines = (
+        'abort', 'abs', 'adjust_timedate', 'and_bits', 'and_bitsu', 'apply',
+        'append', 'arccos', 'arcsin', 'arctan', 'assert', 'atan2',
+        'atom_to_float32', 'atom_to_float64', 'bankers_rounding', 'beep',
+        'begins', 'binary_search', 'bits_to_int', 'bk_color', 'bytes_to_int',
+        'call_func', 'call_proc', 'cdCanvasActivate', 'cdCanvasArc',
+        'cdCanvasBegin', 'cdCanvasBox', 'cdCanvasChord', 'cdCanvasCircle',
+        'cdCanvasClear', 'cdCanvasEnd', 'cdCanvasFlush', 'cdCanvasFont',
+        'cdCanvasGetImageRGB', 'cdCanvasGetSize', 'cdCanvasGetTextAlignment',
+        'cdCanvasGetTextSize', 'cdCanvasLine', 'cdCanvasMark',
+        'cdCanvasMarkSize', 'cdCanvasMultiLineVectorText', 'cdCanvasPixel',
+        'cdCanvasRect', 'cdCanvasRoundedBox', 'cdCanvasRoundedRect',
+        'cdCanvasSector', 'cdCanvasSetAttribute', 'cdCanvasSetBackground',
+        'cdCanvasSetFillMode', 'cdCanvasSetForeground',
+        'cdCanvasSetInteriorStyle', 'cdCanvasSetLineStyle',
+        'cdCanvasSetLineWidth', 'cdCanvasSetTextAlignment', 'cdCanvasText',
+        'cdCanvasSetTextOrientation', 'cdCanvasGetTextOrientation',
+        'cdCanvasVectorText', 'cdCanvasVectorTextDirection',
+        'cdCanvasVectorTextSize', 'cdCanvasVertex', 'cdCreateCanvas',
+        'cdDecodeAlpha', 'cdDecodeColor', 'cdDecodeColorAlpha', 'cdEncodeAlpha',
+        'cdEncodeColor', 'cdEncodeColorAlpha', 'cdKillCanvas', 'cdVersion',
+        'cdVersionDate', 'ceil', 'change_timezone', 'choose', 'clear_screen',
+        'columnize', 'command_line', 'compare', 'complex_abs', 'complex_add',
+        'complex_arg', 'complex_conjugate', 'complex_cos', 'complex_cosh',
+        'complex_div', 'complex_exp', 'complex_imag', 'complex_inv',
+        'complex_log', 'complex_mul', 'complex_neg', 'complex_new',
+        'complex_norm', 'complex_power', 'complex_rho', 'complex_real',
+        'complex_round', 'complex_sin', 'complex_sinh', 'complex_sprint',
+        'complex_sqrt', 'complex_sub', 'complex_theta', 'concat', 'cos',
+        'crash', 'custom_sort', 'date', 'day_of_week', 'day_of_year',
+        'days_in_month', 'decode_base64', 'decode_flags', 'deep_copy', 'deld',
+        'deserialize', 'destroy_dict', 'destroy_queue', 'destroy_stack',
+        'dict_name', 'dict_size', 'elapsed', 'elapsed_short', 'encode_base64',
+        'equal', 'even', 'exp', 'extract', 'factorial', 'factors',
+        'file_size_k', 'find', 'find_all', 'find_any', 'find_replace', 'filter',
+        'flatten', 'float32_to_atom', 'float64_to_atom', 'floor',
+        'format_timedate', 'free_console', 'from_polar', 'gcd', 'get_file_base',
+        'get_file_extension', 'get_file_name', 'get_file_name_and_path',
+        'get_file_path', 'get_file_path_and_name', 'get_maxprime', 'get_prime',
+        'get_primes', 'get_primes_le', 'get_proper_dir', 'get_proper_path',
+        'get_rand', 'get_routine_info', 'get_test_abort', 'get_test_logfile',
+        'get_test_pause', 'get_test_verbosity', 'get_tzid', 'getd', 'getdd',
+        'getd_all_keys', 'getd_by_index', 'getd_index', 'getd_partial_key',
+        'glAttachShader', 'glBindBuffer', 'glBindTexture', 'glBufferData',
+        'glCanvasSpecialText', 'glClear', 'glClearColor', 'glColor',
+        'glCompileShader', 'glCreateBuffer', 'glCreateProgram',
+        'glCreateShader', 'glCreateTexture', 'glDeleteProgram',
+        'glDeleteShader', 'glDrawArrays', 'glEnable',
+        'glEnableVertexAttribArray', 'glFloat32Array', 'glInt32Array',
+        'glFlush', 'glGetAttribLocation', 'glGetError', 'glGetProgramInfoLog',
+        'glGetProgramParameter', 'glGetShaderInfoLog', 'glGetShaderParameter',
+        'glGetUniformLocation', 'glLinkProgram', 'glLoadIdentity',
+        'glMatrixMode', 'glOrtho', 'glRotatef', 'glShadeModel',
+        'glShaderSource', 'glSimpleA7texcoords', 'glTexImage2Dc',
+        'glTexParameteri', 'glTranslate', 'glUniform1f', 'glUniform1i',
+        'glUniformMatrix4fv', 'glUseProgram', 'glVertex',
+        'glVertexAttribPointer', 'glViewport', 'head', 'hsv_to_rgb', 'iff',
+        'iif', 'include_file', 'incl0de_file', 'insert', 'instance',
+        'int_to_bits', 'int_to_bytes', 'is_dict', 'is_integer', 's_leap_year',
+        'is_prime', 'is_prime2', 'islower', 'isupper', 'Icallback',
+        'iup_isdouble', 'iup_isprint', 'iup_XkeyBase', 'IupAppend', 'IupAlarm',
+        'IupBackgroundBox', 'IupButton', 'IupCalendar', 'IupCanvas',
+        'IupClipboard', 'IupClose', 'IupCloseOnEscape', 'IupControlsOpen',
+        'IupDatePick', 'IupDestroy', 'IupDialog', 'IupDrawArc', 'IupDrawBegin',
+        'IupDrawEnd', 'IupDrawGetSize', 'IupDrawGetTextSize', 'IupDrawLine',
+        'IupDrawRectangle', 'IupDrawText', 'IupExpander', 'IupFill',
+        'IupFlatLabel', 'IupFlatList', 'IupFlatTree', 'IupFlush', 'IupFrame',
+        'IupGetAttribute', 'IupGetAttributeId', 'IupGetAttributePtr',
+        'IupGetBrother', 'IupGetChild', 'IupGetChildCount', 'IupGetClassName',
+        'IupGetDialog', 'IupGetDialogChild', 'IupGetDouble', 'IupGetFocus',
+        'IupGetGlobal', 'IupGetGlobalInt', 'IupGetGlobalIntInt', 'IupGetInt',
+        'IupGetInt2', 'IupGetIntId', 'IupGetIntInt', 'IupGetParent',
+        'IupGLCanvas', 'IupGLCanvasOpen', 'IupGLMakeCurrent', 'IupGraph',
+        'IupHbox', 'IupHide', 'IupImage', 'IupImageRGBA', 'IupItem',
+        'iupKeyCodeToName', 'IupLabel', 'IupLink', 'IupList', 'IupMap',
+        'IupMenu', 'IupMenuItem', 'IupMessage', 'IupMessageDlg', 'IupMultiBox',
+        'IupMultiLine', 'IupNextField', 'IupNormaliser', 'IupOpen',
+        'IupPlayInput', 'IupPopup', 'IupPreviousField', 'IupProgressBar',
+        'IupRadio', 'IupRecordInput', 'IupRedraw', 'IupRefresh',
+        'IupRefreshChildren', 'IupSeparator', 'IupSetAttribute',
+        'IupSetAttributes', 'IupSetAttributeHandle', 'IupSetAttributeId',
+        'IupSetAttributePtr', 'IupSetCallback', 'IupSetCallbacks',
+        'IupSetDouble', 'IupSetFocus', 'IupSetGlobal', 'IupSetGlobalInt',
+        'IupSetGlobalFunction', 'IupSetHandle', 'IupSetInt',
+        'IupSetStrAttribute', 'IupSetStrGlobal', 'IupShow', 'IupShowXY',
+        'IupSplit', 'IupStoreAttribute', 'IupSubmenu', 'IupTable',
+        'IupTableClearSelected', 'IupTableClick_cb', 'IupTableGetSelected',
+        'IupTableResize_cb', 'IupTableSetData', 'IupTabs', 'IupText',
+        'IupTimer', 'IupToggle', 'IupTreeAddNodes', 'IupTreeView', 'IupUpdate',
+        'IupValuator', 'IupVbox', 'join', 'join_by', 'join_path', 'k_perm',
+        'largest', 'lcm', 'length', 'log', 'log10', 'log2', 'lower',
+        'm4_crossProduct', 'm4_inverse', 'm4_lookAt', 'm4_multiply',
+        'm4_normalize', 'm4_perspective', 'm4_subtractVectors', 'm4_xRotate',
+        'm4_yRotate', 'machine_bits', 'machine_word', 'match', 'match_all',
+        'match_replace', 'max', 'maxsq', 'min', 'minsq', 'mod', 'mpfr_add',
+        'mpfr_ceil', 'mpfr_cmp', 'mpfr_cmp_si', 'mpfr_const_pi', 'mpfr_div',
+        'mpfr_div_si', 'mpfr_div_z', 'mpfr_floor', 'mpfr_free', 'mpfr_get_d',
+        'mpfr_get_default_precision', 'mpfr_get_default_rounding_mode',
+        'mpfr_get_fixed', 'mpfr_get_precision', 'mpfr_get_si', 'mpfr_init',
+        'mpfr_inits', 'mpfr_init_set', 'mpfr_init_set_q', 'mpfr_init_set_z',
+        'mpfr_mul', 'mpfr_mul_si', 'mpfr_pow_si', 'mpfr_set', 'mpfr_set_d',
+        'mpfr_set_default_precision', 'mpfr_set_default_rounding_mode',
+        'mpfr_set_precision', 'mpfr_set_q', 'mpfr_set_si', 'mpfr_set_str',
+        'mpfr_set_z', 'mpfr_si_div', 'mpfr_si_sub', 'mpfr_sqrt', 'mpfr_sub',
+        'mpfr_sub_si', 'mpq_abs', 'mpq_add', 'mpq_add_si', 'mpq_canonicalize',
+        'mpq_cmp', 'mpq_cmp_si', 'mpq_div', 'mpq_div_2exp', 'mpq_free',
+        'mpq_get_den', 'mpq_get_num', 'mpq_get_str', 'mpq_init', 'mpq_init_set',
+        'mpq_init_set_si', 'mpq_init_set_str', 'mpq_init_set_z', 'mpq_inits',
+        'mpq_inv', 'mpq_mul', 'mpq_neg', 'mpq_set', 'mpq_set_si', 'mpq_set_str',
+        'mpq_set_z', 'mpq_sub', 'mpz_abs', 'mpz_add', 'mpz_addmul',
+        'mpz_addmul_ui', 'mpz_addmul_si', 'mpz_add_si', 'mpz_add_ui', 'mpz_and',
+        'mpz_bin_uiui', 'mpz_cdiv_q', 'mpz_cmp', 'mpz_cmp_si', 'mpz_divexact',
+        'mpz_divexact_ui', 'mpz_divisible_p', 'mpz_divisible_ui_p', 'mpz_even',
+        'mpz_fac_ui', 'mpz_factorstring', 'mpz_fdiv_q', 'mpz_fdiv_q_2exp',
+        'mpz_fdiv_q_ui', 'mpz_fdiv_qr', 'mpz_fdiv_r', 'mpz_fdiv_ui',
+        'mpz_fib_ui', 'mpz_fib2_ui', 'mpz_fits_atom', 'mpz_fits_integer',
+        'mpz_free', 'mpz_gcd', 'mpz_gcd_ui', 'mpz_get_atom', 'mpz_get_integer',
+        'mpz_get_short_str', 'mpz_get_str', 'mpz_init', 'mpz_init_set',
+        'mpz_inits', 'mpz_invert', 'mpz_lcm', 'mpz_lcm_ui', 'mpz_max',
+        'mpz_min', 'mpz_mod', 'mpz_mod_ui', 'mpz_mul', 'mpz_mul_2exp',
+        'mpz_mul_d', 'mpz_mul_si', 'mpz_neg', 'mpz_nthroot', 'mpz_odd',
+        'mpz_pollard_rho', 'mpz_pow_ui', 'mpz_powm', 'mpz_powm_ui', 'mpz_prime',
+        'mpz_prime_factors', 'mpz_prime_mr', 'mpz_rand', 'mpz_rand_ui',
+        'mpz_re_compose', 'mpz_remove', 'mpz_scan0', 'mpz_scan1', 'mpz_set',
+        'mpz_set_d', 'mpz_set_si', 'mpz_set_str', 'mpz_set_v', 'mpz_sign',
+        'mpz_sizeinbase', 'mpz_sqrt', 'mpz_sub', 'mpz_sub_si', 'mpz_sub_ui',
+        'mpz_si_sub', 'mpz_tdiv_q_2exp', 'mpz_tdiv_r_2exp', 'mpz_tstbit',
+        'mpz_ui_pow_ui', 'mpz_xor', 'named_dict', 'new_dict', 'new_queue',
+        'new_stack', 'not_bits', 'not_bitsu', 'odd', 'or_all', 'or_allu',
+        'or_bits', 'or_bitsu', 'ord', 'ordinal', 'ordinant',
+        'override_timezone', 'pad', 'pad_head', 'pad_tail', 'parse_date_string',
+        'papply', 'peep', 'peepn', 'peep_dict', 'permute', 'permutes',
+        'platform', 'pop', 'popn', 'pop_dict', 'power', 'pp', 'ppEx', 'ppExf',
+        'ppf', 'ppOpt', 'pq_add', 'pq_destroy', 'pq_empty', 'pq_new', 'pq_peek',
+        'pq_pop', 'pq_pop_data', 'pq_size', 'prepend', 'prime_factors',
+        'printf', 'product', 'proper', 'push', 'pushn', 'putd', 'puts',
+        'queue_empty', 'queue_size', 'rand', 'rand_range', 'reinstate',
+        'remainder', 'remove', 'remove_all', 'repeat', 'repeatch', 'replace',
+        'requires', 'reverse', 'rfind', 'rgb', 'rmatch', 'rmdr', 'rnd', 'round',
+        'routine_id', 'scanf', 'serialize', 'series', 'set_rand',
+        'set_test_abort', 'set_test_logfile', 'set_test_module',
+        'set_test_pause', 'set_test_verbosity', 'set_timedate_formats',
+        'set_timezone', 'setd', 'setd_default', 'shorten', 'sha256',
+        'shift_bits', 'shuffle', 'sign', 'sin', 'smallest', 'sort',
+        'sort_columns', 'speak', 'splice', 'split', 'split_any', 'split_by',
+        'sprint', 'sprintf', 'sq_abs', 'sq_add', 'sq_and', 'sq_and_bits',
+        'sq_arccos', 'sq_arcsin', 'sq_arctan', 'sq_atom', 'sq_ceil', 'sq_cmp',
+        'sq_cos', 'sq_div', 'sq_even', 'sq_eq', 'sq_floor', 'sq_floor_div',
+        'sq_ge', 'sq_gt', 'sq_int', 'sq_le', 'sq_log', 'sq_log10', 'sq_log2',
+        'sq_lt', 'sq_max', 'sq_min', 'sq_mod', 'sq_mul', 'sq_ne', 'sq_not',
+        'sq_not_bits', 'sq_odd', 'sq_or', 'sq_or_bits', 'sq_power', 'sq_rand',
+        'sq_remainder', 'sq_rmdr', 'sq_rnd', 'sq_round', 'sq_seq', 'sq_sign',
+        'sq_sin', 'sq_sqrt', 'sq_str', 'sq_sub', 'sq_tan', 'sq_trunc',
+        'sq_uminus', 'sq_xor', 'sq_xor_bits', 'sqrt', 'square_free',
+        'stack_empty', 'stack_size', 'substitute', 'substitute_all', 'sum',
+        'tail', 'tan', 'test_equal', 'test_fail', 'test_false',
+        'test_not_equal', 'test_pass', 'test_summary', 'test_true',
+        'text_color', 'throw', 'time', 'timedate_diff', 'timedelta',
+        'to_integer', 'to_number', 'to_rgb', 'to_string', 'traverse_dict',
+        'traverse_dict_partial_key', 'trim', 'trim_head', 'trim_tail', 'trunc',
+        'tagset', 'tagstart', 'typeof', 'unique', 'unix_dict', 'upper',
+        'utf8_to_utf32', 'utf32_to_utf8', 'version', 'vlookup', 'vslice',
+        'wglGetProcAddress', 'wildcard_file', 'wildcard_match', 'with_rho',
+        'with_theta', 'xml_new_doc', 'xml_new_element', 'xml_set_attribute',
+        'xml_sprint', 'xor_bits', 'xor_bitsu',
+        'accept', 'allocate', 'allocate_string', 'allow_break', 'ARM',
+        'atom_to_float80', 'c_func', 'c_proc', 'call_back', 'chdir',
+        'check_break', 'clearDib', 'close', 'closesocket', 'console',
+        'copy_file', 'create', 'create_directory', 'create_thread',
+        'curl_easy_cleanup', 'curl_easy_get_file', 'curl_easy_init',
+        'curl_easy_perform', 'curl_easy_perform_ex', 'curl_easy_setopt',
+        'curl_easy_strerror', 'curl_global_cleanup', 'curl_global_init',
+        'curl_slist_append', 'curl_slist_free_all', 'current_dir', 'cursor',
+        'define_c_func', 'define_c_proc', 'delete', 'delete_cs', 'delete_file',
+        'dir', 'DLL', 'drawDib', 'drawShadedPolygonToDib', 'ELF32', 'ELF64',
+        'enter_cs', 'eval', 'exit_thread', 'free', 'file_exists', 'final',
+        'float80_to_atom', 'format', 'get_bytes', 'get_file_date',
+        'get_file_size', 'get_file_type', 'get_interpreter', 'get_key',
+        'get_socket_error', 'get_text', 'get_thread_exitcode', 'get_thread_id',
+        'getc', 'getenv', 'gets', 'getsockaddr', 'glBegin', 'glCallList',
+        'glFrustum', 'glGenLists', 'glGetString', 'glLight', 'glMaterial',
+        'glNewList', 'glNormal', 'glPopMatrix', 'glPushMatrix', 'glRotate',
+        'glEnd', 'glEndList', 'glTexImage2D', 'goto', 'GUI', 'icons', 'ilASM',
+        'include_files', 'include_paths', 'init_cs', 'ip_to_string',
+        'IupConfig', 'IupConfigDialogClosed', 'IupConfigDialogShow',
+        'IupConfigGetVariableInt', 'IupConfigLoad', 'IupConfigSave',
+        'IupConfigSetVariableInt', 'IupExitLoop', 'IupFileDlg', 'IupFileList',
+        'IupGLSwapBuffers', 'IupHelp', 'IupLoopStep', 'IupMainLoop',
+        'IupNormalizer', 'IupPlot', 'IupPlotAdd', 'IupPlotBegin', 'IupPlotEnd',
+        'IupPlotInsert', 'IupSaveImage', 'IupTreeGetUserId', 'IupUser',
+        'IupVersion', 'IupVersionDate', 'IupVersionNumber', 'IupVersionShow',
+        'killDib', 'leave_cs', 'listen', 'manifest', 'mem_copy', 'mem_set',
+        'mpfr_gamma', 'mpfr_printf', 'mpfr_sprintf', 'mpz_export', 'mpz_import',
+        'namespace', 'new', 'newDib', 'open', 'open_dll', 'PE32', 'PE64',
+        'peek', 'peek_string', 'peek1s', 'peek1u', 'peek2s', 'peek2u', 'peek4s',
+        'peek4u', 'peek8s', 'peek8u', 'peekNS', 'peekns', 'peeknu', 'poke',
+        'poke2', 'poke4', 'poke8', 'pokeN', 'poke_string', 'poke_wstring',
+        'position', 'progress', 'prompt_number', 'prompt_string', 'read_file',
+        'read_lines', 'recv', 'resume_thread', 'seek', 'select', 'send',
+        'setHandler', 'shutdown', 'sleep', 'SO', 'sockaddr_in', 'socket',
+        'split_path', 'suspend_thread', 'system', 'system_exec', 'system_open',
+        'system_wait', 'task_clock_start', 'task_clock_stop', 'task_create',
+        'task_delay', 'task_list', 'task_schedule', 'task_self', 'task_status',
+        'task_suspend', 'task_yield', 'thread_safe_string', 'try_cs',
+        'utf8_to_utf16', 'utf16_to_utf8', 'utf16_to_utf32', 'utf32_to_utf16',
+        'video_config', 'WSACleanup', 'wait_thread', 'walk_dir', 'where',
+        'write_lines', 'wait_key'
+    )
+    constants = (
+        'ANY_QUEUE', 'ASCENDING', 'BLACK', 'BLOCK_CURSOR', 'BLUE',
+        'BRIGHT_CYAN', 'BRIGHT_BLUE', 'BRIGHT_GREEN', 'BRIGHT_MAGENTA',
+        'BRIGHT_RED', 'BRIGHT_WHITE', 'BROWN', 'C_DWORD', 'C_INT', 'C_POINTER',
+        'C_USHORT', 'C_WORD', 'CD_AMBER', 'CD_BLACK', 'CD_BLUE', 'CD_BOLD',
+        'CD_BOLD_ITALIC', 'CD_BOX', 'CD_CENTER', 'CD_CIRCLE', 'CD_CLOSED_LINES',
+        'CD_CONTINUOUS', 'CD_CUSTOM', 'CD_CYAN', 'CD_DARK_BLUE', 'CD_DARK_CYAN',
+        'CD_DARK_GRAY', 'CD_DARK_GREY', 'CD_DARK_GREEN', 'CD_DARK_MAGENTA',
+        'CD_DARK_RED', 'CD_DARK_YELLOW', 'CD_DASH_DOT', 'CD_DASH_DOT_DOT',
+        'CD_DASHED', 'CD_DBUFFER', 'CD_DEG2RAD', 'CD_DIAMOND', 'CD_DOTTED',
+        'CD_EAST', 'CD_EVENODD', 'CD_FILL', 'CD_GL', 'CD_GRAY', 'CD_GREY',
+        'CD_GREEN', 'CD_HATCH', 'CD_HOLLOW', 'CD_HOLLOW_BOX',
+        'CD_HOLLOW_CIRCLE', 'CD_HOLLOW_DIAMOND', 'CD_INDIGO', 'CD_ITALIC',
+        'CD_IUP', 'CD_IUPDBUFFER', 'CD_LIGHT_BLUE', 'CD_LIGHT_GRAY',
+        'CD_LIGHT_GREY', 'CD_LIGHT_GREEN', 'CD_LIGHT_PARCHMENT', 'CD_MAGENTA',
+        'CD_NAVY', 'CD_NORTH', 'CD_NORTH_EAST', 'CD_NORTH_WEST', 'CD_OLIVE',
+        'CD_OPEN_LINES', 'CD_ORANGE', 'CD_PARCHMENT', 'CD_PATTERN',
+        'CD_PRINTER', 'CD_PURPLE', 'CD_PLAIN', 'CD_PLUS', 'CD_QUERY',
+        'CD_RAD2DEG', 'CD_RED', 'CD_SILVER', 'CD_SOLID', 'CD_SOUTH_EAST',
+        'CD_SOUTH_WEST', 'CD_STAR', 'CD_STIPPLE', 'CD_STRIKEOUT',
+        'CD_UNDERLINE', 'CD_WEST', 'CD_WHITE', 'CD_WINDING', 'CD_VIOLET',
+        'CD_X', 'CD_YELLOW', 'CURLE_OK', 'CURLOPT_MAIL_FROM',
+        'CURLOPT_MAIL_RCPT', 'CURLOPT_PASSWORD', 'CURLOPT_READDATA',
+        'CURLOPT_READFUNCTION', 'CURLOPT_SSL_VERIFYPEER',
+        'CURLOPT_SSL_VERIFYHOST', 'CURLOPT_UPLOAD', 'CURLOPT_URL',
+        'CURLOPT_USE_SSL', 'CURLOPT_USERNAME', 'CURLOPT_VERBOSE',
+        'CURLOPT_WRITEFUNCTION', 'CURLUSESSL_ALL', 'CYAN', 'D_NAME',
+        'D_ATTRIBUTES', 'D_SIZE', 'D_YEAR', 'D_MONTH', 'D_DAY', 'D_HOUR',
+        'D_MINUTE', 'D_SECOND', 'D_CREATION', 'D_LASTACCESS', 'D_MODIFICATION',
+        'DT_YEAR', 'DT_MONTH', 'DT_DAY', 'DT_HOUR', 'DT_MINUTE', 'DT_SECOND',
+        'DT_DOW', 'DT_MSEC', 'DT_DOY', 'DT_GMT', 'EULER', 'E_CODE', 'E_ADDR',
+        'E_LINE', 'E_RTN', 'E_NAME', 'E_FILE', 'E_PATH', 'E_USER', 'false',
+        'False', 'FALSE', 'FIFO_QUEUE', 'FILETYPE_DIRECTORY', 'FILETYPE_FILE',
+        'GET_EOF', 'GET_FAIL', 'GET_IGNORE', 'GET_SUCCESS',
+        'GL_AMBIENT_AND_DIFFUSE', 'GL_ARRAY_BUFFER', 'GL_CLAMP',
+        'GL_CLAMP_TO_BORDER', 'GL_CLAMP_TO_EDGE', 'GL_COLOR_BUFFER_BIT',
+        'GL_COMPILE', 'GL_COMPILE_STATUS', 'GL_CULL_FACE',
+        'GL_DEPTH_BUFFER_BIT', 'GL_DEPTH_TEST', 'GL_EXTENSIONS', 'GL_FLAT',
+        'GL_FLOAT', 'GL_FRAGMENT_SHADER', 'GL_FRONT', 'GL_LIGHT0',
+        'GL_LIGHTING', 'GL_LINEAR', 'GL_LINK_STATUS', 'GL_MODELVIEW',
+        'GL_NEAREST', 'GL_NO_ERROR', 'GL_NORMALIZE', 'GL_POSITION',
+        'GL_PROJECTION', 'GL_QUAD_STRIP', 'GL_QUADS', 'GL_RENDERER',
+        'GL_REPEAT', 'GL_RGB', 'GL_RGBA', 'GL_SMOOTH', 'GL_STATIC_DRAW',
+        'GL_TEXTURE_2D', 'GL_TEXTURE_MAG_FILTER', 'GL_TEXTURE_MIN_FILTER',
+        'GL_TEXTURE_WRAP_S', 'GL_TEXTURE_WRAP_T', 'GL_TRIANGLES',
+        'GL_UNSIGNED_BYTE', 'GL_VENDOR', 'GL_VERSION', 'GL_VERTEX_SHADER',
+        'GRAY', 'GREEN', 'GT_LF_STRIPPED', 'GT_WHOLE_FILE', 'INVLN10',
+        'IUP_CLOSE', 'IUP_CONTINUE', 'IUP_DEFAULT', 'IUP_BLACK', 'IUP_BLUE',
+        'IUP_BUTTON1', 'IUP_BUTTON3', 'IUP_CENTER', 'IUP_CYAN', 'IUP_DARK_BLUE',
+        'IUP_DARK_CYAN', 'IUP_DARK_GRAY', 'IUP_DARK_GREY', 'IUP_DARK_GREEN',
+        'IUP_DARK_MAGENTA', 'IUP_DARK_RED', 'IUP_GRAY', 'IUP_GREY', 'IUP_GREEN',
+        'IUP_IGNORE', 'IUP_INDIGO', 'IUP_MAGENTA', 'IUP_MASK_INT',
+        'IUP_MASK_UINT', 'IUP_MOUSEPOS', 'IUP_NAVY', 'IUP_OLIVE', 'IUP_RECTEXT',
+        'IUP_RED', 'IUP_LIGHT_BLUE', 'IUP_LIGHT_GRAY', 'IUP_LIGHT_GREY',
+        'IUP_LIGHT_GREEN', 'IUP_ORANGE', 'IUP_PARCHMENT', 'IUP_PURPLE',
+        'IUP_SILVER', 'IUP_TEAL', 'IUP_VIOLET', 'IUP_WHITE', 'IUP_YELLOW',
+        'K_BS', 'K_cA', 'K_cC', 'K_cD', 'K_cF5', 'K_cK', 'K_cM', 'K_cN', 'K_cO',
+        'K_cP', 'K_cR', 'K_cS', 'K_cT', 'K_cW', 'K_CR', 'K_DEL', 'K_DOWN',
+        'K_END', 'K_ESC', 'K_F1', 'K_F2', 'K_F3', 'K_F4', 'K_F5', 'K_F6',
+        'K_F7', 'K_F8', 'K_F9', 'K_F10', 'K_F11', 'K_F12', 'K_HOME', 'K_INS',
+        'K_LEFT', 'K_MIDDLE', 'K_PGDN', 'K_PGUP', 'K_RIGHT', 'K_SP', 'K_TAB',
+        'K_UP', 'K_h', 'K_i', 'K_j', 'K_p', 'K_r', 'K_s', 'JS', 'LIFO_QUEUE',
+        'LINUX', 'MAX_HEAP', 'MAGENTA', 'MIN_HEAP', 'Nan', 'NO_CURSOR', 'null',
+        'NULL', 'PI', 'pp_Ascii', 'pp_Brkt', 'pp_Date', 'pp_File', 'pp_FltFmt',
+        'pp_Indent', 'pp_IntCh', 'pp_IntFmt', 'pp_Maxlen', 'pp_Nest',
+        'pp_Pause', 'pp_Q22', 'pp_StrFmt', 'RED', 'SEEK_OK', 'SLASH',
+        'TEST_ABORT', 'TEST_CRASH', 'TEST_PAUSE', 'TEST_PAUSE_FAIL',
+        'TEST_QUIET', 'TEST_SHOW_ALL', 'TEST_SHOW_FAILED', 'TEST_SUMMARY',
+        'true', 'True', 'TRUE', 'VC_SCRNLINES', 'WHITE', 'WINDOWS', 'YELLOW'
+    )
+
+    tokens = {
+        'root': [
+            (r"\s+", Whitespace),
+            (r'/\*|--/\*|#\[', Comment.Multiline, 'comment'),
+            (r'(?://|--|#!).*$', Comment.Single),
+#Alt:
+#           (r'//.*$|--.*$|#!.*$', Comment.Single),
+            (r'"([^"\\]|\\.)*"', String.Other),
+            (r'\'[^\']*\'', String.Other),
+            (r'`[^`]*`', String.Other),
+
+            (words(types, prefix=r'\b', suffix=r'\b'), Name.Function),
+            (words(routines, prefix=r'\b', suffix=r'\b'), Name.Function),
+            (words(preproc, prefix=r'\b', suffix=r'\b'), Keyword.Declaration),
+            (words(keywords, prefix=r'\b', suffix=r'\b'), Keyword.Declaration),
+            (words(constants, prefix=r'\b', suffix=r'\b'), Name.Constant),
+            # Aside: Phix only supports/uses the ascii/non-unicode tilde
+            (r'!=|==|<<|>>|:=|[-~+/*%=<>&^|\.(){},?:\[\]$\\;#]', Operator),
+            (r'[\w-]+', Text)
+        ],
+        'comment': [
+            (r'[^*/#]+', Comment.Multiline),
+            (r'/\*|#\[', Comment.Multiline, '#push'),
+            (r'\*/|#\]', Comment.Multiline, '#pop'),
+            (r'[*/#]', Comment.Multiline)
+        ]
+    }
diff --git a/pygments/lexers/php.py b/pygments/lexers/php.py
new file mode 100644
index 0000000000000000000000000000000000000000..82d4aeb3d9d8eedca8cd33fa63443a5936dfc8e1
--- /dev/null
+++ b/pygments/lexers/php.py
@@ -0,0 +1,334 @@
+"""
+    pygments.lexers.php
+    ~~~~~~~~~~~~~~~~~~~
+
+    Lexers for PHP and related languages.
+
+    :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS.
+    :license: BSD, see LICENSE for details.
+"""
+
+import re
+
+from pygments.lexer import Lexer, RegexLexer, include, bygroups, default, \
+    using, this, words, do_insertions, line_re
+from pygments.token import Text, Comment, Operator, Keyword, Name, String, \
+    Number, Punctuation, Other, Generic
+from pygments.util import get_bool_opt, get_list_opt, shebang_matches
+
+__all__ = ['ZephirLexer', 'PsyshConsoleLexer', 'PhpLexer']
+
+
+class ZephirLexer(RegexLexer):
+    """
+    For Zephir language source code.
+
+    Zephir is a compiled high level language aimed
+    to the creation of C-extensions for PHP.
+    """
+
+    name = 'Zephir'
+    url = 'http://zephir-lang.com/'
+    aliases = ['zephir']
+    filenames = ['*.zep']
+    version_added = '2.0'
+
+    zephir_keywords = ['fetch', 'echo', 'isset', 'empty']
+    zephir_type = ['bit', 'bits', 'string']
+
+    flags = re.DOTALL | re.MULTILINE
+
+    tokens = {
+        'commentsandwhitespace': [
+            (r'\s+', Text),
+            (r'//.*?\n', Comment.Single),
+            (r'/\*.*?\*/', Comment.Multiline)
+        ],
+        'slashstartsregex': [
+            include('commentsandwhitespace'),
+            (r'/(\\.|[^[/\\\n]|\[(\\.|[^\]\\\n])*])+/'
+             r'([gim]+\b|\B)', String.Regex, '#pop'),
+            (r'/', Operator, '#pop'),
+            default('#pop')
+        ],
+        'badregex': [
+            (r'\n', Text, '#pop')
+        ],
+        'root': [
+            (r'^(?=\s|/)', Text, 'slashstartsregex'),
+            include('commentsandwhitespace'),
+            (r'\+\+|--|~|&&|\?|:|\|\||\\(?=\n)|'
+             r'(<<|>>>?|==?|!=?|->|[-<>+*%&|^/])=?', Operator, 'slashstartsregex'),
+            (r'[{(\[;,]', Punctuation, 'slashstartsregex'),
+            (r'[})\].]', Punctuation),
+            (r'(for|in|while|do|break|return|continue|switch|case|default|if|else|loop|'
+             r'require|inline|throw|try|catch|finally|new|delete|typeof|instanceof|void|'
+             r'namespace|use|extends|this|fetch|isset|unset|echo|fetch|likely|unlikely|'
+             r'empty)\b', Keyword, 'slashstartsregex'),
+            (r'(var|let|with|function)\b', Keyword.Declaration, 'slashstartsregex'),
+            (r'(abstract|boolean|bool|char|class|const|double|enum|export|extends|final|'
+             r'native|goto|implements|import|int|string|interface|long|ulong|char|uchar|'
+             r'float|unsigned|private|protected|public|short|static|self|throws|reverse|'
+             r'transient|volatile|readonly)\b', Keyword.Reserved),
+            (r'(true|false|null|undefined)\b', Keyword.Constant),
+            (r'(Array|Boolean|Date|_REQUEST|_COOKIE|_SESSION|'
+             r'_GET|_POST|_SERVER|this|stdClass|range|count|iterator|'
+             r'window)\b', Name.Builtin),
+            (r'[$a-zA-Z_][\w\\]*', Name.Other),
+            (r'[0-9][0-9]*\.[0-9]+([eE][0-9]+)?[fd]?', Number.Float),
+            (r'0x[0-9a-fA-F]+', Number.Hex),
+            (r'[0-9]+', Number.Integer),
+            (r'"(\\\\|\\[^\\]|[^"\\])*"', String.Double),
+            (r"'(\\\\|\\[^\\]|[^'\\])*'", String.Single),
+        ]
+    }
+
+
+class PsyshConsoleLexer(Lexer):
+    """
+    For PsySH console output, such as:
+
+    .. sourcecode:: psysh
+
+        >>> $greeting = function($name): string {
+        ...     return "Hello, {$name}";
+        ... };
+        => Closure($name): string {#2371 …3}
+        >>> $greeting('World')
+        => "Hello, World"
+    """
+    name = 'PsySH console session for PHP'
+    url = 'https://psysh.org/'
+    aliases = ['psysh']
+    version_added = '2.7'
+
+    def __init__(self, **options):
+        options['startinline'] = True
+        Lexer.__init__(self, **options)
+
+    def get_tokens_unprocessed(self, text):
+        phplexer = PhpLexer(**self.options)
+        curcode = ''
+        insertions = []
+        for match in line_re.finditer(text):
+            line = match.group()
+            if line.startswith('>>> ') or line.startswith('... '):
+                insertions.append((len(curcode),
+                                   [(0, Generic.Prompt, line[:4])]))
+                curcode += line[4:]
+            elif line.rstrip() == '...':
+                insertions.append((len(curcode),
+                                   [(0, Generic.Prompt, '...')]))
+                curcode += line[3:]
+            else:
+                if curcode:
+                    yield from do_insertions(
+                        insertions, phplexer.get_tokens_unprocessed(curcode))
+                    curcode = ''
+                    insertions = []
+                yield match.start(), Generic.Output, line
+        if curcode:
+            yield from do_insertions(insertions,
+                                     phplexer.get_tokens_unprocessed(curcode))
+
+
+class PhpLexer(RegexLexer):
+    """
+    For PHP source code.
+    For PHP embedded in HTML, use the `HtmlPhpLexer`.
+
+    Additional options accepted:
+
+    `startinline`
+        If given and ``True`` the lexer starts highlighting with
+        php code (i.e.: no starting ``>> from pygments.lexers._php_builtins import MODULES
+            >>> MODULES.keys()
+            ['PHP Options/Info', 'Zip', 'dba', ...]
+
+        In fact the names of those modules match the module names from
+        the php documentation.
+    """
+
+    name = 'PHP'
+    url = 'https://www.php.net/'
+    aliases = ['php', 'php3', 'php4', 'php5']
+    filenames = ['*.php', '*.php[345]', '*.inc']
+    mimetypes = ['text/x-php']
+    version_added = ''
+
+    # Note that a backslash is included, PHP uses a backslash as a namespace
+    # separator.
+    _ident_inner = r'(?:[\\_a-z]|[^\x00-\x7f])(?:[\\\w]|[^\x00-\x7f])*'
+    # But not inside strings.
+    _ident_nons = r'(?:[_a-z]|[^\x00-\x7f])(?:\w|[^\x00-\x7f])*'
+
+    flags = re.IGNORECASE | re.DOTALL | re.MULTILINE
+    tokens = {
+        'root': [
+            (r'<\?(php)?', Comment.Preproc, 'php'),
+            (r'[^<]+', Other),
+            (r'<', Other)
+        ],
+        'php': [
+            (r'\?>', Comment.Preproc, '#pop'),
+            (r'(<<<)([\'"]?)(' + _ident_nons + r')(\2\n.*?\n\s*)(\3)(;?)(\n)',
+             bygroups(String, String, String.Delimiter, String, String.Delimiter,
+                      Punctuation, Text)),
+            (r'\s+', Text),
+            (r'#\[', Punctuation, 'attribute'),
+            (r'#.*?\n', Comment.Single),
+            (r'//.*?\n', Comment.Single),
+            # put the empty comment here, it is otherwise seen as
+            # the start of a docstring
+            (r'/\*\*/', Comment.Multiline),
+            (r'/\*\*.*?\*/', String.Doc),
+            (r'/\*.*?\*/', Comment.Multiline),
+            (r'(->|::)(\s*)(' + _ident_nons + ')',
+             bygroups(Operator, Text, Name.Attribute)),
+            (r'[~!%^&*+=|:.<>/@-]+', Operator),
+            (r'\?', Operator),  # don't add to the charclass above!
+            (r'[\[\]{}();,]+', Punctuation),
+            (r'(new)(\s+)(class)\b', bygroups(Keyword, Text, Keyword)),
+            (r'(class)(\s+)', bygroups(Keyword, Text), 'classname'),
+            (r'(function)(\s*)(?=\()', bygroups(Keyword, Text)),
+            (r'(function)(\s+)(&?)(\s*)',
+             bygroups(Keyword, Text, Operator, Text), 'functionname'),
+            (r'(const)(\s+)(' + _ident_inner + ')',
+             bygroups(Keyword, Text, Name.Constant)),
+            (r'(and|E_PARSE|old_function|E_ERROR|or|as|E_WARNING|parent|'
+             r'eval|PHP_OS|break|exit|case|extends|PHP_VERSION|cfunction|'
+             r'FALSE|print|for|require|continue|foreach|require_once|'
+             r'declare|return|default|static|do|switch|die|stdClass|'
+             r'echo|else|TRUE|elseif|var|empty|if|xor|enddeclare|include|'
+             r'virtual|endfor|include_once|while|endforeach|global|'
+             r'endif|list|endswitch|new|endwhile|not|'
+             r'array|E_ALL|NULL|final|php_user_filter|interface|'
+             r'implements|public|private|protected|abstract|clone|try|'
+             r'catch|throw|this|use|namespace|trait|yield|'
+             r'finally|match)\b', Keyword),
+            (r'(true|false|null)\b', Keyword.Constant),
+            include('magicconstants'),
+            (r'\$\{', Name.Variable, 'variablevariable'),
+            (r'\$+' + _ident_inner, Name.Variable),
+            (_ident_inner, Name.Other),
+            (r'(\d+\.\d*|\d*\.\d+)(e[+-]?[0-9]+)?', Number.Float),
+            (r'\d+e[+-]?[0-9]+', Number.Float),
+            (r'0[0-7]+', Number.Oct),
+            (r'0x[a-f0-9]+', Number.Hex),
+            (r'\d+', Number.Integer),
+            (r'0b[01]+', Number.Bin),
+            (r"'([^'\\]*(?:\\.[^'\\]*)*)'", String.Single),
+            (r'`([^`\\]*(?:\\.[^`\\]*)*)`', String.Backtick),
+            (r'"', String.Double, 'string'),
+        ],
+        'variablevariable': [
+            (r'\}', Name.Variable, '#pop'),
+            include('php')
+        ],
+        'magicfuncs': [
+            # source: http://php.net/manual/en/language.oop5.magic.php
+            (words((
+                '__construct', '__destruct', '__call', '__callStatic', '__get', '__set',
+                '__isset', '__unset', '__sleep', '__wakeup', '__toString', '__invoke',
+                '__set_state', '__clone', '__debugInfo',), suffix=r'\b'),
+             Name.Function.Magic),
+        ],
+        'magicconstants': [
+            # source: http://php.net/manual/en/language.constants.predefined.php
+            (words((
+                '__LINE__', '__FILE__', '__DIR__', '__FUNCTION__', '__CLASS__',
+                '__TRAIT__', '__METHOD__', '__NAMESPACE__',),
+                suffix=r'\b'),
+             Name.Constant),
+        ],
+        'classname': [
+            (_ident_inner, Name.Class, '#pop')
+        ],
+        'functionname': [
+            include('magicfuncs'),
+            (_ident_inner, Name.Function, '#pop'),
+            default('#pop')
+        ],
+        'string': [
+            (r'"', String.Double, '#pop'),
+            (r'[^{$"\\]+', String.Double),
+            (r'\\([nrt"$\\]|[0-7]{1,3}|x[0-9a-f]{1,2})', String.Escape),
+            (r'\$' + _ident_nons + r'(\[\S+?\]|->' + _ident_nons + ')?',
+             String.Interpol),
+            (r'(\{\$\{)(.*?)(\}\})',
+             bygroups(String.Interpol, using(this, _startinline=True),
+                      String.Interpol)),
+            (r'(\{)(\$.*?)(\})',
+             bygroups(String.Interpol, using(this, _startinline=True),
+                      String.Interpol)),
+            (r'(\$\{)(\S+)(\})',
+             bygroups(String.Interpol, Name.Variable, String.Interpol)),
+            (r'[${\\]', String.Double)
+        ],
+        'attribute': [
+            (r'\]', Punctuation, '#pop'),
+            (r'\(', Punctuation, 'attributeparams'),
+            (_ident_inner, Name.Decorator),
+            include('php')
+        ],
+        'attributeparams': [
+            (r'\)', Punctuation, '#pop'),
+            include('php')
+        ],
+    }
+
+    def __init__(self, **options):
+        self.funcnamehighlighting = get_bool_opt(
+            options, 'funcnamehighlighting', True)
+        self.disabledmodules = get_list_opt(
+            options, 'disabledmodules', ['unknown'])
+        self.startinline = get_bool_opt(options, 'startinline', False)
+
+        # private option argument for the lexer itself
+        if '_startinline' in options:
+            self.startinline = options.pop('_startinline')
+
+        # collect activated functions in a set
+        self._functions = set()
+        if self.funcnamehighlighting:
+            from pygments.lexers._php_builtins import MODULES
+            for key, value in MODULES.items():
+                if key not in self.disabledmodules:
+                    self._functions.update(value)
+        RegexLexer.__init__(self, **options)
+
+    def get_tokens_unprocessed(self, text):
+        stack = ['root']
+        if self.startinline:
+            stack.append('php')
+        for index, token, value in \
+                RegexLexer.get_tokens_unprocessed(self, text, stack):
+            if token is Name.Other:
+                if value in self._functions:
+                    yield index, Name.Builtin, value
+                    continue
+            yield index, token, value
+
+    def analyse_text(text):
+        if shebang_matches(text, r'php'):
+            return True
+        rv = 0.0
+        if re.search(r'<\?(?!xml)', text):
+            rv += 0.3
+        return rv
diff --git a/pygments/lexers/pointless.py b/pygments/lexers/pointless.py
new file mode 100644
index 0000000000000000000000000000000000000000..adedb757d891928fe415b11dfae0917f430a7a6d
--- /dev/null
+++ b/pygments/lexers/pointless.py
@@ -0,0 +1,70 @@
+"""
+    pygments.lexers.pointless
+    ~~~~~~~~~~~~~~~~~~~~~~~~~
+
+    Lexers for Pointless.
+
+    :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS.
+    :license: BSD, see LICENSE for details.
+"""
+
+from pygments.lexer import RegexLexer, words
+from pygments.token import Comment, Error, Keyword, Name, Number, Operator, \
+    Punctuation, String, Text
+
+__all__ = ['PointlessLexer']
+
+
+class PointlessLexer(RegexLexer):
+    """
+    For Pointless source code.
+    """
+
+    name = 'Pointless'
+    url = 'https://ptls.dev'
+    aliases = ['pointless']
+    filenames = ['*.ptls']
+    version_added = '2.7'
+
+    ops = words([
+        "+", "-", "*", "/", "**", "%", "+=", "-=", "*=",
+        "/=", "**=", "%=", "|>", "=", "==", "!=", "<", ">",
+        "<=", ">=", "=>", "$", "++",
+    ])
+
+    keywords = words([
+        "if", "then", "else", "where", "with", "cond",
+        "case", "and", "or", "not", "in", "as", "for",
+        "requires", "throw", "try", "catch", "when",
+        "yield", "upval",
+    ], suffix=r'\b')
+
+    tokens = {
+        'root': [
+            (r'[ \n\r]+', Text),
+            (r'--.*$', Comment.Single),
+            (r'"""', String, 'multiString'),
+            (r'"', String, 'string'),
+            (r'[\[\](){}:;,.]', Punctuation),
+            (ops, Operator),
+            (keywords, Keyword),
+            (r'\d+|\d*\.\d+', Number),
+            (r'(true|false)\b', Name.Builtin),
+            (r'[A-Z][a-zA-Z0-9]*\b', String.Symbol),
+            (r'output\b', Name.Variable.Magic),
+            (r'(export|import)\b', Keyword.Namespace),
+            (r'[a-z][a-zA-Z0-9]*\b', Name.Variable)
+        ],
+        'multiString': [
+            (r'\\.', String.Escape),
+            (r'"""', String, '#pop'),
+            (r'"', String),
+            (r'[^\\"]+', String),
+        ],
+        'string': [
+            (r'\\.', String.Escape),
+            (r'"', String, '#pop'),
+            (r'\n', Error),
+            (r'[^\\"]+', String),
+        ],
+    }
diff --git a/pygments/lexers/pony.py b/pygments/lexers/pony.py
new file mode 100644
index 0000000000000000000000000000000000000000..055423a4694c3aa9ecb5f3ff6e64d17d6487be93
--- /dev/null
+++ b/pygments/lexers/pony.py
@@ -0,0 +1,93 @@
+"""
+    pygments.lexers.pony
+    ~~~~~~~~~~~~~~~~~~~~
+
+    Lexers for Pony and related languages.
+
+    :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS.
+    :license: BSD, see LICENSE for details.
+"""
+
+from pygments.lexer import RegexLexer, bygroups, words
+from pygments.token import Text, Comment, Operator, Keyword, Name, String, \
+    Number, Punctuation
+
+__all__ = ['PonyLexer']
+
+
+class PonyLexer(RegexLexer):
+    """
+    For Pony source code.
+    """
+
+    name = 'Pony'
+    aliases = ['pony']
+    filenames = ['*.pony']
+    url = 'https://www.ponylang.io'
+    version_added = '2.4'
+
+    _caps = r'(iso|trn|ref|val|box|tag)'
+
+    tokens = {
+        'root': [
+            (r'\n', Text),
+            (r'[^\S\n]+', Text),
+            (r'//.*\n', Comment.Single),
+            (r'/\*', Comment.Multiline, 'nested_comment'),
+            (r'"""(?:.|\n)*?"""', String.Doc),
+            (r'"', String, 'string'),
+            (r'\'.*\'', String.Char),
+            (r'=>|[]{}:().~;,|&!^?[]', Punctuation),
+            (words((
+                'addressof', 'and', 'as', 'consume', 'digestof', 'is', 'isnt',
+                'not', 'or'),
+                suffix=r'\b'),
+             Operator.Word),
+            (r'!=|==|<<|>>|[-+/*%=<>]', Operator),
+            (words((
+                'box', 'break', 'compile_error', 'compile_intrinsic',
+                'continue', 'do', 'else', 'elseif', 'embed', 'end', 'error',
+                'for', 'if', 'ifdef', 'in', 'iso', 'lambda', 'let', 'match',
+                'object', 'recover', 'ref', 'repeat', 'return', 'tag', 'then',
+                'this', 'trn', 'try', 'until', 'use', 'var', 'val', 'where',
+                'while', 'with', '#any', '#read', '#send', '#share'),
+                suffix=r'\b'),
+             Keyword),
+            (r'(actor|class|struct|primitive|interface|trait|type)((?:\s)+)',
+             bygroups(Keyword, Text), 'typename'),
+            (r'(new|fun|be)((?:\s)+)', bygroups(Keyword, Text), 'methodname'),
+            (words((
+                'I8', 'U8', 'I16', 'U16', 'I32', 'U32', 'I64', 'U64', 'I128',
+                'U128', 'ILong', 'ULong', 'ISize', 'USize', 'F32', 'F64',
+                'Bool', 'Pointer', 'None', 'Any', 'Array', 'String',
+                'Iterator'),
+                suffix=r'\b'),
+             Name.Builtin.Type),
+            (r'_?[A-Z]\w*', Name.Type),
+            (r'(\d+\.\d*|\.\d+|\d+)[eE][+-]?\d+', Number.Float),
+            (r'0x[0-9a-fA-F]+', Number.Hex),
+            (r'\d+', Number.Integer),
+            (r'(true|false)\b', Name.Builtin),
+            (r'_\d*', Name),
+            (r'_?[a-z][\w\']*', Name)
+        ],
+        'typename': [
+            (_caps + r'?((?:\s)*)(_?[A-Z]\w*)',
+             bygroups(Keyword, Text, Name.Class), '#pop')
+        ],
+        'methodname': [
+            (_caps + r'?((?:\s)*)(_?[a-z]\w*)',
+             bygroups(Keyword, Text, Name.Function), '#pop')
+        ],
+        'nested_comment': [
+            (r'[^*/]+', Comment.Multiline),
+            (r'/\*', Comment.Multiline, '#push'),
+            (r'\*/', Comment.Multiline, '#pop'),
+            (r'[*/]', Comment.Multiline)
+        ],
+        'string': [
+            (r'"', String, '#pop'),
+            (r'\\"', String),
+            (r'[^\\"]+', String)
+        ]
+    }
diff --git a/pygments/lexers/praat.py b/pygments/lexers/praat.py
new file mode 100644
index 0000000000000000000000000000000000000000..054f5b61e56a53d2bf04d2bba2fc15c1fa85566d
--- /dev/null
+++ b/pygments/lexers/praat.py
@@ -0,0 +1,303 @@
+"""
+    pygments.lexers.praat
+    ~~~~~~~~~~~~~~~~~~~~~
+
+    Lexer for Praat
+
+    :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS.
+    :license: BSD, see LICENSE for details.
+"""
+
+from pygments.lexer import RegexLexer, words, bygroups, include
+from pygments.token import Name, Text, Comment, Keyword, String, Punctuation, \
+    Number, Operator, Whitespace
+
+__all__ = ['PraatLexer']
+
+
+class PraatLexer(RegexLexer):
+    """
+    For Praat scripts.
+    """
+
+    name = 'Praat'
+    url = 'http://www.praat.org'
+    aliases = ['praat']
+    filenames = ['*.praat', '*.proc', '*.psc']
+    version_added = '2.1'
+
+    keywords = (
+        'if', 'then', 'else', 'elsif', 'elif', 'endif', 'fi', 'for', 'from', 'to',
+        'endfor', 'endproc', 'while', 'endwhile', 'repeat', 'until', 'select', 'plus',
+        'minus', 'demo', 'assert', 'stopwatch', 'nocheck', 'nowarn', 'noprogress',
+        'editor', 'endeditor', 'clearinfo',
+    )
+
+    functions_string = (
+        'backslashTrigraphsToUnicode', 'chooseDirectory', 'chooseReadFile',
+        'chooseWriteFile', 'date', 'demoKey', 'do', 'environment', 'extractLine',
+        'extractWord', 'fixed', 'info', 'left', 'mid', 'percent', 'readFile', 'replace',
+        'replace_regex', 'right', 'selected', 'string', 'unicodeToBackslashTrigraphs',
+    )
+
+    functions_numeric = (
+        'abs', 'appendFile', 'appendFileLine', 'appendInfo', 'appendInfoLine', 'arccos',
+        'arccosh', 'arcsin', 'arcsinh', 'arctan', 'arctan2', 'arctanh', 'barkToHertz',
+        'beginPause', 'beginSendPraat', 'besselI', 'besselK', 'beta', 'beta2',
+        'binomialP', 'binomialQ', 'boolean', 'ceiling', 'chiSquareP', 'chiSquareQ',
+        'choice', 'comment', 'cos', 'cosh', 'createDirectory', 'deleteFile',
+        'demoClicked', 'demoClickedIn', 'demoCommandKeyPressed',
+        'demoExtraControlKeyPressed', 'demoInput', 'demoKeyPressed',
+        'demoOptionKeyPressed', 'demoShiftKeyPressed', 'demoShow', 'demoWaitForInput',
+        'demoWindowTitle', 'demoX', 'demoY', 'differenceLimensToPhon', 'do', 'editor',
+        'endPause', 'endSendPraat', 'endsWith', 'erb', 'erbToHertz', 'erf', 'erfc',
+        'exitScript', 'exp', 'extractNumber', 'fileReadable', 'fisherP', 'fisherQ',
+        'floor', 'gaussP', 'gaussQ', 'hertzToBark', 'hertzToErb', 'hertzToMel',
+        'hertzToSemitones', 'imax', 'imin', 'incompleteBeta', 'incompleteGammaP', 'index',
+        'index_regex', 'integer', 'invBinomialP', 'invBinomialQ', 'invChiSquareQ', 'invFisherQ',
+        'invGaussQ', 'invSigmoid', 'invStudentQ', 'length', 'ln', 'lnBeta', 'lnGamma',
+        'log10', 'log2', 'max', 'melToHertz', 'min', 'minusObject', 'natural', 'number',
+        'numberOfColumns', 'numberOfRows', 'numberOfSelected', 'objectsAreIdentical',
+        'option', 'optionMenu', 'pauseScript', 'phonToDifferenceLimens', 'plusObject',
+        'positive', 'randomBinomial', 'randomGauss', 'randomInteger', 'randomPoisson',
+        'randomUniform', 'real', 'readFile', 'removeObject', 'rindex', 'rindex_regex',
+        'round', 'runScript', 'runSystem', 'runSystem_nocheck', 'selectObject',
+        'selected', 'semitonesToHertz', 'sentence', 'sentencetext', 'sigmoid', 'sin', 'sinc',
+        'sincpi', 'sinh', 'soundPressureToPhon', 'sqrt', 'startsWith', 'studentP',
+        'studentQ', 'tan', 'tanh', 'text', 'variableExists', 'word', 'writeFile', 'writeFileLine',
+        'writeInfo', 'writeInfoLine',
+    )
+
+    functions_array = (
+        'linear', 'randomGauss', 'randomInteger', 'randomUniform', 'zero',
+    )
+
+    objects = (
+        'Activation', 'AffineTransform', 'AmplitudeTier', 'Art', 'Artword',
+        'Autosegment', 'BarkFilter', 'BarkSpectrogram', 'CCA', 'Categories',
+        'Cepstrogram', 'Cepstrum', 'Cepstrumc', 'ChebyshevSeries', 'ClassificationTable',
+        'Cochleagram', 'Collection', 'ComplexSpectrogram', 'Configuration', 'Confusion',
+        'ContingencyTable', 'Corpus', 'Correlation', 'Covariance',
+        'CrossCorrelationTable', 'CrossCorrelationTables', 'DTW', 'DataModeler',
+        'Diagonalizer', 'Discriminant', 'Dissimilarity', 'Distance', 'Distributions',
+        'DurationTier', 'EEG', 'ERP', 'ERPTier', 'EditCostsTable', 'EditDistanceTable',
+        'Eigen', 'Excitation', 'Excitations', 'ExperimentMFC', 'FFNet', 'FeatureWeights',
+        'FileInMemory', 'FilesInMemory', 'Formant', 'FormantFilter', 'FormantGrid',
+        'FormantModeler', 'FormantPoint', 'FormantTier', 'GaussianMixture', 'HMM',
+        'HMM_Observation', 'HMM_ObservationSequence', 'HMM_State', 'HMM_StateSequence',
+        'Harmonicity', 'ISpline', 'Index', 'Intensity', 'IntensityTier', 'IntervalTier',
+        'KNN', 'KlattGrid', 'KlattTable', 'LFCC', 'LPC', 'Label', 'LegendreSeries',
+        'LinearRegression', 'LogisticRegression', 'LongSound', 'Ltas', 'MFCC', 'MSpline',
+        'ManPages', 'Manipulation', 'Matrix', 'MelFilter', 'MelSpectrogram',
+        'MixingMatrix', 'Movie', 'Network', 'Object', 'OTGrammar', 'OTHistory', 'OTMulti',
+        'PCA', 'PairDistribution', 'ParamCurve', 'Pattern', 'Permutation', 'Photo',
+        'Pitch', 'PitchModeler', 'PitchTier', 'PointProcess', 'Polygon', 'Polynomial',
+        'PowerCepstrogram', 'PowerCepstrum', 'Procrustes', 'RealPoint', 'RealTier',
+        'ResultsMFC', 'Roots', 'SPINET', 'SSCP', 'SVD', 'Salience', 'ScalarProduct',
+        'Similarity', 'SimpleString', 'SortedSetOfString', 'Sound', 'Speaker',
+        'Spectrogram', 'Spectrum', 'SpectrumTier', 'SpeechSynthesizer', 'SpellingChecker',
+        'Strings', 'StringsIndex', 'Table', 'TableOfReal', 'TextGrid', 'TextInterval',
+        'TextPoint', 'TextTier', 'Tier', 'Transition', 'VocalTract', 'VocalTractTier',
+        'Weight', 'WordList',
+    )
+
+    variables_numeric = (
+        'macintosh', 'windows', 'unix', 'praatVersion', 'pi', 'e', 'undefined',
+    )
+
+    variables_string = (
+        'praatVersion', 'tab', 'shellDirectory', 'homeDirectory',
+        'preferencesDirectory', 'newline', 'temporaryDirectory',
+        'defaultDirectory',
+    )
+
+    object_attributes = (
+        'ncol', 'nrow', 'xmin', 'ymin', 'xmax', 'ymax', 'nx', 'ny', 'dx', 'dy',
+    )
+
+    tokens = {
+        'root': [
+            (r'(\s+)(#.*?$)',  bygroups(Whitespace, Comment.Single)),
+            (r'^#.*?$',        Comment.Single),
+            (r';[^\n]*',       Comment.Single),
+            (r'\s+',           Whitespace),
+
+            (r'\bprocedure\b', Keyword,       'procedure_definition'),
+            (r'\bcall\b',      Keyword,       'procedure_call'),
+            (r'@',             Name.Function, 'procedure_call'),
+
+            include('function_call'),
+
+            (words(keywords, suffix=r'\b'), Keyword),
+
+            (r'(\bform\b)(\s+)([^\n]+)',
+             bygroups(Keyword, Whitespace, String), 'old_form'),
+
+            (r'(print(?:line|tab)?|echo|exit|asserterror|pause|send(?:praat|socket)|'
+             r'include|execute|system(?:_nocheck)?)(\s+)',
+             bygroups(Keyword, Whitespace), 'string_unquoted'),
+
+            (r'(goto|label)(\s+)(\w+)', bygroups(Keyword, Whitespace, Name.Label)),
+
+            include('variable_name'),
+            include('number'),
+
+            (r'"', String, 'string'),
+
+            (words((objects), suffix=r'(?=\s+\S+\n)'), Name.Class, 'string_unquoted'),
+
+            (r'\b[A-Z]', Keyword, 'command'),
+            (r'(\.{3}|[)(,])', Punctuation),
+        ],
+        'command': [
+            (r'( ?[\w()-]+ ?)', Keyword),
+
+            include('string_interpolated'),
+
+            (r'\.{3}', Keyword, ('#pop', 'old_arguments')),
+            (r':', Keyword, ('#pop', 'comma_list')),
+            (r'\s', Whitespace, '#pop'),
+        ],
+        'procedure_call': [
+            (r'\s+', Whitespace),
+            (r'([\w.]+)(?:(:)|(?:(\s*)(\()))',
+             bygroups(Name.Function, Punctuation,
+                      Text.Whitespace, Punctuation), '#pop'),
+            (r'([\w.]+)', Name.Function, ('#pop', 'old_arguments')),
+        ],
+        'procedure_definition': [
+            (r'\s', Whitespace),
+            (r'([\w.]+)(\s*?[(:])',
+             bygroups(Name.Function, Whitespace), '#pop'),
+            (r'([\w.]+)([^\n]*)',
+             bygroups(Name.Function, Text), '#pop'),
+        ],
+        'function_call': [
+            (words(functions_string, suffix=r'\$(?=\s*[:(])'), Name.Function, 'function'),
+            (words(functions_array, suffix=r'#(?=\s*[:(])'),   Name.Function, 'function'),
+            (words(functions_numeric, suffix=r'(?=\s*[:(])'),  Name.Function, 'function'),
+        ],
+        'function': [
+            (r'\s+',   Whitespace),
+            (r':',     Punctuation, ('#pop', 'comma_list')),
+            (r'\s*\(', Punctuation, ('#pop', 'comma_list')),
+        ],
+        'comma_list': [
+            (r'(\s*\n\s*)(\.{3})', bygroups(Whitespace, Punctuation)),
+
+            (r'(\s*)(?:([)\]])|(\n))', bygroups(
+                Whitespace, Punctuation, Whitespace), '#pop'),
+
+            (r'\s+', Whitespace),
+            (r'"',   String, 'string'),
+            (r'\b(if|then|else|fi|endif)\b', Keyword),
+
+            include('function_call'),
+            include('variable_name'),
+            include('operator'),
+            include('number'),
+
+            (r'[()]', Text),
+            (r',', Punctuation),
+        ],
+        'old_arguments': [
+            (r'\n', Whitespace, '#pop'),
+
+            include('variable_name'),
+            include('operator'),
+            include('number'),
+
+            (r'"', String, 'string'),
+            (r'[^\n]', Text),
+        ],
+        'number': [
+            (r'\n', Whitespace, '#pop'),
+            (r'\b\d+(\.\d*)?([eE][-+]?\d+)?%?', Number),
+        ],
+        'object_reference': [
+            include('string_interpolated'),
+            (r'([a-z][a-zA-Z0-9_]*|\d+)', Name.Builtin),
+
+            (words(object_attributes, prefix=r'\.'), Name.Builtin, '#pop'),
+
+            (r'\$', Name.Builtin),
+            (r'\[', Text, '#pop'),
+        ],
+        'variable_name': [
+            include('operator'),
+            include('number'),
+
+            (words(variables_string,  suffix=r'\$'), Name.Variable.Global),
+            (words(variables_numeric,
+             suffix=r'(?=[^a-zA-Z0-9_."\'$#\[:(]|\s|^|$)'),
+             Name.Variable.Global),
+
+            (words(objects, prefix=r'\b', suffix=r"(_)"),
+             bygroups(Name.Builtin, Name.Builtin),
+             'object_reference'),
+
+            (r'\.?_?[a-z][\w.]*(\$|#)?', Text),
+            (r'[\[\]]', Punctuation, 'comma_list'),
+
+            include('string_interpolated'),
+        ],
+        'operator': [
+            (r'([+\/*<>=!-]=?|[&*|][&*|]?|\^|<>)',       Operator),
+            (r'(?', Punctuation),
+            (r'"(?:\\x[0-9a-fA-F]+\\|\\u[0-9a-fA-F]{4}|\\U[0-9a-fA-F]{8}|'
+             r'\\[0-7]+\\|\\["\\abcefnrstv]|[^\\"])*"', String.Double),
+            (r"'(?:''|[^'])*'", String.Atom),  # quoted atom
+            # Needs to not be followed by an atom.
+            # (r'=(?=\s|[a-zA-Z\[])', Operator),
+            (r'is\b', Operator),
+            (r'(<|>|=<|>=|==|=:=|=|/|//|\*|\+|-)(?=\s|[a-zA-Z0-9\[])',
+             Operator),
+            (r'(mod|div|not)\b', Operator),
+            (r'_', Keyword),  # The don't-care variable
+            (r'([a-z]+)(:)', bygroups(Name.Namespace, Punctuation)),
+            (r'([a-z\u00c0-\u1fff\u3040-\ud7ff\ue000-\uffef]'
+             r'[\w$\u00c0-\u1fff\u3040-\ud7ff\ue000-\uffef]*)'
+             r'(\s*)(:-|-->)',
+             bygroups(Name.Function, Text, Operator)),  # function defn
+            (r'([a-z\u00c0-\u1fff\u3040-\ud7ff\ue000-\uffef]'
+             r'[\w$\u00c0-\u1fff\u3040-\ud7ff\ue000-\uffef]*)'
+             r'(\s*)(\()',
+             bygroups(Name.Function, Text, Punctuation)),
+            (r'[a-z\u00c0-\u1fff\u3040-\ud7ff\ue000-\uffef]'
+             r'[\w$\u00c0-\u1fff\u3040-\ud7ff\ue000-\uffef]*',
+             String.Atom),  # atom, characters
+            # This one includes !
+            (r'[#&*+\-./:<=>?@\\^~\u00a1-\u00bf\u2010-\u303f]+',
+             String.Atom),  # atom, graphics
+            (r'[A-Z_]\w*', Name.Variable),
+            (r'\s+|[\u2000-\u200f\ufff0-\ufffe\uffef]', Text),
+        ],
+        'nested-comment': [
+            (r'\*/', Comment.Multiline, '#pop'),
+            (r'/\*', Comment.Multiline, '#push'),
+            (r'[^*/]+', Comment.Multiline),
+            (r'[*/]', Comment.Multiline),
+        ],
+    }
+
+    def analyse_text(text):
+        """Competes with IDL and Visual Prolog on *.pro"""
+        if ':-' in text:
+            # Visual Prolog also uses :-
+            return 0.5
+        else:
+            return 0
+
+
+class LogtalkLexer(RegexLexer):
+    """
+    For Logtalk source code.
+    """
+
+    name = 'Logtalk'
+    url = 'http://logtalk.org/'
+    aliases = ['logtalk']
+    filenames = ['*.lgt', '*.logtalk']
+    mimetypes = ['text/x-logtalk']
+    version_added = '0.10'
+
+    tokens = {
+        'root': [
+            # Directives
+            (r'^\s*:-\s', Punctuation, 'directive'),
+            # Comments
+            (r'%.*?\n', Comment),
+            (r'/\*(.|\n)*?\*/', Comment),
+            # Whitespace
+            (r'\n', Text),
+            (r'\s+', Text),
+            # Numbers
+            (r"0'[\\]?.", Number),
+            (r'0b[01]+', Number.Bin),
+            (r'0o[0-7]+', Number.Oct),
+            (r'0x[0-9a-fA-F]+', Number.Hex),
+            (r'\d+\.?\d*((e|E)(\+|-)?\d+)?', Number),
+            # Variables
+            (r'([A-Z_][a-zA-Z0-9_]*)', Name.Variable),
+            # Event handlers
+            (r'(after|before)(?=[(])', Keyword),
+            # Message forwarding handler
+            (r'forward(?=[(])', Keyword),
+            # Execution-context methods
+            (r'(context|parameter|this|se(lf|nder))(?=[(])', Keyword),
+            # Reflection
+            (r'(current_predicate|predicate_property)(?=[(])', Keyword),
+            # DCGs and term expansion
+            (r'(expand_(goal|term)|(goal|term)_expansion|phrase)(?=[(])', Keyword),
+            # Entity
+            (r'(abolish|c(reate|urrent))_(object|protocol|category)(?=[(])', Keyword),
+            (r'(object|protocol|category)_property(?=[(])', Keyword),
+            # Entity relations
+            (r'co(mplements_object|nforms_to_protocol)(?=[(])', Keyword),
+            (r'extends_(object|protocol|category)(?=[(])', Keyword),
+            (r'imp(lements_protocol|orts_category)(?=[(])', Keyword),
+            (r'(instantiat|specializ)es_class(?=[(])', Keyword),
+            # Events
+            (r'(current_event|(abolish|define)_events)(?=[(])', Keyword),
+            # Flags
+            (r'(create|current|set)_logtalk_flag(?=[(])', Keyword),
+            # Compiling, loading, and library paths
+            (r'logtalk_(compile|l(ibrary_path|oad|oad_context)|make(_target_action)?)(?=[(])', Keyword),
+            (r'\blogtalk_make\b', Keyword),
+            # Database
+            (r'(clause|retract(all)?)(?=[(])', Keyword),
+            (r'a(bolish|ssert(a|z))(?=[(])', Keyword),
+            # Control constructs
+            (r'(ca(ll|tch)|throw)(?=[(])', Keyword),
+            (r'(fa(il|lse)|true|(instantiation|system)_error)\b', Keyword),
+            (r'(uninstantiation|type|domain|existence|permission|representation|evaluation|resource|syntax)_error(?=[(])', Keyword),
+            # All solutions
+            (r'((bag|set)of|f(ind|or)all)(?=[(])', Keyword),
+            # Multi-threading predicates
+            (r'threaded(_(ca(ll|ncel)|once|ignore|exit|peek|wait|notify))?(?=[(])', Keyword),
+            # Engine predicates
+            (r'threaded_engine(_(create|destroy|self|next|next_reified|yield|post|fetch))?(?=[(])', Keyword),
+            # Term unification
+            (r'(subsumes_term|unify_with_occurs_check)(?=[(])', Keyword),
+            # Term creation and decomposition
+            (r'(functor|arg|copy_term|numbervars|term_variables)(?=[(])', Keyword),
+            # Evaluable functors
+            (r'(div|rem|m(ax|in|od)|abs|sign)(?=[(])', Keyword),
+            (r'float(_(integer|fractional)_part)?(?=[(])', Keyword),
+            (r'(floor|t(an|runcate)|round|ceiling)(?=[(])', Keyword),
+            # Other arithmetic functors
+            (r'(cos|a(cos|sin|tan|tan2)|exp|log|s(in|qrt)|xor)(?=[(])', Keyword),
+            # Term testing
+            (r'(var|atom(ic)?|integer|float|c(allable|ompound)|n(onvar|umber)|ground|acyclic_term)(?=[(])', Keyword),
+            # Term comparison
+            (r'compare(?=[(])', Keyword),
+            # Stream selection and control
+            (r'(curren|se)t_(in|out)put(?=[(])', Keyword),
+            (r'(open|close)(?=[(])', Keyword),
+            (r'flush_output(?=[(])', Keyword),
+            (r'(at_end_of_stream|flush_output)\b', Keyword),
+            (r'(stream_property|at_end_of_stream|set_stream_position)(?=[(])', Keyword),
+            # Character and byte input/output
+            (r'(nl|(get|peek|put)_(byte|c(har|ode)))(?=[(])', Keyword),
+            (r'\bnl\b', Keyword),
+            # Term input/output
+            (r'read(_term)?(?=[(])', Keyword),
+            (r'write(q|_(canonical|term))?(?=[(])', Keyword),
+            (r'(current_)?op(?=[(])', Keyword),
+            (r'(current_)?char_conversion(?=[(])', Keyword),
+            # Atomic term processing
+            (r'atom_(length|c(hars|o(ncat|des)))(?=[(])', Keyword),
+            (r'(char_code|sub_atom)(?=[(])', Keyword),
+            (r'number_c(har|ode)s(?=[(])', Keyword),
+            # Implementation defined hooks functions
+            (r'(se|curren)t_prolog_flag(?=[(])', Keyword),
+            (r'\bhalt\b', Keyword),
+            (r'halt(?=[(])', Keyword),
+            # Message sending operators
+            (r'(::|:|\^\^)', Operator),
+            # External call
+            (r'[{}]', Keyword),
+            # Logic and control
+            (r'(ignore|once)(?=[(])', Keyword),
+            (r'\brepeat\b', Keyword),
+            # Sorting
+            (r'(key)?sort(?=[(])', Keyword),
+            # Bitwise functors
+            (r'(>>|<<|/\\|\\\\|\\)', Operator),
+            # Predicate aliases
+            (r'\bas\b', Operator),
+            # Arithmetic evaluation
+            (r'\bis\b', Keyword),
+            # Arithmetic comparison
+            (r'(=:=|=\\=|<|=<|>=|>)', Operator),
+            # Term creation and decomposition
+            (r'=\.\.', Operator),
+            # Term unification
+            (r'(=|\\=)', Operator),
+            # Term comparison
+            (r'(==|\\==|@=<|@<|@>=|@>)', Operator),
+            # Evaluable functors
+            (r'(//|[-+*/])', Operator),
+            (r'\b(e|pi|div|mod|rem)\b', Operator),
+            # Other arithmetic functors
+            (r'\b\*\*\b', Operator),
+            # DCG rules
+            (r'-->', Operator),
+            # Control constructs
+            (r'([!;]|->)', Operator),
+            # Logic and control
+            (r'\\+', Operator),
+            # Mode operators
+            (r'[?@]', Operator),
+            # Existential quantifier
+            (r'\^', Operator),
+            # Punctuation
+            (r'[()\[\],.|]', Text),
+            # Atoms
+            (r"[a-z][a-zA-Z0-9_]*", Text),
+            (r"'", String, 'quoted_atom'),
+            # Double-quoted terms
+            (r'"', String, 'double_quoted_term'),
+        ],
+
+        'quoted_atom': [
+            (r"''", String),
+            (r"'", String, '#pop'),
+            (r'\\([\\abfnrtv"\']|(x[a-fA-F0-9]+|[0-7]+)\\)', String.Escape),
+            (r"[^\\'\n]+", String),
+            (r'\\', String),
+        ],
+
+        'double_quoted_term': [
+            (r'""', String),
+            (r'"', String, '#pop'),
+            (r'\\([\\abfnrtv"\']|(x[a-fA-F0-9]+|[0-7]+)\\)', String.Escape),
+            (r'[^\\"\n]+', String),
+            (r'\\', String),
+        ],
+
+        'directive': [
+            # Conditional compilation directives
+            (r'(el)?if(?=[(])', Keyword, 'root'),
+            (r'(e(lse|ndif))(?=[.])', Keyword, 'root'),
+            # Entity directives
+            (r'(category|object|protocol)(?=[(])', Keyword, 'entityrelations'),
+            (r'(end_(category|object|protocol))(?=[.])', Keyword, 'root'),
+            # Predicate scope directives
+            (r'(public|protected|private)(?=[(])', Keyword, 'root'),
+            # Other directives
+            (r'e(n(coding|sure_loaded)|xport)(?=[(])', Keyword, 'root'),
+            (r'in(clude|itialization|fo)(?=[(])', Keyword, 'root'),
+            (r'(built_in|dynamic|synchronized|threaded)(?=[.])', Keyword, 'root'),
+            (r'(alias|d(ynamic|iscontiguous)|m(eta_(non_terminal|predicate)|ode|ultifile)|s(et_(logtalk|prolog)_flag|ynchronized))(?=[(])', Keyword, 'root'),
+            (r'op(?=[(])', Keyword, 'root'),
+            (r'(c(alls|oinductive)|module|reexport|use(s|_module))(?=[(])', Keyword, 'root'),
+            (r'[a-z][a-zA-Z0-9_]*(?=[(])', Text, 'root'),
+            (r'[a-z][a-zA-Z0-9_]*(?=[.])', Text, 'root'),
+        ],
+
+        'entityrelations': [
+            (r'(complements|extends|i(nstantiates|mp(lements|orts))|specializes)(?=[(])', Keyword),
+            # Numbers
+            (r"0'[\\]?.", Number),
+            (r'0b[01]+', Number.Bin),
+            (r'0o[0-7]+', Number.Oct),
+            (r'0x[0-9a-fA-F]+', Number.Hex),
+            (r'\d+\.?\d*((e|E)(\+|-)?\d+)?', Number),
+            # Variables
+            (r'([A-Z_][a-zA-Z0-9_]*)', Name.Variable),
+            # Atoms
+            (r"[a-z][a-zA-Z0-9_]*", Text),
+            (r"'", String, 'quoted_atom'),
+            # Double-quoted terms
+            (r'"', String, 'double_quoted_term'),
+            # End of entity-opening directive
+            (r'([)]\.)', Text, 'root'),
+            # Scope operator
+            (r'(::)', Operator),
+            # Punctuation
+            (r'[()\[\],.|]', Text),
+            # Comments
+            (r'%.*?\n', Comment),
+            (r'/\*(.|\n)*?\*/', Comment),
+            # Whitespace
+            (r'\n', Text),
+            (r'\s+', Text),
+        ]
+    }
+
+    def analyse_text(text):
+        if ':- object(' in text:
+            return 1.0
+        elif ':- protocol(' in text:
+            return 1.0
+        elif ':- category(' in text:
+            return 1.0
+        elif re.search(r'^:-\s[a-z]', text, re.M):
+            return 0.9
+        else:
+            return 0.0
diff --git a/pygments/lexers/promql.py b/pygments/lexers/promql.py
new file mode 100644
index 0000000000000000000000000000000000000000..cad3c254a1202682f4a831d4c09472bce34c33f0
--- /dev/null
+++ b/pygments/lexers/promql.py
@@ -0,0 +1,176 @@
+"""
+    pygments.lexers.promql
+    ~~~~~~~~~~~~~~~~~~~~~~
+
+    Lexer for Prometheus Query Language.
+
+    :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS.
+    :license: BSD, see LICENSE for details.
+"""
+
+from pygments.lexer import RegexLexer, bygroups, default, words
+from pygments.token import Comment, Keyword, Name, Number, Operator, \
+    Punctuation, String, Whitespace
+
+__all__ = ["PromQLLexer"]
+
+
+class PromQLLexer(RegexLexer):
+    """
+    For PromQL queries.
+
+    For details about the grammar see:
+    https://github.com/prometheus/prometheus/tree/master/promql/parser
+
+    .. versionadded: 2.7
+    """
+
+    name = "PromQL"
+    url = 'https://prometheus.io/docs/prometheus/latest/querying/basics/'
+    aliases = ["promql"]
+    filenames = ["*.promql"]
+    version_added = ''
+
+    base_keywords = (
+        words(
+            (
+                "bool",
+                "by",
+                "group_left",
+                "group_right",
+                "ignoring",
+                "offset",
+                "on",
+                "without",
+            ),
+            suffix=r"\b",
+        ),
+        Keyword,
+    )
+
+    aggregator_keywords = (
+        words(
+            (
+                "sum",
+                "min",
+                "max",
+                "avg",
+                "group",
+                "stddev",
+                "stdvar",
+                "count",
+                "count_values",
+                "bottomk",
+                "topk",
+                "quantile",
+            ),
+            suffix=r"\b",
+        ),
+        Keyword,
+    )
+
+    function_keywords = (
+        words(
+            (
+                "abs",
+                "absent",
+                "absent_over_time",
+                "avg_over_time",
+                "ceil",
+                "changes",
+                "clamp_max",
+                "clamp_min",
+                "count_over_time",
+                "day_of_month",
+                "day_of_week",
+                "days_in_month",
+                "delta",
+                "deriv",
+                "exp",
+                "floor",
+                "histogram_quantile",
+                "holt_winters",
+                "hour",
+                "idelta",
+                "increase",
+                "irate",
+                "label_join",
+                "label_replace",
+                "ln",
+                "log10",
+                "log2",
+                "max_over_time",
+                "min_over_time",
+                "minute",
+                "month",
+                "predict_linear",
+                "quantile_over_time",
+                "rate",
+                "resets",
+                "round",
+                "scalar",
+                "sort",
+                "sort_desc",
+                "sqrt",
+                "stddev_over_time",
+                "stdvar_over_time",
+                "sum_over_time",
+                "time",
+                "timestamp",
+                "vector",
+                "year",
+            ),
+            suffix=r"\b",
+        ),
+        Keyword.Reserved,
+    )
+
+    tokens = {
+        "root": [
+            (r"\n", Whitespace),
+            (r"\s+", Whitespace),
+            (r",", Punctuation),
+            # Keywords
+            base_keywords,
+            aggregator_keywords,
+            function_keywords,
+            # Offsets
+            (r"[1-9][0-9]*[smhdwy]", String),
+            # Numbers
+            (r"-?[0-9]+\.[0-9]+", Number.Float),
+            (r"-?[0-9]+", Number.Integer),
+            # Comments
+            (r"#.*?$", Comment.Single),
+            # Operators
+            (r"(\+|\-|\*|\/|\%|\^)", Operator),
+            (r"==|!=|>=|<=|<|>", Operator),
+            (r"and|or|unless", Operator.Word),
+            # Metrics
+            (r"[_a-zA-Z][a-zA-Z0-9_]+", Name.Variable),
+            # Params
+            (r'(["\'])(.*?)(["\'])', bygroups(Punctuation, String, Punctuation)),
+            # Other states
+            (r"\(", Operator, "function"),
+            (r"\)", Operator),
+            (r"\{", Punctuation, "labels"),
+            (r"\[", Punctuation, "range"),
+        ],
+        "labels": [
+            (r"\}", Punctuation, "#pop"),
+            (r"\n", Whitespace),
+            (r"\s+", Whitespace),
+            (r",", Punctuation),
+            (r'([_a-zA-Z][a-zA-Z0-9_]*?)(\s*?)(=~|!=|=|!~)(\s*?)("|\')(.*?)("|\')',
+             bygroups(Name.Label, Whitespace, Operator, Whitespace,
+                      Punctuation, String, Punctuation)),
+        ],
+        "range": [
+            (r"\]", Punctuation, "#pop"),
+            (r"[1-9][0-9]*[smhdwy]", String),
+        ],
+        "function": [
+            (r"\)", Operator, "#pop"),
+            (r"\(", Operator, "#push"),
+            default("#pop"),
+        ],
+    }
diff --git a/pygments/lexers/prql.py b/pygments/lexers/prql.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee95d2d47491a45479aa79a0e7c2c24a4cef01cf
--- /dev/null
+++ b/pygments/lexers/prql.py
@@ -0,0 +1,251 @@
+"""
+    pygments.lexers.prql
+    ~~~~~~~~~~~~~~~~~~~~
+
+    Lexer for the PRQL query language.
+
+    :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS.
+    :license: BSD, see LICENSE for details.
+"""
+
+from pygments.lexer import RegexLexer, combined, words, include, bygroups
+from pygments.token import Comment, Literal, Keyword, Name, Number, Operator, \
+    Punctuation, String, Text, Whitespace
+
+__all__ = ['PrqlLexer']
+
+
+class PrqlLexer(RegexLexer):
+    """
+    For PRQL source code.
+
+    grammar: https://github.com/PRQL/prql/tree/main/grammars
+    """
+
+    name = 'PRQL'
+    url = 'https://prql-lang.org/'
+    aliases = ['prql']
+    filenames = ['*.prql']
+    mimetypes = ['application/prql', 'application/x-prql']
+    version_added = '2.17'
+
+    builtinTypes = words((
+        "bool",
+        "int",
+        "int8", "int16", "int32", "int64", "int128",
+        "float",
+        "text",
+        "set"), suffix=r'\b')
+
+    def innerstring_rules(ttype):
+        return [
+            # the new style '{}'.format(...) string formatting
+            (r'\{'
+             r'((\w+)((\.\w+)|(\[[^\]]+\]))*)?'  # field name
+             r'(\:(.?[<>=\^])?[-+ ]?#?0?(\d+)?,?(\.\d+)?[E-GXb-gnosx%]?)?'
+             r'\}', String.Interpol),
+
+            (r'[^\\\'"%{\n]+', ttype),
+            (r'[\'"\\]', ttype),
+            (r'%|(\{{1,2})', ttype)
+        ]
+
+    def fstring_rules(ttype):
+        return [
+            (r'\}', String.Interpol),
+            (r'\{', String.Interpol, 'expr-inside-fstring'),
+            (r'[^\\\'"{}\n]+', ttype),
+            (r'[\'"\\]', ttype),
+        ]
+
+    tokens = {
+        'root': [
+
+            # Comments
+            (r'#!.*', String.Doc),
+            (r'#.*', Comment.Single),
+
+            # Whitespace
+            (r'\s+', Whitespace),
+
+            # Modules
+            (r'^(\s*)(module)(\s*)',
+             bygroups(Whitespace, Keyword.Namespace, Whitespace),
+             'imports'),
+
+            (builtinTypes, Keyword.Type),
+
+            # Main
+            (r'^prql ', Keyword.Reserved),
+
+            ('let', Keyword.Declaration),
+
+            include('keywords'),
+            include('expr'),
+
+            # Transforms
+            (r'^[A-Za-z_][a-zA-Z0-9_]*', Keyword),
+        ],
+        'expr': [
+            # non-raw f-strings
+            ('(f)(""")', bygroups(String.Affix, String.Double),
+             combined('fstringescape', 'tdqf')),
+            ("(f)(''')", bygroups(String.Affix, String.Single),
+             combined('fstringescape', 'tsqf')),
+            ('(f)(")', bygroups(String.Affix, String.Double),
+             combined('fstringescape', 'dqf')),
+            ("(f)(')", bygroups(String.Affix, String.Single),
+             combined('fstringescape', 'sqf')),
+
+            # non-raw s-strings
+            ('(s)(""")', bygroups(String.Affix, String.Double),
+             combined('stringescape', 'tdqf')),
+            ("(s)(''')", bygroups(String.Affix, String.Single),
+             combined('stringescape', 'tsqf')),
+            ('(s)(")', bygroups(String.Affix, String.Double),
+             combined('stringescape', 'dqf')),
+            ("(s)(')", bygroups(String.Affix, String.Single),
+             combined('stringescape', 'sqf')),
+
+            # raw strings
+            ('(?i)(r)(""")',
+             bygroups(String.Affix, String.Double), 'tdqs'),
+            ("(?i)(r)(''')",
+             bygroups(String.Affix, String.Single), 'tsqs'),
+            ('(?i)(r)(")',
+             bygroups(String.Affix, String.Double), 'dqs'),
+            ("(?i)(r)(')",
+             bygroups(String.Affix, String.Single), 'sqs'),
+
+            # non-raw strings
+            ('"""', String.Double, combined('stringescape', 'tdqs')),
+            ("'''", String.Single, combined('stringescape', 'tsqs')),
+            ('"', String.Double, combined('stringescape', 'dqs')),
+            ("'", String.Single, combined('stringescape', 'sqs')),
+
+            # Time and dates
+            (r'@\d{4}-\d{2}-\d{2}T\d{2}(:\d{2})?(:\d{2})?(\.\d{1,6})?(Z|[+-]\d{1,2}(:\d{1,2})?)?', Literal.Date),
+            (r'@\d{4}-\d{2}-\d{2}', Literal.Date),
+            (r'@\d{2}(:\d{2})?(:\d{2})?(\.\d{1,6})?(Z|[+-]\d{1,2}(:\d{1,2})?)?', Literal.Date),
+
+            (r'[^\S\n]+', Text),
+            include('numbers'),
+            (r'->|=>|==|!=|>=|<=|~=|&&|\|\||\?\?|\/\/', Operator),
+            (r'[-~+/*%=<>&^|.@]', Operator),
+            (r'[]{}:(),;[]', Punctuation),
+            include('functions'),
+
+            # Variable Names
+            (r'[A-Za-z_][a-zA-Z0-9_]*', Name.Variable),
+        ],
+        'numbers': [
+            (r'(\d(?:_?\d)*\.(?:\d(?:_?\d)*)?|(?:\d(?:_?\d)*)?\.\d(?:_?\d)*)'
+             r'([eE][+-]?\d(?:_?\d)*)?', Number.Float),
+            (r'\d(?:_?\d)*[eE][+-]?\d(?:_?\d)*j?', Number.Float),
+            (r'0[oO](?:_?[0-7])+', Number.Oct),
+            (r'0[bB](?:_?[01])+', Number.Bin),
+            (r'0[xX](?:_?[a-fA-F0-9])+', Number.Hex),
+            (r'\d(?:_?\d)*', Number.Integer),
+        ],
+        'fstringescape': [
+            include('stringescape'),
+        ],
+        'bytesescape': [
+            (r'\\([\\bfnrt"\']|\n|x[a-fA-F0-9]{2}|[0-7]{1,3})', String.Escape)
+        ],
+        'stringescape': [
+            (r'\\(N\{.*?\}|u\{[a-fA-F0-9]{1,6}\})', String.Escape),
+            include('bytesescape')
+        ],
+        'fstrings-single': fstring_rules(String.Single),
+        'fstrings-double': fstring_rules(String.Double),
+        'strings-single': innerstring_rules(String.Single),
+        'strings-double': innerstring_rules(String.Double),
+        'dqf': [
+            (r'"', String.Double, '#pop'),
+            (r'\\\\|\\"|\\\n', String.Escape),  # included here for raw strings
+            include('fstrings-double')
+        ],
+        'sqf': [
+            (r"'", String.Single, '#pop'),
+            (r"\\\\|\\'|\\\n", String.Escape),  # included here for raw strings
+            include('fstrings-single')
+        ],
+        'dqs': [
+            (r'"', String.Double, '#pop'),
+            (r'\\\\|\\"|\\\n', String.Escape),  # included here for raw strings
+            include('strings-double')
+        ],
+        'sqs': [
+            (r"'", String.Single, '#pop'),
+            (r"\\\\|\\'|\\\n", String.Escape),  # included here for raw strings
+            include('strings-single')
+        ],
+        'tdqf': [
+            (r'"""', String.Double, '#pop'),
+            include('fstrings-double'),
+            (r'\n', String.Double)
+        ],
+        'tsqf': [
+            (r"'''", String.Single, '#pop'),
+            include('fstrings-single'),
+            (r'\n', String.Single)
+        ],
+        'tdqs': [
+            (r'"""', String.Double, '#pop'),
+            include('strings-double'),
+            (r'\n', String.Double)
+        ],
+        'tsqs': [
+            (r"'''", String.Single, '#pop'),
+            include('strings-single'),
+            (r'\n', String.Single)
+        ],
+
+        'expr-inside-fstring': [
+            (r'[{([]', Punctuation, 'expr-inside-fstring-inner'),
+            # without format specifier
+            (r'(=\s*)?'         # debug (https://bugs.python.org/issue36817)
+             r'\}', String.Interpol, '#pop'),
+            # with format specifier
+            # we'll catch the remaining '}' in the outer scope
+            (r'(=\s*)?'         # debug (https://bugs.python.org/issue36817)
+             r':', String.Interpol, '#pop'),
+            (r'\s+', Whitespace),  # allow new lines
+            include('expr'),
+        ],
+        'expr-inside-fstring-inner': [
+            (r'[{([]', Punctuation, 'expr-inside-fstring-inner'),
+            (r'[])}]', Punctuation, '#pop'),
+            (r'\s+', Whitespace),  # allow new lines
+            include('expr'),
+        ],
+        'keywords': [
+            (words((
+                'into', 'case', 'type', 'module', 'internal',
+            ), suffix=r'\b'),
+                Keyword),
+            (words(('true', 'false', 'null'), suffix=r'\b'), Keyword.Constant),
+        ],
+        'functions': [
+            (words((
+                "min", "max", "sum", "average", "stddev", "every", "any",
+                "concat_array", "count", "lag", "lead", "first", "last",
+                "rank", "rank_dense", "row_number", "round", "as", "in",
+                "tuple_every", "tuple_map", "tuple_zip", "_eq", "_is_null",
+                "from_text", "lower", "upper", "read_parquet", "read_csv"),
+                suffix=r'\b'),
+             Name.Function),
+        ],
+
+        'comment': [
+            (r'-(?!\})', Comment.Multiline),
+            (r'\{-', Comment.Multiline, 'comment'),
+            (r'[^-}]', Comment.Multiline),
+            (r'-\}', Comment.Multiline, '#pop'),
+        ],
+
+        'imports': [
+            (r'\w+(\.\w+)*', Name.Class, '#pop'),
+        ],
+    }
diff --git a/pygments/lexers/ptx.py b/pygments/lexers/ptx.py
new file mode 100644
index 0000000000000000000000000000000000000000..784ca13a6f0c0ba5464706f68c84eeae9d36c9e9
--- /dev/null
+++ b/pygments/lexers/ptx.py
@@ -0,0 +1,119 @@
+"""
+    pygments.lexers.ptx
+    ~~~~~~~~~~~~~~~~~~~
+
+    Lexer for other PTX language.
+
+    :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS.
+    :license: BSD, see LICENSE for details.
+"""
+
+from pygments.lexer import RegexLexer, include, words
+from pygments.token import Comment, Keyword, Name, String, Number, \
+    Punctuation, Whitespace, Operator
+
+__all__ = ["PtxLexer"]
+
+
+class PtxLexer(RegexLexer):
+    """
+    For NVIDIA `PTX `_
+    source.
+    """
+    name = 'PTX'
+    url = "https://docs.nvidia.com/cuda/parallel-thread-execution/"
+    filenames = ['*.ptx']
+    aliases = ['ptx']
+    mimetypes = ['text/x-ptx']
+    version_added = '2.16'
+
+    #: optional Comment or Whitespace
+    string = r'"[^"]*?"'
+    followsym = r'[a-zA-Z0-9_$]'
+    identifier = r'([-a-zA-Z$._][\w\-$.]*|' + string + ')'
+    block_label = r'(' + identifier + r'|(\d+))'
+
+    tokens = {
+        'root': [
+            include('whitespace'),
+
+            (block_label + r'\s*:', Name.Label),
+
+            include('keyword'),
+
+            (r'%' + identifier, Name.Variable),
+            (r'%\d+', Name.Variable.Anonymous),
+            (r'c?' + string, String),
+            (identifier, Name.Variable),
+            (r';', Punctuation),
+            (r'[*+-/]', Operator),
+
+            (r'0[xX][a-fA-F0-9]+', Number),
+            (r'-?\d+(?:[.]\d+)?(?:[eE][-+]?\d+(?:[.]\d+)?)?', Number),
+
+            (r'[=<>{}\[\]()*.,!]|x\b', Punctuation)
+
+        ],
+        'whitespace': [
+            (r'(\n|\s+)+', Whitespace),
+            (r'//.*?\n', Comment)
+        ],
+
+        'keyword': [
+            # Instruction keywords
+            (words((
+                'abs', 'discard', 'min', 'shf', 'vadd',
+                'activemask', 'div', 'mma', 'shfl', 'vadd2',
+                'add', 'dp2a', 'mov', 'shl', 'vadd4',
+                'addc', 'dp4a', 'movmatrix', 'shr', 'vavrg2',
+                'alloca', 'elect', 'mul', 'sin', 'vavrg4',
+                'and', 'ex2', 'mul24', 'slct', 'vmad',
+                'applypriority', 'exit', 'multimem', 'sqrt', 'vmax',
+                'atom', 'fence', 'nanosleep', 'st', 'vmax2',
+                'bar', 'fma', 'neg', 'stackrestore', 'vmax4',
+                'barrier', 'fns', 'not', 'stacksave', 'vmin',
+                'bfe', 'getctarank', 'or', 'stmatrix', 'vmin2',
+                'bfi', 'griddepcontrol', 'pmevent', 'sub', 'vmin4',
+                'bfind', 'isspacep', 'popc', 'subc', 'vote',
+                'bmsk', 'istypep', 'prefetch', 'suld', 'vset',
+                'bra', 'ld', 'prefetchu', 'suq', 'vset2',
+                'brev', 'ldmatrix', 'prmt', 'sured', 'vset4',
+                'brkpt', 'ldu', 'rcp', 'sust', 'vshl',
+                'brx', 'lg2', 'red', 'szext', 'vshr',
+                'call', 'lop3', 'redux', 'tanh', 'vsub',
+                'clz', 'mad', 'rem', 'testp', 'vsub2',
+                'cnot', 'mad24', 'ret', 'tex', 'vsub4',
+                'copysign', 'madc', 'rsqrt', 'tld4', 'wgmma',
+                'cos', 'mapa', 'sad', 'trap', 'wmma',
+                'cp', 'match', 'selp', 'txq', 'xor',
+                'createpolicy', 'max', 'set', 'vabsdiff', 'cvt',
+                'mbarrier', 'setmaxnreg', 'vabsdiff2', 'cvta',
+                'membar', 'setp', 'vabsdiff4')), Keyword),
+            # State Spaces and Suffixes
+            (words((
+                'reg', '.sreg', '.const', '.global',
+                '.local', '.param', '.shared', '.tex',
+                '.wide', '.loc'
+            )), Keyword.Pseudo),
+            # PTX Directives
+            (words((
+                '.address_size', '.explicitcluster', '.maxnreg', '.section',
+                '.alias', '.extern', '.maxntid', '.shared',
+                '.align', '.file', '.minnctapersm', '.sreg',
+                '.branchtargets', '.func', '.noreturn', '.target',
+                '.callprototype', '.global', '.param', '.tex',
+                '.calltargets', '.loc', '.pragma', '.version',
+                '.common', '.local', '.reg', '.visible',
+                '.const', '.maxclusterrank', '.reqnctapercluster', '.weak',
+                '.entry', '.maxnctapersm', '.reqntid')), Keyword.Reserved),
+            # Fundamental Types
+            (words((
+                '.s8', '.s16', '.s32', '.s64',
+                '.u8', '.u16', '.u32', '.u64',
+                '.f16', '.f16x2', '.f32', '.f64',
+                '.b8', '.b16', '.b32', '.b64',
+                '.pred'
+            )), Keyword.Type)
+        ],
+
+    }
diff --git a/pygments/lexers/python.py b/pygments/lexers/python.py
new file mode 100644
index 0000000000000000000000000000000000000000..805f6ff2ac778176926cbdb40eca17ef4c93ee70
--- /dev/null
+++ b/pygments/lexers/python.py
@@ -0,0 +1,1201 @@
+"""
+    pygments.lexers.python
+    ~~~~~~~~~~~~~~~~~~~~~~
+
+    Lexers for Python and related languages.
+
+    :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS.
+    :license: BSD, see LICENSE for details.
+"""
+
+import keyword
+
+from pygments.lexer import DelegatingLexer, RegexLexer, include, \
+    bygroups, using, default, words, combined, this
+from pygments.util import get_bool_opt, shebang_matches
+from pygments.token import Text, Comment, Operator, Keyword, Name, String, \
+    Number, Punctuation, Generic, Other, Error, Whitespace
+from pygments import unistring as uni
+
+__all__ = ['PythonLexer', 'PythonConsoleLexer', 'PythonTracebackLexer',
+           'Python2Lexer', 'Python2TracebackLexer',
+           'CythonLexer', 'DgLexer', 'NumPyLexer']
+
+
+class PythonLexer(RegexLexer):
+    """
+    For Python source code (version 3.x).
+
+    .. versionchanged:: 2.5
+       This is now the default ``PythonLexer``.  It is still available as the
+       alias ``Python3Lexer``.
+    """
+
+    name = 'Python'
+    url = 'https://www.python.org'
+    aliases = ['python', 'py', 'sage', 'python3', 'py3', 'bazel', 'starlark', 'pyi']
+    filenames = [
+        '*.py',
+        '*.pyw',
+        # Type stubs
+        '*.pyi',
+        # Jython
+        '*.jy',
+        # Sage
+        '*.sage',
+        # SCons
+        '*.sc',
+        'SConstruct',
+        'SConscript',
+        # Skylark/Starlark (used by Bazel, Buck, and Pants)
+        '*.bzl',
+        'BUCK',
+        'BUILD',
+        'BUILD.bazel',
+        'WORKSPACE',
+        # Twisted Application infrastructure
+        '*.tac',
+    ]
+    mimetypes = ['text/x-python', 'application/x-python',
+                 'text/x-python3', 'application/x-python3']
+    version_added = '0.10'
+
+    uni_name = f"[{uni.xid_start}][{uni.xid_continue}]*"
+
+    def innerstring_rules(ttype):
+        return [
+            # the old style '%s' % (...) string formatting (still valid in Py3)
+            (r'%(\(\w+\))?[-#0 +]*([0-9]+|[*])?(\.([0-9]+|[*]))?'
+             '[hlL]?[E-GXc-giorsaux%]', String.Interpol),
+            # the new style '{}'.format(...) string formatting
+            (r'\{'
+             r'((\w+)((\.\w+)|(\[[^\]]+\]))*)?'  # field name
+             r'(\![sra])?'                       # conversion
+             r'(\:(.?[<>=\^])?[-+ ]?#?0?(\d+)?,?(\.\d+)?[E-GXb-gnosx%]?)?'
+             r'\}', String.Interpol),
+
+            # backslashes, quotes and formatting signs must be parsed one at a time
+            (r'[^\\\'"%{\n]+', ttype),
+            (r'[\'"\\]', ttype),
+            # unhandled string formatting sign
+            (r'%|(\{{1,2})', ttype)
+            # newlines are an error (use "nl" state)
+        ]
+
+    def fstring_rules(ttype):
+        return [
+            # Assuming that a '}' is the closing brace after format specifier.
+            # Sadly, this means that we won't detect syntax error. But it's
+            # more important to parse correct syntax correctly, than to
+            # highlight invalid syntax.
+            (r'\}', String.Interpol),
+            (r'\{', String.Interpol, 'expr-inside-fstring'),
+            # backslashes, quotes and formatting signs must be parsed one at a time
+            (r'[^\\\'"{}\n]+', ttype),
+            (r'[\'"\\]', ttype),
+            # newlines are an error (use "nl" state)
+        ]
+
+    tokens = {
+        'root': [
+            (r'\n', Whitespace),
+            (r'^(\s*)([rRuUbB]{,2})("""(?:.|\n)*?""")',
+             bygroups(Whitespace, String.Affix, String.Doc)),
+            (r"^(\s*)([rRuUbB]{,2})('''(?:.|\n)*?''')",
+             bygroups(Whitespace, String.Affix, String.Doc)),
+            (r'\A#!.+$', Comment.Hashbang),
+            (r'#.*$', Comment.Single),
+            (r'\\\n', Text),
+            (r'\\', Text),
+            include('keywords'),
+            include('soft-keywords'),
+            (r'(def)((?:\s|\\\s)+)', bygroups(Keyword, Whitespace), 'funcname'),
+            (r'(class)((?:\s|\\\s)+)', bygroups(Keyword, Whitespace), 'classname'),
+            (r'(from)((?:\s|\\\s)+)', bygroups(Keyword.Namespace, Whitespace),
+             'fromimport'),
+            (r'(import)((?:\s|\\\s)+)', bygroups(Keyword.Namespace, Whitespace),
+             'import'),
+            include('expr'),
+        ],
+        'expr': [
+            # raw f-strings
+            ('(?i)(rf|fr)(""")',
+             bygroups(String.Affix, String.Double),
+             combined('rfstringescape', 'tdqf')),
+            ("(?i)(rf|fr)(''')",
+             bygroups(String.Affix, String.Single),
+             combined('rfstringescape', 'tsqf')),
+            ('(?i)(rf|fr)(")',
+             bygroups(String.Affix, String.Double),
+             combined('rfstringescape', 'dqf')),
+            ("(?i)(rf|fr)(')",
+             bygroups(String.Affix, String.Single),
+             combined('rfstringescape', 'sqf')),
+            # non-raw f-strings
+            ('([fF])(""")', bygroups(String.Affix, String.Double),
+             combined('fstringescape', 'tdqf')),
+            ("([fF])(''')", bygroups(String.Affix, String.Single),
+             combined('fstringescape', 'tsqf')),
+            ('([fF])(")', bygroups(String.Affix, String.Double),
+             combined('fstringescape', 'dqf')),
+            ("([fF])(')", bygroups(String.Affix, String.Single),
+             combined('fstringescape', 'sqf')),
+            # raw bytes and strings
+            ('(?i)(rb|br|r)(""")',
+             bygroups(String.Affix, String.Double), 'tdqs'),
+            ("(?i)(rb|br|r)(''')",
+             bygroups(String.Affix, String.Single), 'tsqs'),
+            ('(?i)(rb|br|r)(")',
+             bygroups(String.Affix, String.Double), 'dqs'),
+            ("(?i)(rb|br|r)(')",
+             bygroups(String.Affix, String.Single), 'sqs'),
+            # non-raw strings
+            ('([uU]?)(""")', bygroups(String.Affix, String.Double),
+             combined('stringescape', 'tdqs')),
+            ("([uU]?)(''')", bygroups(String.Affix, String.Single),
+             combined('stringescape', 'tsqs')),
+            ('([uU]?)(")', bygroups(String.Affix, String.Double),
+             combined('stringescape', 'dqs')),
+            ("([uU]?)(')", bygroups(String.Affix, String.Single),
+             combined('stringescape', 'sqs')),
+            # non-raw bytes
+            ('([bB])(""")', bygroups(String.Affix, String.Double),
+             combined('bytesescape', 'tdqs')),
+            ("([bB])(''')", bygroups(String.Affix, String.Single),
+             combined('bytesescape', 'tsqs')),
+            ('([bB])(")', bygroups(String.Affix, String.Double),
+             combined('bytesescape', 'dqs')),
+            ("([bB])(')", bygroups(String.Affix, String.Single),
+             combined('bytesescape', 'sqs')),
+
+            (r'[^\S\n]+', Text),
+            include('numbers'),
+            (r'!=|==|<<|>>|:=|[-~+/*%=<>&^|.]', Operator),
+            (r'[]{}:(),;[]', Punctuation),
+            (r'(in|is|and|or|not)\b', Operator.Word),
+            include('expr-keywords'),
+            include('builtins'),
+            include('magicfuncs'),
+            include('magicvars'),
+            include('name'),
+        ],
+        'expr-inside-fstring': [
+            (r'[{([]', Punctuation, 'expr-inside-fstring-inner'),
+            # without format specifier
+            (r'(=\s*)?'         # debug (https://bugs.python.org/issue36817)
+             r'(\![sraf])?'     # conversion
+             r'\}', String.Interpol, '#pop'),
+            # with format specifier
+            # we'll catch the remaining '}' in the outer scope
+            (r'(=\s*)?'         # debug (https://bugs.python.org/issue36817)
+             r'(\![sraf])?'     # conversion
+             r':', String.Interpol, '#pop'),
+            (r'\s+', Whitespace),  # allow new lines
+            include('expr'),
+        ],
+        'expr-inside-fstring-inner': [
+            (r'[{([]', Punctuation, 'expr-inside-fstring-inner'),
+            (r'[])}]', Punctuation, '#pop'),
+            (r'\s+', Whitespace),  # allow new lines
+            include('expr'),
+        ],
+        'expr-keywords': [
+            # Based on https://docs.python.org/3/reference/expressions.html
+            (words((
+                'async for', 'await', 'else', 'for', 'if', 'lambda',
+                'yield', 'yield from'), suffix=r'\b'),
+             Keyword),
+            (words(('True', 'False', 'None'), suffix=r'\b'), Keyword.Constant),
+        ],
+        'keywords': [
+            (words((
+                'assert', 'async', 'await', 'break', 'continue', 'del', 'elif',
+                'else', 'except', 'finally', 'for', 'global', 'if', 'lambda',
+                'pass', 'raise', 'nonlocal', 'return', 'try', 'while', 'yield',
+                'yield from', 'as', 'with'), suffix=r'\b'),
+             Keyword),
+            (words(('True', 'False', 'None'), suffix=r'\b'), Keyword.Constant),
+        ],
+        'soft-keywords': [
+            # `match`, `case` and `_` soft keywords
+            (r'(^[ \t]*)'              # at beginning of line + possible indentation
+             r'(match|case)\b'         # a possible keyword
+             r'(?![ \t]*(?:'           # not followed by...
+             r'[:,;=^&|@~)\]}]|(?:' +  # characters and keywords that mean this isn't
+                                       # pattern matching (but None/True/False is ok)
+             r'|'.join(k for k in keyword.kwlist if k[0].islower()) + r')\b))',
+             bygroups(Text, Keyword), 'soft-keywords-inner'),
+        ],
+        'soft-keywords-inner': [
+            # optional `_` keyword
+            (r'(\s+)([^\n_]*)(_\b)', bygroups(Whitespace, using(this), Keyword)),
+            default('#pop')
+        ],
+        'builtins': [
+            (words((
+                '__import__', 'abs', 'aiter', 'all', 'any', 'bin', 'bool', 'bytearray',
+                'breakpoint', 'bytes', 'callable', 'chr', 'classmethod', 'compile',
+                'complex', 'delattr', 'dict', 'dir', 'divmod', 'enumerate', 'eval',
+                'filter', 'float', 'format', 'frozenset', 'getattr', 'globals',
+                'hasattr', 'hash', 'hex', 'id', 'input', 'int', 'isinstance',
+                'issubclass', 'iter', 'len', 'list', 'locals', 'map', 'max',
+                'memoryview', 'min', 'next', 'object', 'oct', 'open', 'ord', 'pow',
+                'print', 'property', 'range', 'repr', 'reversed', 'round', 'set',
+                'setattr', 'slice', 'sorted', 'staticmethod', 'str', 'sum', 'super',
+                'tuple', 'type', 'vars', 'zip'), prefix=r'(?>|[-~+/*%=<>&^|.]', Operator),
+            include('keywords'),
+            (r'(def)((?:\s|\\\s)+)', bygroups(Keyword, Whitespace), 'funcname'),
+            (r'(class)((?:\s|\\\s)+)', bygroups(Keyword, Whitespace), 'classname'),
+            (r'(from)((?:\s|\\\s)+)', bygroups(Keyword.Namespace, Whitespace),
+             'fromimport'),
+            (r'(import)((?:\s|\\\s)+)', bygroups(Keyword.Namespace, Whitespace),
+             'import'),
+            include('builtins'),
+            include('magicfuncs'),
+            include('magicvars'),
+            include('backtick'),
+            ('([rR]|[uUbB][rR]|[rR][uUbB])(""")',
+             bygroups(String.Affix, String.Double), 'tdqs'),
+            ("([rR]|[uUbB][rR]|[rR][uUbB])(''')",
+             bygroups(String.Affix, String.Single), 'tsqs'),
+            ('([rR]|[uUbB][rR]|[rR][uUbB])(")',
+             bygroups(String.Affix, String.Double), 'dqs'),
+            ("([rR]|[uUbB][rR]|[rR][uUbB])(')",
+             bygroups(String.Affix, String.Single), 'sqs'),
+            ('([uUbB]?)(""")', bygroups(String.Affix, String.Double),
+             combined('stringescape', 'tdqs')),
+            ("([uUbB]?)(''')", bygroups(String.Affix, String.Single),
+             combined('stringescape', 'tsqs')),
+            ('([uUbB]?)(")', bygroups(String.Affix, String.Double),
+             combined('stringescape', 'dqs')),
+            ("([uUbB]?)(')", bygroups(String.Affix, String.Single),
+             combined('stringescape', 'sqs')),
+            include('name'),
+            include('numbers'),
+        ],
+        'keywords': [
+            (words((
+                'assert', 'break', 'continue', 'del', 'elif', 'else', 'except',
+                'exec', 'finally', 'for', 'global', 'if', 'lambda', 'pass',
+                'print', 'raise', 'return', 'try', 'while', 'yield',
+                'yield from', 'as', 'with'), suffix=r'\b'),
+             Keyword),
+        ],
+        'builtins': [
+            (words((
+                '__import__', 'abs', 'all', 'any', 'apply', 'basestring', 'bin',
+                'bool', 'buffer', 'bytearray', 'bytes', 'callable', 'chr', 'classmethod',
+                'cmp', 'coerce', 'compile', 'complex', 'delattr', 'dict', 'dir', 'divmod',
+                'enumerate', 'eval', 'execfile', 'exit', 'file', 'filter', 'float',
+                'frozenset', 'getattr', 'globals', 'hasattr', 'hash', 'hex', 'id',
+                'input', 'int', 'intern', 'isinstance', 'issubclass', 'iter', 'len',
+                'list', 'locals', 'long', 'map', 'max', 'min', 'next', 'object',
+                'oct', 'open', 'ord', 'pow', 'property', 'range', 'raw_input', 'reduce',
+                'reload', 'repr', 'reversed', 'round', 'set', 'setattr', 'slice',
+                'sorted', 'staticmethod', 'str', 'sum', 'super', 'tuple', 'type',
+                'unichr', 'unicode', 'vars', 'xrange', 'zip'),
+                prefix=r'(?>> )(.*\n)', bygroups(Generic.Prompt, Other.Code), 'continuations'),
+            # This happens, e.g., when tracebacks are embedded in documentation;
+            # trailing whitespaces are often stripped in such contexts.
+            (r'(>>>)(\n)', bygroups(Generic.Prompt, Whitespace)),
+            (r'(\^C)?Traceback \(most recent call last\):\n', Other.Traceback, 'traceback'),
+            # SyntaxError starts with this
+            (r'  File "[^"]+", line \d+', Other.Traceback, 'traceback'),
+            (r'.*\n', Generic.Output),
+        ],
+        'continuations': [
+            (r'(\.\.\. )(.*\n)', bygroups(Generic.Prompt, Other.Code)),
+            # See above.
+            (r'(\.\.\.)(\n)', bygroups(Generic.Prompt, Whitespace)),
+            default('#pop'),
+        ],
+        'traceback': [
+            # As soon as we see a traceback, consume everything until the next
+            # >>> prompt.
+            (r'(?=>>>( |$))', Text, '#pop'),
+            (r'(KeyboardInterrupt)(\n)', bygroups(Name.Class, Whitespace)),
+            (r'.*\n', Other.Traceback),
+        ],
+    }
+
+
+class PythonConsoleLexer(DelegatingLexer):
+    """
+    For Python console output or doctests, such as:
+
+    .. sourcecode:: pycon
+
+        >>> a = 'foo'
+        >>> print(a)
+        foo
+        >>> 1 / 0
+        Traceback (most recent call last):
+          File "", line 1, in 
+        ZeroDivisionError: integer division or modulo by zero
+
+    Additional options:
+
+    `python3`
+        Use Python 3 lexer for code.  Default is ``True``.
+
+        .. versionadded:: 1.0
+        .. versionchanged:: 2.5
+           Now defaults to ``True``.
+    """
+
+    name = 'Python console session'
+    aliases = ['pycon', 'python-console']
+    mimetypes = ['text/x-python-doctest']
+    url = 'https://python.org'
+    version_added = ''
+
+    def __init__(self, **options):
+        python3 = get_bool_opt(options, 'python3', True)
+        if python3:
+            pylexer = PythonLexer
+            tblexer = PythonTracebackLexer
+        else:
+            pylexer = Python2Lexer
+            tblexer = Python2TracebackLexer
+        # We have two auxiliary lexers. Use DelegatingLexer twice with
+        # different tokens.  TODO: DelegatingLexer should support this
+        # directly, by accepting a tuplet of auxiliary lexers and a tuple of
+        # distinguishing tokens. Then we wouldn't need this intermediary
+        # class.
+        class _ReplaceInnerCode(DelegatingLexer):
+            def __init__(self, **options):
+                super().__init__(pylexer, _PythonConsoleLexerBase, Other.Code, **options)
+        super().__init__(tblexer, _ReplaceInnerCode, Other.Traceback, **options)
+
+
+class PythonTracebackLexer(RegexLexer):
+    """
+    For Python 3.x tracebacks, with support for chained exceptions.
+
+    .. versionchanged:: 2.5
+       This is now the default ``PythonTracebackLexer``.  It is still available
+       as the alias ``Python3TracebackLexer``.
+    """
+
+    name = 'Python Traceback'
+    aliases = ['pytb', 'py3tb']
+    filenames = ['*.pytb', '*.py3tb']
+    mimetypes = ['text/x-python-traceback', 'text/x-python3-traceback']
+    url = 'https://python.org'
+    version_added = '1.0'
+
+    tokens = {
+        'root': [
+            (r'\n', Whitespace),
+            (r'^(\^C)?Traceback \(most recent call last\):\n', Generic.Traceback, 'intb'),
+            (r'^During handling of the above exception, another '
+             r'exception occurred:\n\n', Generic.Traceback),
+            (r'^The above exception was the direct cause of the '
+             r'following exception:\n\n', Generic.Traceback),
+            (r'^(?=  File "[^"]+", line \d+)', Generic.Traceback, 'intb'),
+            (r'^.*\n', Other),
+        ],
+        'intb': [
+            (r'^(  File )("[^"]+")(, line )(\d+)(, in )(.+)(\n)',
+             bygroups(Text, Name.Builtin, Text, Number, Text, Name, Whitespace)),
+            (r'^(  File )("[^"]+")(, line )(\d+)(\n)',
+             bygroups(Text, Name.Builtin, Text, Number, Whitespace)),
+            (r'^(    )(.+)(\n)',
+             bygroups(Whitespace, using(PythonLexer), Whitespace), 'markers'),
+            (r'^([ \t]*)(\.\.\.)(\n)',
+             bygroups(Whitespace, Comment, Whitespace)),  # for doctests...
+            (r'^([^:]+)(: )(.+)(\n)',
+             bygroups(Generic.Error, Text, Name, Whitespace), '#pop'),
+            (r'^([a-zA-Z_][\w.]*)(:?\n)',
+             bygroups(Generic.Error, Whitespace), '#pop'),
+            default('#pop'),
+        ],
+        'markers': [
+            # Either `PEP 657 `
+            # error locations in Python 3.11+, or single-caret markers
+            # for syntax errors before that.
+            (r'^( {4,})([~^]+)(\n)',
+             bygroups(Whitespace, Punctuation.Marker, Whitespace),
+             '#pop'),
+            default('#pop'),
+        ],
+    }
+
+
+Python3TracebackLexer = PythonTracebackLexer
+
+
+class Python2TracebackLexer(RegexLexer):
+    """
+    For Python tracebacks.
+
+    .. versionchanged:: 2.5
+       This class has been renamed from ``PythonTracebackLexer``.
+       ``PythonTracebackLexer`` now refers to the Python 3 variant.
+    """
+
+    name = 'Python 2.x Traceback'
+    aliases = ['py2tb']
+    filenames = ['*.py2tb']
+    mimetypes = ['text/x-python2-traceback']
+    url = 'https://python.org'
+    version_added = '0.7'
+
+    tokens = {
+        'root': [
+            # Cover both (most recent call last) and (innermost last)
+            # The optional ^C allows us to catch keyboard interrupt signals.
+            (r'^(\^C)?(Traceback.*\n)',
+             bygroups(Text, Generic.Traceback), 'intb'),
+            # SyntaxError starts with this.
+            (r'^(?=  File "[^"]+", line \d+)', Generic.Traceback, 'intb'),
+            (r'^.*\n', Other),
+        ],
+        'intb': [
+            (r'^(  File )("[^"]+")(, line )(\d+)(, in )(.+)(\n)',
+             bygroups(Text, Name.Builtin, Text, Number, Text, Name, Whitespace)),
+            (r'^(  File )("[^"]+")(, line )(\d+)(\n)',
+             bygroups(Text, Name.Builtin, Text, Number, Whitespace)),
+            (r'^(    )(.+)(\n)',
+             bygroups(Text, using(Python2Lexer), Whitespace), 'marker'),
+            (r'^([ \t]*)(\.\.\.)(\n)',
+             bygroups(Text, Comment, Whitespace)),  # for doctests...
+            (r'^([^:]+)(: )(.+)(\n)',
+             bygroups(Generic.Error, Text, Name, Whitespace), '#pop'),
+            (r'^([a-zA-Z_]\w*)(:?\n)',
+             bygroups(Generic.Error, Whitespace), '#pop')
+        ],
+        'marker': [
+            # For syntax errors.
+            (r'( {4,})(\^)', bygroups(Text, Punctuation.Marker), '#pop'),
+            default('#pop'),
+        ],
+    }
+
+
+class CythonLexer(RegexLexer):
+    """
+    For Pyrex and Cython source code.
+    """
+
+    name = 'Cython'
+    url = 'https://cython.org'
+    aliases = ['cython', 'pyx', 'pyrex']
+    filenames = ['*.pyx', '*.pxd', '*.pxi']
+    mimetypes = ['text/x-cython', 'application/x-cython']
+    version_added = '1.1'
+
+    tokens = {
+        'root': [
+            (r'\n', Whitespace),
+            (r'^(\s*)("""(?:.|\n)*?""")', bygroups(Whitespace, String.Doc)),
+            (r"^(\s*)('''(?:.|\n)*?''')", bygroups(Whitespace, String.Doc)),
+            (r'[^\S\n]+', Text),
+            (r'#.*$', Comment),
+            (r'[]{}:(),;[]', Punctuation),
+            (r'\\\n', Whitespace),
+            (r'\\', Text),
+            (r'(in|is|and|or|not)\b', Operator.Word),
+            (r'(<)([a-zA-Z0-9.?]+)(>)',
+             bygroups(Punctuation, Keyword.Type, Punctuation)),
+            (r'!=|==|<<|>>|[-~+/*%=<>&^|.?]', Operator),
+            (r'(from)(\d+)(<=)(\s+)(<)(\d+)(:)',
+             bygroups(Keyword, Number.Integer, Operator, Whitespace, Operator,
+                      Name, Punctuation)),
+            include('keywords'),
+            (r'(def|property)(\s+)', bygroups(Keyword, Whitespace), 'funcname'),
+            (r'(cp?def)(\s+)', bygroups(Keyword, Whitespace), 'cdef'),
+            # (should actually start a block with only cdefs)
+            (r'(cdef)(:)', bygroups(Keyword, Punctuation)),
+            (r'(class|struct)(\s+)', bygroups(Keyword, Whitespace), 'classname'),
+            (r'(from)(\s+)', bygroups(Keyword, Whitespace), 'fromimport'),
+            (r'(c?import)(\s+)', bygroups(Keyword, Whitespace), 'import'),
+            include('builtins'),
+            include('backtick'),
+            ('(?:[rR]|[uU][rR]|[rR][uU])"""', String, 'tdqs'),
+            ("(?:[rR]|[uU][rR]|[rR][uU])'''", String, 'tsqs'),
+            ('(?:[rR]|[uU][rR]|[rR][uU])"', String, 'dqs'),
+            ("(?:[rR]|[uU][rR]|[rR][uU])'", String, 'sqs'),
+            ('[uU]?"""', String, combined('stringescape', 'tdqs')),
+            ("[uU]?'''", String, combined('stringescape', 'tsqs')),
+            ('[uU]?"', String, combined('stringescape', 'dqs')),
+            ("[uU]?'", String, combined('stringescape', 'sqs')),
+            include('name'),
+            include('numbers'),
+        ],
+        'keywords': [
+            (words((
+                'assert', 'async', 'await', 'break', 'by', 'continue', 'ctypedef', 'del', 'elif',
+                'else', 'except', 'except?', 'exec', 'finally', 'for', 'fused', 'gil',
+                'global', 'if', 'include', 'lambda', 'nogil', 'pass', 'print',
+                'raise', 'return', 'try', 'while', 'yield', 'as', 'with'), suffix=r'\b'),
+             Keyword),
+            (r'(DEF|IF|ELIF|ELSE)\b', Comment.Preproc),
+        ],
+        'builtins': [
+            (words((
+                '__import__', 'abs', 'all', 'any', 'apply', 'basestring', 'bin', 'bint',
+                'bool', 'buffer', 'bytearray', 'bytes', 'callable', 'chr',
+                'classmethod', 'cmp', 'coerce', 'compile', 'complex', 'delattr',
+                'dict', 'dir', 'divmod', 'enumerate', 'eval', 'execfile', 'exit',
+                'file', 'filter', 'float', 'frozenset', 'getattr', 'globals',
+                'hasattr', 'hash', 'hex', 'id', 'input', 'int', 'intern', 'isinstance',
+                'issubclass', 'iter', 'len', 'list', 'locals', 'long', 'map', 'max',
+                'min', 'next', 'object', 'oct', 'open', 'ord', 'pow', 'property', 'Py_ssize_t',
+                'range', 'raw_input', 'reduce', 'reload', 'repr', 'reversed',
+                'round', 'set', 'setattr', 'slice', 'sorted', 'staticmethod',
+                'str', 'sum', 'super', 'tuple', 'type', 'unichr', 'unicode', 'unsigned',
+                'vars', 'xrange', 'zip'), prefix=r'(??/\\:']?:)(\s*)(\{)",
+             bygroups(Name.Function, Whitespace, Operator, Whitespace, Punctuation),
+             "functions"),
+            # Variable Names
+            (r"([.]?[a-zA-Z][\w.]*)(\s*)([-.~=!@#$%^&*_+|,<>?/\\:']?:)",
+             bygroups(Name.Variable, Whitespace, Operator)),
+            # Functions
+            (r"\{", Punctuation, "functions"),
+            # Parentheses
+            (r"\(", Punctuation, "parentheses"),
+            # Brackets
+            (r"\[", Punctuation, "brackets"),
+            # Errors
+            (r"'`([a-zA-Z][\w.]*)?", Name.Exception),
+            # File Symbols
+            (r"`:([a-zA-Z/][\w./]*)?", String.Symbol),
+            # Symbols
+            (r"`([a-zA-Z][\w.]*)?", String.Symbol),
+            # Numbers
+            include("numbers"),
+            # Variable Names
+            (r"[a-zA-Z][\w.]*", Name),
+            # Operators
+            (r"[-=+*#$%@!~^&:.,<>'\\|/?_]", Operator),
+            # Punctuation
+            (r";", Punctuation),
+        ],
+        "functions": [
+            include("root"),
+            (r"\}", Punctuation, "#pop"),
+        ],
+        "parentheses": [
+            include("root"),
+            (r"\)", Punctuation, "#pop"),
+        ],
+        "brackets": [
+            include("root"),
+            (r"\]", Punctuation, "#pop"),
+        ],
+        "numbers": [
+            # Binary Values
+            (r"[01]+b", Number.Bin),
+            # Nulls/Infinities
+            (r"0[nNwW][cefghijmndzuvtp]?", Number),
+            # Timestamps
+            ((r"(?:[0-9]{4}[.][0-9]{2}[.][0-9]{2}|[0-9]+)"
+              "D(?:[0-9](?:[0-9](?::[0-9]{2}"
+              "(?::[0-9]{2}(?:[.][0-9]*)?)?)?)?)?"), Literal.Date),
+            # Datetimes
+            ((r"[0-9]{4}[.][0-9]{2}"
+              "(?:m|[.][0-9]{2}(?:T(?:[0-9]{2}:[0-9]{2}"
+              "(?::[0-9]{2}(?:[.][0-9]*)?)?)?)?)"), Literal.Date),
+            # Times
+            (r"[0-9]{2}:[0-9]{2}(?::[0-9]{2}(?:[.][0-9]{1,3})?)?",
+             Literal.Date),
+            # GUIDs
+            (r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}",
+             Number.Hex),
+            # Byte Vectors
+            (r"0x[0-9a-fA-F]+", Number.Hex),
+            # Floats
+            (r"([0-9]*[.]?[0-9]+|[0-9]+[.]?[0-9]*)[eE][+-]?[0-9]+[ef]?",
+             Number.Float),
+            (r"([0-9]*[.][0-9]+|[0-9]+[.][0-9]*)[ef]?", Number.Float),
+            (r"[0-9]+[ef]", Number.Float),
+            # Characters
+            (r"[0-9]+c", Number),
+            # Integers
+            (r"[0-9]+[ihtuv]", Number.Integer),
+            # Long Integers
+            (r"[0-9]+[jnp]?", Number.Integer.Long),
+        ],
+        "comments": [
+            (r"[^\\]+", Comment.Multiline),
+            (r"^\\", Comment.Multiline, "#pop"),
+            (r"\\", Comment.Multiline),
+        ],
+        "strings": [
+            (r'[^"\\]+', String.Double),
+            (r"\\.", String.Escape),
+            (r'"', String.Double, "#pop"),
+        ],
+    }
+
+
+class QLexer(KLexer):
+    """
+    For `Q `_ source code.
+    """
+
+    name = "Q"
+    aliases = ["q"]
+    filenames = ["*.q"]
+    version_added = '2.12'
+
+    tokens = {
+        "root": [
+            (words(("aj", "aj0", "ajf", "ajf0", "all", "and", "any", "asc",
+                    "asof", "attr", "avgs", "ceiling", "cols", "count", "cross",
+                    "csv", "cut", "deltas", "desc", "differ", "distinct", "dsave",
+                    "each", "ej", "ema", "eval", "except", "fby", "fills", "first",
+                    "fkeys", "flip", "floor", "get", "group", "gtime", "hclose",
+                    "hcount", "hdel", "hsym", "iasc", "idesc", "ij", "ijf",
+                    "inter", "inv", "key", "keys", "lj", "ljf", "load", "lower",
+                    "lsq", "ltime", "ltrim", "mavg", "maxs", "mcount", "md5",
+                    "mdev", "med", "meta", "mins", "mmax", "mmin", "mmu", "mod",
+                    "msum", "neg", "next", "not", "null", "or", "over", "parse",
+                    "peach", "pj", "prds", "prior", "prev", "rand", "rank", "ratios",
+                    "raze", "read0", "read1", "reciprocal", "reval", "reverse",
+                    "rload", "rotate", "rsave", "rtrim", "save", "scan", "scov",
+                    "sdev", "set", "show", "signum", "ssr", "string", "sublist",
+                    "sums", "sv", "svar", "system", "tables", "til", "trim", "txf",
+                    "type", "uj", "ujf", "ungroup", "union", "upper", "upsert",
+                    "value", "view", "views", "vs", "where", "wj", "wj1", "ww",
+                    "xasc", "xbar", "xcol", "xcols", "xdesc", "xgroup", "xkey",
+                    "xlog", "xprev", "xrank"),
+                    suffix=r"\b"), Name.Builtin,
+            ),
+            inherit,
+        ],
+    }
diff --git a/pygments/lexers/qlik.py b/pygments/lexers/qlik.py
new file mode 100644
index 0000000000000000000000000000000000000000..a29f89f35a9b1f4e8be514da1259e42feb2a9642
--- /dev/null
+++ b/pygments/lexers/qlik.py
@@ -0,0 +1,117 @@
+"""
+    pygments.lexers.qlik
+    ~~~~~~~~~~~~~~~~~~~~
+
+    Lexer for the qlik scripting language
+
+    :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS.
+    :license: BSD, see LICENSE for details.
+"""
+
+import re
+
+from pygments.lexer import RegexLexer, include, bygroups, words
+from pygments.token import Comment, Keyword, Name, Number, Operator, \
+    Punctuation, String, Text
+from pygments.lexers._qlik_builtins import OPERATORS_LIST, STATEMENT_LIST, \
+    SCRIPT_FUNCTIONS, CONSTANT_LIST
+
+__all__ = ["QlikLexer"]
+
+
+class QlikLexer(RegexLexer):
+    """
+    Lexer for qlik code, including .qvs files
+    """
+
+    name = "Qlik"
+    aliases = ["qlik", "qlikview", "qliksense", "qlikscript"]
+    filenames = ["*.qvs", "*.qvw"]
+    url = "https://qlik.com"
+    version_added = '2.12'
+
+    flags = re.IGNORECASE
+
+    tokens = {
+        # Handle multi-line comments
+        "comment": [
+            (r"\*/", Comment.Multiline, "#pop"),
+            (r"[^*]+", Comment.Multiline),
+        ],
+        # Handle numbers
+        "numerics": [
+            (r"\b\d+\.\d+(e\d+)?[fd]?\b", Number.Float),
+            (r"\b\d+\b", Number.Integer),
+        ],
+        # Handle variable names in things
+        "interp": [
+            (
+                r"(\$\()(\w+)(\))",
+                bygroups(String.Interpol, Name.Variable, String.Interpol),
+            ),
+        ],
+        # Handle strings
+        "string": [
+            (r"'", String, "#pop"),
+            include("interp"),
+            (r"[^'$]+", String),
+            (r"\$", String),
+        ],
+        #
+        "assignment": [
+            (r";", Punctuation, "#pop"),
+            include("root"),
+        ],
+        "field_name_quote": [
+            (r'"', String.Symbol, "#pop"),
+            include("interp"),
+            (r"[^\"$]+", String.Symbol),
+            (r"\$", String.Symbol),
+        ],
+        "field_name_bracket": [
+            (r"\]", String.Symbol, "#pop"),
+            include("interp"),
+            (r"[^\]$]+", String.Symbol),
+            (r"\$", String.Symbol),
+        ],
+        "function": [(r"\)", Punctuation, "#pop"), include("root")],
+        "root": [
+            # Whitespace and comments
+            (r"\s+", Text.Whitespace),
+            (r"/\*", Comment.Multiline, "comment"),
+            (r"//.*\n", Comment.Single),
+            # variable assignment
+            (r"(let|set)(\s+)", bygroups(Keyword.Declaration, Text.Whitespace),
+             "assignment"),
+            # Word operators
+            (words(OPERATORS_LIST["words"], prefix=r"\b", suffix=r"\b"),
+             Operator.Word),
+            # Statements
+            (words(STATEMENT_LIST, suffix=r"\b"), Keyword),
+            # Table names
+            (r"[a-z]\w*:", Keyword.Declaration),
+            # Constants
+            (words(CONSTANT_LIST, suffix=r"\b"), Keyword.Constant),
+            # Functions
+            (words(SCRIPT_FUNCTIONS, suffix=r"(?=\s*\()"), Name.Builtin,
+             "function"),
+            # interpolation - e.g. $(variableName)
+            include("interp"),
+            # Quotes denote a field/file name
+            (r'"', String.Symbol, "field_name_quote"),
+            # Square brackets denote a field/file name
+            (r"\[", String.Symbol, "field_name_bracket"),
+            # Strings
+            (r"'", String, "string"),
+            # Numbers
+            include("numerics"),
+            # Operator symbols
+            (words(OPERATORS_LIST["symbols"]), Operator),
+            # Strings denoted by single quotes
+            (r"'.+?'", String),
+            # Words as text
+            (r"\b\w+\b", Text),
+            # Basic punctuation
+            (r"[,;.()\\/]", Punctuation),
+        ],
+    }
diff --git a/pygments/lexers/qvt.py b/pygments/lexers/qvt.py
new file mode 100644
index 0000000000000000000000000000000000000000..302d1b6ed894ef6a928fd09ba7b679791d4ddaa7
--- /dev/null
+++ b/pygments/lexers/qvt.py
@@ -0,0 +1,153 @@
+"""
+    pygments.lexers.qvt
+    ~~~~~~~~~~~~~~~~~~~
+
+    Lexer for QVT Operational language.
+
+    :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS.
+    :license: BSD, see LICENSE for details.
+"""
+
+from pygments.lexer import RegexLexer, bygroups, include, combined, default, \
+    words
+from pygments.token import Text, Comment, Operator, Keyword, Punctuation, \
+    Name, String, Number
+
+__all__ = ['QVToLexer']
+
+
+class QVToLexer(RegexLexer):
+    """
+    For the QVT Operational Mapping language.
+
+    Reference for implementing this: «Meta Object Facility (MOF) 2.0
+    Query/View/Transformation Specification», Version 1.1 - January 2011
+    (https://www.omg.org/spec/QVT/1.1/), see §8.4, «Concrete Syntax» in
+    particular.
+
+    Notable tokens assignments:
+
+    - Name.Class is assigned to the identifier following any of the following
+      keywords: metamodel, class, exception, primitive, enum, transformation
+      or library
+
+    - Name.Function is assigned to the names of mappings and queries
+
+    - Name.Builtin.Pseudo is assigned to the pre-defined variables 'this',
+      'self' and 'result'.
+    """
+    # With obvious borrowings & inspiration from the Java, Python and C lexers
+
+    name = 'QVTO'
+    aliases = ['qvto', 'qvt']
+    filenames = ['*.qvto']
+    url = 'https://www.omg.org/spec/QVT/1.1'
+    version_added = ''
+
+    tokens = {
+        'root': [
+            (r'\n', Text),
+            (r'[^\S\n]+', Text),
+            (r'(--|//)(\s*)(directive:)?(.*)$',
+             bygroups(Comment, Comment, Comment.Preproc, Comment)),
+            # Uncomment the following if you want to distinguish between
+            # '/*' and '/**', à la javadoc
+            # (r'/[*]{2}(.|\n)*?[*]/', Comment.Multiline),
+            (r'/[*](.|\n)*?[*]/', Comment.Multiline),
+            (r'\\\n', Text),
+            (r'(and|not|or|xor|##?)\b', Operator.Word),
+            (r'(:{1,2}=|[-+]=)\b', Operator.Word),
+            (r'(@|<<|>>)\b', Keyword),  # stereotypes
+            (r'!=|<>|==|=|!->|->|>=|<=|[.]{3}|[+/*%=<>&|.~]', Operator),
+            (r'[]{}:(),;[]', Punctuation),
+            (r'(true|false|unlimited|null)\b', Keyword.Constant),
+            (r'(this|self|result)\b', Name.Builtin.Pseudo),
+            (r'(var)\b', Keyword.Declaration),
+            (r'(from|import)\b', Keyword.Namespace, 'fromimport'),
+            (r'(metamodel|class|exception|primitive|enum|transformation|'
+             r'library)(\s+)(\w+)',
+             bygroups(Keyword.Word, Text, Name.Class)),
+            (r'(exception)(\s+)(\w+)',
+             bygroups(Keyword.Word, Text, Name.Exception)),
+            (r'(main)\b', Name.Function),
+            (r'(mapping|helper|query)(\s+)',
+             bygroups(Keyword.Declaration, Text), 'operation'),
+            (r'(assert)(\s+)\b', bygroups(Keyword, Text), 'assert'),
+            (r'(Bag|Collection|Dict|OrderedSet|Sequence|Set|Tuple|List)\b',
+             Keyword.Type),
+            include('keywords'),
+            ('"', String, combined('stringescape', 'dqs')),
+            ("'", String, combined('stringescape', 'sqs')),
+            include('name'),
+            include('numbers'),
+            # (r'([a-zA-Z_]\w*)(::)([a-zA-Z_]\w*)',
+            # bygroups(Text, Text, Text)),
+        ],
+
+        'fromimport': [
+            (r'(?:[ \t]|\\\n)+', Text),
+            (r'[a-zA-Z_][\w.]*', Name.Namespace),
+            default('#pop'),
+        ],
+
+        'operation': [
+            (r'::', Text),
+            (r'(.*::)([a-zA-Z_]\w*)([ \t]*)(\()',
+             bygroups(Text, Name.Function, Text, Punctuation), '#pop')
+        ],
+
+        'assert': [
+            (r'(warning|error|fatal)\b', Keyword, '#pop'),
+            default('#pop'),  # all else: go back
+        ],
+
+        'keywords': [
+            (words((
+                'abstract', 'access', 'any', 'assert', 'blackbox', 'break',
+                'case', 'collect', 'collectNested', 'collectOne', 'collectselect',
+                'collectselectOne', 'composes', 'compute', 'configuration',
+                'constructor', 'continue', 'datatype', 'default', 'derived',
+                'disjuncts', 'do', 'elif', 'else', 'end', 'endif', 'except',
+                'exists', 'extends', 'forAll', 'forEach', 'forOne', 'from', 'if',
+                'implies', 'in', 'inherits', 'init', 'inout', 'intermediate',
+                'invresolve', 'invresolveIn', 'invresolveone', 'invresolveoneIn',
+                'isUnique', 'iterate', 'late', 'let', 'literal', 'log', 'map',
+                'merges', 'modeltype', 'new', 'object', 'one', 'ordered', 'out',
+                'package', 'population', 'property', 'raise', 'readonly',
+                'references', 'refines', 'reject', 'resolve', 'resolveIn',
+                'resolveone', 'resolveoneIn', 'return', 'select', 'selectOne',
+                'sortedBy', 'static', 'switch', 'tag', 'then', 'try', 'typedef',
+                'unlimited', 'uses', 'when', 'where', 'while', 'with', 'xcollect',
+                'xmap', 'xselect'), suffix=r'\b'), Keyword),
+        ],
+
+        # There is no need to distinguish between String.Single and
+        # String.Double: 'strings' is factorised for 'dqs' and 'sqs'
+        'strings': [
+            (r'[^\\\'"\n]+', String),
+            # quotes, percents and backslashes must be parsed one at a time
+            (r'[\'"\\]', String),
+        ],
+        'stringescape': [
+            (r'\\([\\btnfr"\']|u[0-3][0-7]{2}|u[0-7]{1,2})', String.Escape)
+        ],
+        'dqs': [  # double-quoted string
+            (r'"', String, '#pop'),
+            (r'\\\\|\\"', String.Escape),
+            include('strings')
+        ],
+        'sqs': [  # single-quoted string
+            (r"'", String, '#pop'),
+            (r"\\\\|\\'", String.Escape),
+            include('strings')
+        ],
+        'name': [
+            (r'[a-zA-Z_]\w*', Name),
+        ],
+        # numbers: excerpt taken from the python lexer
+        'numbers': [
+            (r'(\d+\.\d*|\d*\.\d+)([eE][+-]?[0-9]+)?', Number.Float),
+            (r'\d+[eE][+-]?[0-9]+', Number.Float),
+            (r'\d+', Number.Integer)
+        ],
+    }
diff --git a/pygments/lexers/r.py b/pygments/lexers/r.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3f65ba2d8644a6faa4f50526d2e31b08ed4d339
--- /dev/null
+++ b/pygments/lexers/r.py
@@ -0,0 +1,196 @@
+"""
+    pygments.lexers.r
+    ~~~~~~~~~~~~~~~~~
+
+    Lexers for the R/S languages.
+
+    :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS.
+    :license: BSD, see LICENSE for details.
+"""
+
+import re
+
+from pygments.lexer import Lexer, RegexLexer, include, do_insertions
+from pygments.token import Text, Comment, Operator, Keyword, Name, String, \
+    Number, Punctuation, Generic, Whitespace
+
+__all__ = ['RConsoleLexer', 'SLexer', 'RdLexer']
+
+
+line_re  = re.compile('.*?\n')
+
+
+class RConsoleLexer(Lexer):
+    """
+    For R console transcripts or R CMD BATCH output files.
+    """
+
+    name = 'RConsole'
+    aliases = ['rconsole', 'rout']
+    filenames = ['*.Rout']
+    url = 'https://www.r-project.org'
+    version_added = ''
+    _example = "rconsole/r-console-transcript.Rout"
+
+    def get_tokens_unprocessed(self, text):
+        slexer = SLexer(**self.options)
+
+        current_code_block = ''
+        insertions = []
+
+        for match in line_re.finditer(text):
+            line = match.group()
+            if line.startswith('>') or line.startswith('+'):
+                # Colorize the prompt as such,
+                # then put rest of line into current_code_block
+                insertions.append((len(current_code_block),
+                                   [(0, Generic.Prompt, line[:2])]))
+                current_code_block += line[2:]
+            else:
+                # We have reached a non-prompt line!
+                # If we have stored prompt lines, need to process them first.
+                if current_code_block:
+                    # Weave together the prompts and highlight code.
+                    yield from do_insertions(
+                        insertions, slexer.get_tokens_unprocessed(current_code_block))
+                    # Reset vars for next code block.
+                    current_code_block = ''
+                    insertions = []
+                # Now process the actual line itself, this is output from R.
+                yield match.start(), Generic.Output, line
+
+        # If we happen to end on a code block with nothing after it, need to
+        # process the last code block. This is neither elegant nor DRY so
+        # should be changed.
+        if current_code_block:
+            yield from do_insertions(
+                insertions, slexer.get_tokens_unprocessed(current_code_block))
+
+
+class SLexer(RegexLexer):
+    """
+    For S, S-plus, and R source code.
+    """
+
+    name = 'S'
+    aliases = ['splus', 's', 'r']
+    filenames = ['*.S', '*.R', '.Rhistory', '.Rprofile', '.Renviron']
+    mimetypes = ['text/S-plus', 'text/S', 'text/x-r-source', 'text/x-r',
+                 'text/x-R', 'text/x-r-history', 'text/x-r-profile']
+    url = 'https://www.r-project.org'
+    version_added = '0.10'
+
+    valid_name = r'`[^`\\]*(?:\\.[^`\\]*)*`|(?:[a-zA-Z]|\.[A-Za-z_.])[\w.]*|\.'
+    tokens = {
+        'comments': [
+            (r'#.*$', Comment.Single),
+        ],
+        'valid_name': [
+            (valid_name, Name),
+        ],
+        'function_name': [
+            (rf'({valid_name})\s*(?=\()', Name.Function),
+        ],
+        'punctuation': [
+            (r'\[{1,2}|\]{1,2}|\(|\)|;|,', Punctuation),
+        ],
+        'keywords': [
+            (r'(if|else|for|while|repeat|in|next|break|return|switch|function)'
+             r'(?![\w.])',
+             Keyword.Reserved),
+        ],
+        'operators': [
+            (r'<>?|-|==|<=|>=|\|>|<|>|&&?|!=|\|\|?|\?', Operator),
+            (r'\*|\+|\^|/|!|%[^%]*%|=|~|\$|@|:{1,3}', Operator),
+        ],
+        'builtin_symbols': [
+            (r'(NULL|NA(_(integer|real|complex|character)_)?|'
+             r'letters|LETTERS|Inf|TRUE|FALSE|NaN|pi|\.\.(\.|[0-9]+))'
+             r'(?![\w.])',
+             Keyword.Constant),
+            (r'(T|F)\b', Name.Builtin.Pseudo),
+        ],
+        'numbers': [
+            # hex number
+            (r'0[xX][a-fA-F0-9]+([pP][0-9]+)?[Li]?', Number.Hex),
+            # decimal number
+            (r'[+-]?([0-9]+(\.[0-9]+)?|\.[0-9]+|\.)([eE][+-]?[0-9]+)?[Li]?',
+             Number),
+        ],
+        'statements': [
+            include('comments'),
+            # whitespaces
+            (r'\s+', Whitespace),
+            (r'\'', String, 'string_squote'),
+            (r'\"', String, 'string_dquote'),
+            include('builtin_symbols'),
+            include('keywords'),
+            include('function_name'),
+            include('valid_name'),
+            include('numbers'),
+            include('punctuation'),
+            include('operators'),
+        ],
+        'root': [
+            # calls:
+            include('statements'),
+            # blocks:
+            (r'\{|\}', Punctuation),
+            # (r'\{', Punctuation, 'block'),
+            (r'.', Text),
+        ],
+        # 'block': [
+        #    include('statements'),
+        #    ('\{', Punctuation, '#push'),
+        #    ('\}', Punctuation, '#pop')
+        # ],
+        'string_squote': [
+            (r'([^\'\\]|\\.)*\'', String, '#pop'),
+        ],
+        'string_dquote': [
+            (r'([^"\\]|\\.)*"', String, '#pop'),
+        ],
+    }
+
+    def analyse_text(text):
+        if re.search(r'[a-z0-9_\])\s]<-(?!-)', text):
+            return 0.11
+
+
+class RdLexer(RegexLexer):
+    """
+    Pygments Lexer for R documentation (Rd) files
+
+    This is a very minimal implementation, highlighting little more
+    than the macros. A description of Rd syntax is found in `Writing R
+    Extensions `_
+    and `Parsing Rd files `_.
+    """
+    name = 'Rd'
+    aliases = ['rd']
+    filenames = ['*.Rd']
+    mimetypes = ['text/x-r-doc']
+    url = 'http://cran.r-project.org/doc/manuals/R-exts.html'
+    version_added = '1.6'
+
+    # To account for verbatim / LaTeX-like / and R-like areas
+    # would require parsing.
+    tokens = {
+        'root': [
+            # catch escaped brackets and percent sign
+            (r'\\[\\{}%]', String.Escape),
+            # comments
+            (r'%.*$', Comment),
+            # special macros with no arguments
+            (r'\\(?:cr|l?dots|R|tab)\b', Keyword.Constant),
+            # macros
+            (r'\\[a-zA-Z]+\b', Keyword),
+            # special preprocessor macros
+            (r'^\s*#(?:ifn?def|endif).*\b', Comment.Preproc),
+            # non-escaped brackets
+            (r'[{}]', Name.Builtin),
+            # everything else
+            (r'[^\\%\n{}]+', Text),
+            (r'.', Text),
+        ]
+    }
diff --git a/pygments/lexers/rdf.py b/pygments/lexers/rdf.py
new file mode 100644
index 0000000000000000000000000000000000000000..4930c1b387a6f098074fc15e26a1f7a62f3554c1
--- /dev/null
+++ b/pygments/lexers/rdf.py
@@ -0,0 +1,468 @@
+"""
+    pygments.lexers.rdf
+    ~~~~~~~~~~~~~~~~~~~
+
+    Lexers for semantic web and RDF query languages and markup.
+
+    :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS.
+    :license: BSD, see LICENSE for details.
+"""
+
+import re
+
+from pygments.lexer import RegexLexer, bygroups, default
+from pygments.token import Keyword, Punctuation, String, Number, Operator, \
+    Generic, Whitespace, Name, Literal, Comment, Text
+
+__all__ = ['SparqlLexer', 'TurtleLexer', 'ShExCLexer']
+
+
+class SparqlLexer(RegexLexer):
+    """
+    Lexer for SPARQL query language.
+    """
+    name = 'SPARQL'
+    aliases = ['sparql']
+    filenames = ['*.rq', '*.sparql']
+    mimetypes = ['application/sparql-query']
+    url = 'https://www.w3.org/TR/sparql11-query'
+    version_added = '2.0'
+
+    # character group definitions ::
+
+    PN_CHARS_BASE_GRP = ('a-zA-Z'
+                         '\u00c0-\u00d6'
+                         '\u00d8-\u00f6'
+                         '\u00f8-\u02ff'
+                         '\u0370-\u037d'
+                         '\u037f-\u1fff'
+                         '\u200c-\u200d'
+                         '\u2070-\u218f'
+                         '\u2c00-\u2fef'
+                         '\u3001-\ud7ff'
+                         '\uf900-\ufdcf'
+                         '\ufdf0-\ufffd')
+
+    PN_CHARS_U_GRP = (PN_CHARS_BASE_GRP + '_')
+
+    PN_CHARS_GRP = (PN_CHARS_U_GRP +
+                    r'\-' +
+                    r'0-9' +
+                    '\u00b7' +
+                    '\u0300-\u036f' +
+                    '\u203f-\u2040')
+
+    HEX_GRP = '0-9A-Fa-f'
+
+    PN_LOCAL_ESC_CHARS_GRP = r' _~.\-!$&"()*+,;=/?#@%'
+
+    # terminal productions ::
+
+    PN_CHARS_BASE = '[' + PN_CHARS_BASE_GRP + ']'
+
+    PN_CHARS_U = '[' + PN_CHARS_U_GRP + ']'
+
+    PN_CHARS = '[' + PN_CHARS_GRP + ']'
+
+    HEX = '[' + HEX_GRP + ']'
+
+    PN_LOCAL_ESC_CHARS = '[' + PN_LOCAL_ESC_CHARS_GRP + ']'
+
+    IRIREF = r'<(?:[^<>"{}|^`\\\x00-\x20])*>'
+
+    BLANK_NODE_LABEL = '_:[0-9' + PN_CHARS_U_GRP + '](?:[' + PN_CHARS_GRP + \
+                       '.]*' + PN_CHARS + ')?'
+
+    PN_PREFIX = PN_CHARS_BASE + '(?:[' + PN_CHARS_GRP + '.]*' + PN_CHARS + ')?'
+
+    VARNAME = '[0-9' + PN_CHARS_U_GRP + '][' + PN_CHARS_U_GRP + \
+              '0-9\u00b7\u0300-\u036f\u203f-\u2040]*'
+
+    PERCENT = '%' + HEX + HEX
+
+    PN_LOCAL_ESC = r'\\' + PN_LOCAL_ESC_CHARS
+
+    PLX = '(?:' + PERCENT + ')|(?:' + PN_LOCAL_ESC + ')'
+
+    PN_LOCAL = ('(?:[' + PN_CHARS_U_GRP + ':0-9' + ']|' + PLX + ')' +
+                '(?:(?:[' + PN_CHARS_GRP + '.:]|' + PLX + ')*(?:[' +
+                PN_CHARS_GRP + ':]|' + PLX + '))?')
+
+    EXPONENT = r'[eE][+-]?\d+'
+
+    # Lexer token definitions ::
+
+    tokens = {
+        'root': [
+            (r'\s+', Text),
+            # keywords ::
+            (r'(?i)(select|construct|describe|ask|where|filter|group\s+by|minus|'
+             r'distinct|reduced|from\s+named|from|order\s+by|desc|asc|limit|'
+             r'offset|values|bindings|load|into|clear|drop|create|add|move|copy|'
+             r'insert\s+data|delete\s+data|delete\s+where|with|delete|insert|'
+             r'using\s+named|using|graph|default|named|all|optional|service|'
+             r'silent|bind|undef|union|not\s+in|in|as|having|to|prefix|base)\b', Keyword),
+            (r'(a)\b', Keyword),
+            # IRIs ::
+            ('(' + IRIREF + ')', Name.Label),
+            # blank nodes ::
+            ('(' + BLANK_NODE_LABEL + ')', Name.Label),
+            #  # variables ::
+            ('[?$]' + VARNAME, Name.Variable),
+            # prefixed names ::
+            (r'(' + PN_PREFIX + r')?(\:)(' + PN_LOCAL + r')?',
+             bygroups(Name.Namespace, Punctuation, Name.Tag)),
+            # function names ::
+            (r'(?i)(str|lang|langmatches|datatype|bound|iri|uri|bnode|rand|abs|'
+             r'ceil|floor|round|concat|strlen|ucase|lcase|encode_for_uri|'
+             r'contains|strstarts|strends|strbefore|strafter|year|month|day|'
+             r'hours|minutes|seconds|timezone|tz|now|uuid|struuid|md5|sha1|sha256|sha384|'
+             r'sha512|coalesce|if|strlang|strdt|sameterm|isiri|isuri|isblank|'
+             r'isliteral|isnumeric|regex|substr|replace|exists|not\s+exists|'
+             r'count|sum|min|max|avg|sample|group_concat|separator)\b',
+             Name.Function),
+            # boolean literals ::
+            (r'(true|false)', Keyword.Constant),
+            # double literals ::
+            (r'[+\-]?(\d+\.\d*' + EXPONENT + r'|\.?\d+' + EXPONENT + ')', Number.Float),
+            # decimal literals ::
+            (r'[+\-]?(\d+\.\d*|\.\d+)', Number.Float),
+            # integer literals ::
+            (r'[+\-]?\d+', Number.Integer),
+            # operators ::
+            (r'(\|\||&&|=|\*|\-|\+|/|!=|<=|>=|!|<|>)', Operator),
+            # punctuation characters ::
+            (r'[(){}.;,:^\[\]]', Punctuation),
+            # line comments ::
+            (r'#[^\n]*', Comment),
+            # strings ::
+            (r'"""', String, 'triple-double-quoted-string'),
+            (r'"', String, 'single-double-quoted-string'),
+            (r"'''", String, 'triple-single-quoted-string'),
+            (r"'", String, 'single-single-quoted-string'),
+        ],
+        'triple-double-quoted-string': [
+            (r'"""', String, 'end-of-string'),
+            (r'[^\\]+', String),
+            (r'\\', String, 'string-escape'),
+        ],
+        'single-double-quoted-string': [
+            (r'"', String, 'end-of-string'),
+            (r'[^"\\\n]+', String),
+            (r'\\', String, 'string-escape'),
+        ],
+        'triple-single-quoted-string': [
+            (r"'''", String, 'end-of-string'),
+            (r'[^\\]+', String),
+            (r'\\', String.Escape, 'string-escape'),
+        ],
+        'single-single-quoted-string': [
+            (r"'", String, 'end-of-string'),
+            (r"[^'\\\n]+", String),
+            (r'\\', String, 'string-escape'),
+        ],
+        'string-escape': [
+            (r'u' + HEX + '{4}', String.Escape, '#pop'),
+            (r'U' + HEX + '{8}', String.Escape, '#pop'),
+            (r'.', String.Escape, '#pop'),
+        ],
+        'end-of-string': [
+            (r'(@)([a-zA-Z]+(?:-[a-zA-Z0-9]+)*)',
+             bygroups(Operator, Name.Function), '#pop:2'),
+            (r'\^\^', Operator, '#pop:2'),
+            default('#pop:2'),
+        ],
+    }
+
+
+class TurtleLexer(RegexLexer):
+    """
+    Lexer for Turtle data language.
+    """
+    name = 'Turtle'
+    aliases = ['turtle']
+    filenames = ['*.ttl']
+    mimetypes = ['text/turtle', 'application/x-turtle']
+    url = 'https://www.w3.org/TR/turtle'
+    version_added = '2.1'
+
+    # character group definitions ::
+    PN_CHARS_BASE_GRP = ('a-zA-Z'
+                         '\u00c0-\u00d6'
+                         '\u00d8-\u00f6'
+                         '\u00f8-\u02ff'
+                         '\u0370-\u037d'
+                         '\u037f-\u1fff'
+                         '\u200c-\u200d'
+                         '\u2070-\u218f'
+                         '\u2c00-\u2fef'
+                         '\u3001-\ud7ff'
+                         '\uf900-\ufdcf'
+                         '\ufdf0-\ufffd')
+
+    PN_CHARS_U_GRP = (PN_CHARS_BASE_GRP + '_')
+
+    PN_CHARS_GRP = (PN_CHARS_U_GRP +
+                    r'\-' +
+                    r'0-9' +
+                    '\u00b7' +
+                    '\u0300-\u036f' +
+                    '\u203f-\u2040')
+
+    PN_CHARS = '[' + PN_CHARS_GRP + ']'
+
+    PN_CHARS_BASE = '[' + PN_CHARS_BASE_GRP + ']'
+
+    PN_PREFIX = PN_CHARS_BASE + '(?:[' + PN_CHARS_GRP + '.]*' + PN_CHARS + ')?'
+
+    HEX_GRP = '0-9A-Fa-f'
+
+    HEX = '[' + HEX_GRP + ']'
+
+    PERCENT = '%' + HEX + HEX
+
+    PN_LOCAL_ESC_CHARS_GRP = r' _~.\-!$&"()*+,;=/?#@%'
+
+    PN_LOCAL_ESC_CHARS = '[' + PN_LOCAL_ESC_CHARS_GRP + ']'
+
+    PN_LOCAL_ESC = r'\\' + PN_LOCAL_ESC_CHARS
+
+    PLX = '(?:' + PERCENT + ')|(?:' + PN_LOCAL_ESC + ')'
+
+    PN_LOCAL = ('(?:[' + PN_CHARS_U_GRP + ':0-9' + ']|' + PLX + ')' +
+                '(?:(?:[' + PN_CHARS_GRP + '.:]|' + PLX + ')*(?:[' +
+                PN_CHARS_GRP + ':]|' + PLX + '))?')
+
+    patterns = {
+        'PNAME_NS': r'((?:[a-zA-Z][\w-]*)?\:)',  # Simplified character range
+        'IRIREF': r'(<[^<>"{}|^`\\\x00-\x20]*>)'
+    }
+
+    tokens = {
+        'root': [
+            (r'\s+', Text),
+
+            # Base / prefix
+            (r'(@base|BASE)(\s+){IRIREF}(\s*)(\.?)'.format(**patterns),
+             bygroups(Keyword, Whitespace, Name.Variable, Whitespace,
+                      Punctuation)),
+            (r'(@prefix|PREFIX)(\s+){PNAME_NS}(\s+){IRIREF}(\s*)(\.?)'.format(**patterns),
+             bygroups(Keyword, Whitespace, Name.Namespace, Whitespace,
+                      Name.Variable, Whitespace, Punctuation)),
+
+            # The shorthand predicate 'a'
+            (r'(?<=\s)a(?=\s)', Keyword.Type),
+
+            # IRIREF
+            (r'{IRIREF}'.format(**patterns), Name.Variable),
+
+            # PrefixedName
+            (r'(' + PN_PREFIX + r')?(\:)(' + PN_LOCAL + r')?',
+             bygroups(Name.Namespace, Punctuation, Name.Tag)),
+
+            # BlankNodeLabel
+            (r'(_)(:)([' + PN_CHARS_U_GRP + r'0-9]([' + PN_CHARS_GRP + r'.]*' + PN_CHARS + ')?)',
+             bygroups(Name.Namespace, Punctuation, Name.Tag)),
+
+            # Comment
+            (r'#[^\n]+', Comment),
+
+            (r'\b(true|false)\b', Literal),
+            (r'[+\-]?\d*\.\d+', Number.Float),
+            (r'[+\-]?\d*(:?\.\d+)?E[+\-]?\d+', Number.Float),
+            (r'[+\-]?\d+', Number.Integer),
+            (r'[\[\](){}.;,:^]', Punctuation),
+
+            (r'"""', String, 'triple-double-quoted-string'),
+            (r'"', String, 'single-double-quoted-string'),
+            (r"'''", String, 'triple-single-quoted-string'),
+            (r"'", String, 'single-single-quoted-string'),
+        ],
+        'triple-double-quoted-string': [
+            (r'"""', String, 'end-of-string'),
+            (r'[^\\]+(?=""")', String),
+            (r'\\', String, 'string-escape'),
+        ],
+        'single-double-quoted-string': [
+            (r'"', String, 'end-of-string'),
+            (r'[^"\\\n]+', String),
+            (r'\\', String, 'string-escape'),
+        ],
+        'triple-single-quoted-string': [
+            (r"'''", String, 'end-of-string'),
+            (r"[^\\]+(?=''')", String),
+            (r'\\', String, 'string-escape'),
+        ],
+        'single-single-quoted-string': [
+            (r"'", String, 'end-of-string'),
+            (r"[^'\\\n]+", String),
+            (r'\\', String, 'string-escape'),
+        ],
+        'string-escape': [
+            (r'.', String, '#pop'),
+        ],
+        'end-of-string': [
+            (r'(@)([a-zA-Z]+(?:-[a-zA-Z0-9]+)*)',
+             bygroups(Operator, Generic.Emph), '#pop:2'),
+
+            (r'(\^\^){IRIREF}'.format(**patterns), bygroups(Operator, Generic.Emph), '#pop:2'),
+
+            default('#pop:2'),
+
+        ],
+    }
+
+    # Turtle and Tera Term macro files share the same file extension
+    # but each has a recognizable and distinct syntax.
+    def analyse_text(text):
+        for t in ('@base ', 'BASE ', '@prefix ', 'PREFIX '):
+            if re.search(rf'^\s*{t}', text):
+                return 0.80
+
+
+class ShExCLexer(RegexLexer):
+    """
+    Lexer for ShExC shape expressions language syntax.
+    """
+    name = 'ShExC'
+    aliases = ['shexc', 'shex']
+    filenames = ['*.shex']
+    mimetypes = ['text/shex']
+    url = 'https://shex.io/shex-semantics/#shexc'
+    version_added = ''
+
+    # character group definitions ::
+
+    PN_CHARS_BASE_GRP = ('a-zA-Z'
+                         '\u00c0-\u00d6'
+                         '\u00d8-\u00f6'
+                         '\u00f8-\u02ff'
+                         '\u0370-\u037d'
+                         '\u037f-\u1fff'
+                         '\u200c-\u200d'
+                         '\u2070-\u218f'
+                         '\u2c00-\u2fef'
+                         '\u3001-\ud7ff'
+                         '\uf900-\ufdcf'
+                         '\ufdf0-\ufffd')
+
+    PN_CHARS_U_GRP = (PN_CHARS_BASE_GRP + '_')
+
+    PN_CHARS_GRP = (PN_CHARS_U_GRP +
+                    r'\-' +
+                    r'0-9' +
+                    '\u00b7' +
+                    '\u0300-\u036f' +
+                    '\u203f-\u2040')
+
+    HEX_GRP = '0-9A-Fa-f'
+
+    PN_LOCAL_ESC_CHARS_GRP = r"_~.\-!$&'()*+,;=/?#@%"
+
+    # terminal productions ::
+
+    PN_CHARS_BASE = '[' + PN_CHARS_BASE_GRP + ']'
+
+    PN_CHARS_U = '[' + PN_CHARS_U_GRP + ']'
+
+    PN_CHARS = '[' + PN_CHARS_GRP + ']'
+
+    HEX = '[' + HEX_GRP + ']'
+
+    PN_LOCAL_ESC_CHARS = '[' + PN_LOCAL_ESC_CHARS_GRP + ']'
+
+    UCHAR_NO_BACKSLASH = '(?:u' + HEX + '{4}|U' + HEX + '{8})'
+
+    UCHAR = r'\\' + UCHAR_NO_BACKSLASH
+
+    IRIREF = r'<(?:[^\x00-\x20<>"{}|^`\\]|' + UCHAR + ')*>'
+
+    BLANK_NODE_LABEL = '_:[0-9' + PN_CHARS_U_GRP + '](?:[' + PN_CHARS_GRP + \
+                       '.]*' + PN_CHARS + ')?'
+
+    PN_PREFIX = PN_CHARS_BASE + '(?:[' + PN_CHARS_GRP + '.]*' + PN_CHARS + ')?'
+
+    PERCENT = '%' + HEX + HEX
+
+    PN_LOCAL_ESC = r'\\' + PN_LOCAL_ESC_CHARS
+
+    PLX = '(?:' + PERCENT + ')|(?:' + PN_LOCAL_ESC + ')'
+
+    PN_LOCAL = ('(?:[' + PN_CHARS_U_GRP + ':0-9' + ']|' + PLX + ')' +
+                '(?:(?:[' + PN_CHARS_GRP + '.:]|' + PLX + ')*(?:[' +
+                PN_CHARS_GRP + ':]|' + PLX + '))?')
+
+    EXPONENT = r'[eE][+-]?\d+'
+
+    # Lexer token definitions ::
+
+    tokens = {
+        'root': [
+            (r'\s+', Text),
+            # keywords ::
+            (r'(?i)(base|prefix|start|external|'
+             r'literal|iri|bnode|nonliteral|length|minlength|maxlength|'
+             r'mininclusive|minexclusive|maxinclusive|maxexclusive|'
+             r'totaldigits|fractiondigits|'
+             r'closed|extra)\b', Keyword),
+            (r'(a)\b', Keyword),
+            # IRIs ::
+            ('(' + IRIREF + ')', Name.Label),
+            # blank nodes ::
+            ('(' + BLANK_NODE_LABEL + ')', Name.Label),
+            # prefixed names ::
+            (r'(' + PN_PREFIX + r')?(\:)(' + PN_LOCAL + ')?',
+             bygroups(Name.Namespace, Punctuation, Name.Tag)),
+            # boolean literals ::
+            (r'(true|false)', Keyword.Constant),
+            # double literals ::
+            (r'[+\-]?(\d+\.\d*' + EXPONENT + r'|\.?\d+' + EXPONENT + ')', Number.Float),
+            # decimal literals ::
+            (r'[+\-]?(\d+\.\d*|\.\d+)', Number.Float),
+            # integer literals ::
+            (r'[+\-]?\d+', Number.Integer),
+            # operators ::
+            (r'[@|$&=*+?^\-~]', Operator),
+            # operator keywords ::
+            (r'(?i)(and|or|not)\b', Operator.Word),
+            # punctuation characters ::
+            (r'[(){}.;,:^\[\]]', Punctuation),
+            # line comments ::
+            (r'#[^\n]*', Comment),
+            # strings ::
+            (r'"""', String, 'triple-double-quoted-string'),
+            (r'"', String, 'single-double-quoted-string'),
+            (r"'''", String, 'triple-single-quoted-string'),
+            (r"'", String, 'single-single-quoted-string'),
+        ],
+        'triple-double-quoted-string': [
+            (r'"""', String, 'end-of-string'),
+            (r'[^\\]+', String),
+            (r'\\', String, 'string-escape'),
+        ],
+        'single-double-quoted-string': [
+            (r'"', String, 'end-of-string'),
+            (r'[^"\\\n]+', String),
+            (r'\\', String, 'string-escape'),
+        ],
+        'triple-single-quoted-string': [
+            (r"'''", String, 'end-of-string'),
+            (r'[^\\]+', String),
+            (r'\\', String.Escape, 'string-escape'),
+        ],
+        'single-single-quoted-string': [
+            (r"'", String, 'end-of-string'),
+            (r"[^'\\\n]+", String),
+            (r'\\', String, 'string-escape'),
+        ],
+        'string-escape': [
+            (UCHAR_NO_BACKSLASH, String.Escape, '#pop'),
+            (r'.', String.Escape, '#pop'),
+        ],
+        'end-of-string': [
+            (r'(@)([a-zA-Z]+(?:-[a-zA-Z0-9]+)*)',
+             bygroups(Operator, Name.Function), '#pop:2'),
+            (r'\^\^', Operator, '#pop:2'),
+            default('#pop:2'),
+        ],
+    }
diff --git a/pygments/lexers/rebol.py b/pygments/lexers/rebol.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b37a749450472b4adf9c6a613bc1b8ab503f5a2
--- /dev/null
+++ b/pygments/lexers/rebol.py
@@ -0,0 +1,419 @@
+"""
+    pygments.lexers.rebol
+    ~~~~~~~~~~~~~~~~~~~~~
+
+    Lexers for the REBOL and related languages.
+
+    :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS.
+    :license: BSD, see LICENSE for details.
+"""
+
+import re
+
+from pygments.lexer import RegexLexer, bygroups
+from pygments.token import Text, Comment, Operator, Keyword, Name, String, \
+    Number, Generic, Whitespace
+
+__all__ = ['RebolLexer', 'RedLexer']
+
+
+class RebolLexer(RegexLexer):
+    """
+    A REBOL lexer.
+    """
+    name = 'REBOL'
+    aliases = ['rebol']
+    filenames = ['*.r', '*.r3', '*.reb']
+    mimetypes = ['text/x-rebol']
+    url = 'http://www.rebol.com'
+    version_added = '1.1'
+
+    flags = re.IGNORECASE | re.MULTILINE
+
+    escape_re = r'(?:\^\([0-9a-f]{1,4}\)*)'
+
+    def word_callback(lexer, match):
+        word = match.group()
+
+        if re.match(".*:$", word):
+            yield match.start(), Generic.Subheading, word
+        elif re.match(
+            r'(native|alias|all|any|as-string|as-binary|bind|bound\?|case|'
+            r'catch|checksum|comment|debase|dehex|exclude|difference|disarm|'
+            r'either|else|enbase|foreach|remove-each|form|free|get|get-env|if|'
+            r'in|intersect|loop|minimum-of|maximum-of|mold|new-line|'
+            r'new-line\?|not|now|prin|print|reduce|compose|construct|repeat|'
+            r'reverse|save|script\?|set|shift|switch|throw|to-hex|trace|try|'
+            r'type\?|union|unique|unless|unprotect|unset|until|use|value\?|'
+            r'while|compress|decompress|secure|open|close|read|read-io|'
+            r'write-io|write|update|query|wait|input\?|exp|log-10|log-2|'
+            r'log-e|square-root|cosine|sine|tangent|arccosine|arcsine|'
+            r'arctangent|protect|lowercase|uppercase|entab|detab|connected\?|'
+            r'browse|launch|stats|get-modes|set-modes|to-local-file|'
+            r'to-rebol-file|encloak|decloak|create-link|do-browser|bind\?|'
+            r'hide|draw|show|size-text|textinfo|offset-to-caret|'
+            r'caret-to-offset|local-request-file|rgb-to-hsv|hsv-to-rgb|'
+            r'crypt-strength\?|dh-make-key|dh-generate-key|dh-compute-key|'
+            r'dsa-make-key|dsa-generate-key|dsa-make-signature|'
+            r'dsa-verify-signature|rsa-make-key|rsa-generate-key|'
+            r'rsa-encrypt)$', word):
+            yield match.start(), Name.Builtin, word
+        elif re.match(
+            r'(add|subtract|multiply|divide|remainder|power|and~|or~|xor~|'
+            r'minimum|maximum|negate|complement|absolute|random|head|tail|'
+            r'next|back|skip|at|pick|first|second|third|fourth|fifth|sixth|'
+            r'seventh|eighth|ninth|tenth|last|path|find|select|make|to|copy\*|'
+            r'insert|remove|change|poke|clear|trim|sort|min|max|abs|cp|'
+            r'copy)$', word):
+            yield match.start(), Name.Function, word
+        elif re.match(
+            r'(error|source|input|license|help|install|echo|Usage|with|func|'
+            r'throw-on-error|function|does|has|context|probe|\?\?|as-pair|'
+            r'mod|modulo|round|repend|about|set-net|append|join|rejoin|reform|'
+            r'remold|charset|array|replace|move|extract|forskip|forall|alter|'
+            r'first+|also|take|for|forever|dispatch|attempt|what-dir|'
+            r'change-dir|clean-path|list-dir|dirize|rename|split-path|delete|'
+            r'make-dir|delete-dir|in-dir|confirm|dump-obj|upgrade|what|'
+            r'build-tag|process-source|build-markup|decode-cgi|read-cgi|'
+            r'write-user|save-user|set-user-name|protect-system|parse-xml|'
+            r'cvs-date|cvs-version|do-boot|get-net-info|desktop|layout|'
+            r'scroll-para|get-face|alert|set-face|uninstall|unfocus|'
+            r'request-dir|center-face|do-events|net-error|decode-url|'
+            r'parse-header|parse-header-date|parse-email-addrs|import-email|'
+            r'send|build-attach-body|resend|show-popup|hide-popup|open-events|'
+            r'find-key-face|do-face|viewtop|confine|find-window|'
+            r'insert-event-func|remove-event-func|inform|dump-pane|dump-face|'
+            r'flag-face|deflag-face|clear-fields|read-net|vbug|path-thru|'
+            r'read-thru|load-thru|do-thru|launch-thru|load-image|'
+            r'request-download|do-face-alt|set-font|set-para|get-style|'
+            r'set-style|make-face|stylize|choose|hilight-text|hilight-all|'
+            r'unlight-text|focus|scroll-drag|clear-face|reset-face|scroll-face|'
+            r'resize-face|load-stock|load-stock-block|notify|request|flash|'
+            r'request-color|request-pass|request-text|request-list|'
+            r'request-date|request-file|dbug|editor|link-relative-path|'
+            r'emailer|parse-error)$', word):
+            yield match.start(), Keyword.Namespace, word
+        elif re.match(
+            r'(halt|quit|do|load|q|recycle|call|run|ask|parse|view|unview|'
+            r'return|exit|break)$', word):
+            yield match.start(), Name.Exception, word
+        elif re.match('REBOL$', word):
+            yield match.start(), Generic.Heading, word
+        elif re.match("to-.*", word):
+            yield match.start(), Keyword, word
+        elif re.match(r'(\+|-|\*|/|//|\*\*|and|or|xor|=\?|=|==|<>|<|>|<=|>=)$',
+                      word):
+            yield match.start(), Operator, word
+        elif re.match(r".*\?$", word):
+            yield match.start(), Keyword, word
+        elif re.match(r".*\!$", word):
+            yield match.start(), Keyword.Type, word
+        elif re.match("'.*", word):
+            yield match.start(), Name.Variable.Instance, word  # lit-word
+        elif re.match("#.*", word):
+            yield match.start(), Name.Label, word  # issue
+        elif re.match("%.*", word):
+            yield match.start(), Name.Decorator, word  # file
+        else:
+            yield match.start(), Name.Variable, word
+
+    tokens = {
+        'root': [
+            (r'\s+', Text),
+            (r'#"', String.Char, 'char'),
+            (r'#\{[0-9a-f]*\}', Number.Hex),
+            (r'2#\{', Number.Hex, 'bin2'),
+            (r'64#\{[0-9a-z+/=\s]*\}', Number.Hex),
+            (r'"', String, 'string'),
+            (r'\{', String, 'string2'),
+            (r';#+.*\n', Comment.Special),
+            (r';\*+.*\n', Comment.Preproc),
+            (r';.*\n', Comment),
+            (r'%"', Name.Decorator, 'stringFile'),
+            (r'%[^(^{")\s\[\]]+', Name.Decorator),
+            (r'[+-]?([a-z]{1,3})?\$\d+(\.\d+)?', Number.Float),  # money
+            (r'[+-]?\d+\:\d+(\:\d+)?(\.\d+)?', String.Other),    # time
+            (r'\d+[\-/][0-9a-z]+[\-/]\d+(\/\d+\:\d+((\:\d+)?'
+             r'([.\d+]?([+-]?\d+:\d+)?)?)?)?', String.Other),   # date
+            (r'\d+(\.\d+)+\.\d+', Keyword.Constant),             # tuple
+            (r'\d+X\d+', Keyword.Constant),                   # pair
+            (r'[+-]?\d+(\'\d+)?([.,]\d*)?E[+-]?\d+', Number.Float),
+            (r'[+-]?\d+(\'\d+)?[.,]\d*', Number.Float),
+            (r'[+-]?\d+(\'\d+)?', Number),
+            (r'[\[\]()]', Generic.Strong),
+            (r'[a-z]+[^(^{"\s:)]*://[^(^{"\s)]*', Name.Decorator),  # url
+            (r'mailto:[^(^{"@\s)]+@[^(^{"@\s)]+', Name.Decorator),  # url
+            (r'[^(^{"@\s)]+@[^(^{"@\s)]+', Name.Decorator),         # email
+            (r'comment\s"', Comment, 'commentString1'),
+            (r'comment\s\{', Comment, 'commentString2'),
+            (r'comment\s\[', Comment, 'commentBlock'),
+            (r'comment\s[^(\s{"\[]+', Comment),
+            (r'/[^(^{")\s/[\]]*', Name.Attribute),
+            (r'([^(^{")\s/[\]]+)(?=[:({"\s/\[\]])', word_callback),
+            (r'<[\w:.-]*>', Name.Tag),
+            (r'<[^(<>\s")]+', Name.Tag, 'tag'),
+            (r'([^(^{")\s]+)', Text),
+        ],
+        'string': [
+            (r'[^(^")]+', String),
+            (escape_re, String.Escape),
+            (r'[(|)]+', String),
+            (r'\^.', String.Escape),
+            (r'"', String, '#pop'),
+        ],
+        'string2': [
+            (r'[^(^{})]+', String),
+            (escape_re, String.Escape),
+            (r'[(|)]+', String),
+            (r'\^.', String.Escape),
+            (r'\{', String, '#push'),
+            (r'\}', String, '#pop'),
+        ],
+        'stringFile': [
+            (r'[^(^")]+', Name.Decorator),
+            (escape_re, Name.Decorator),
+            (r'\^.', Name.Decorator),
+            (r'"', Name.Decorator, '#pop'),
+        ],
+        'char': [
+            (escape_re + '"', String.Char, '#pop'),
+            (r'\^."', String.Char, '#pop'),
+            (r'."', String.Char, '#pop'),
+        ],
+        'tag': [
+            (escape_re, Name.Tag),
+            (r'"', Name.Tag, 'tagString'),
+            (r'[^(<>\r\n")]+', Name.Tag),
+            (r'>', Name.Tag, '#pop'),
+        ],
+        'tagString': [
+            (r'[^(^")]+', Name.Tag),
+            (escape_re, Name.Tag),
+            (r'[(|)]+', Name.Tag),
+            (r'\^.', Name.Tag),
+            (r'"', Name.Tag, '#pop'),
+        ],
+        'tuple': [
+            (r'(\d+\.)+', Keyword.Constant),
+            (r'\d+', Keyword.Constant, '#pop'),
+        ],
+        'bin2': [
+            (r'\s+', Number.Hex),
+            (r'([01]\s*){8}', Number.Hex),
+            (r'\}', Number.Hex, '#pop'),
+        ],
+        'commentString1': [
+            (r'[^(^")]+', Comment),
+            (escape_re, Comment),
+            (r'[(|)]+', Comment),
+            (r'\^.', Comment),
+            (r'"', Comment, '#pop'),
+        ],
+        'commentString2': [
+            (r'[^(^{})]+', Comment),
+            (escape_re, Comment),
+            (r'[(|)]+', Comment),
+            (r'\^.', Comment),
+            (r'\{', Comment, '#push'),
+            (r'\}', Comment, '#pop'),
+        ],
+        'commentBlock': [
+            (r'\[', Comment, '#push'),
+            (r'\]', Comment, '#pop'),
+            (r'"', Comment, "commentString1"),
+            (r'\{', Comment, "commentString2"),
+            (r'[^(\[\]"{)]+', Comment),
+        ],
+    }
+
+    def analyse_text(text):
+        """
+        Check if code contains REBOL header and so it probably not R code
+        """
+        if re.match(r'^\s*REBOL\s*\[', text, re.IGNORECASE):
+            # The code starts with REBOL header
+            return 1.0
+        elif re.search(r'\s*REBOL\s*\[', text, re.IGNORECASE):
+            # The code contains REBOL header but also some text before it
+            return 0.5
+
+
+class RedLexer(RegexLexer):
+    """
+    A Red-language lexer.
+    """
+    name = 'Red'
+    aliases = ['red', 'red/system']
+    filenames = ['*.red', '*.reds']
+    mimetypes = ['text/x-red', 'text/x-red-system']
+    url = 'https://www.red-lang.org'
+    version_added = '2.0'
+
+    flags = re.IGNORECASE | re.MULTILINE
+
+    escape_re = r'(?:\^\([0-9a-f]{1,4}\)*)'
+
+    def word_callback(lexer, match):
+        word = match.group()
+
+        if re.match(".*:$", word):
+            yield match.start(), Generic.Subheading, word
+        elif re.match(r'(if|unless|either|any|all|while|until|loop|repeat|'
+                      r'foreach|forall|func|function|does|has|switch|'
+                      r'case|reduce|compose|get|set|print|prin|equal\?|'
+                      r'not-equal\?|strict-equal\?|lesser\?|greater\?|lesser-or-equal\?|'
+                      r'greater-or-equal\?|same\?|not|type\?|stats|'
+                      r'bind|union|replace|charset|routine)$', word):
+            yield match.start(), Name.Builtin, word
+        elif re.match(r'(make|random|reflect|to|form|mold|absolute|add|divide|multiply|negate|'
+                      r'power|remainder|round|subtract|even\?|odd\?|and~|complement|or~|xor~|'
+                      r'append|at|back|change|clear|copy|find|head|head\?|index\?|insert|'
+                      r'length\?|next|pick|poke|remove|reverse|select|sort|skip|swap|tail|tail\?|'
+                      r'take|trim|create|close|delete|modify|open|open\?|query|read|rename|'
+                      r'update|write)$', word):
+            yield match.start(), Name.Function, word
+        elif re.match(r'(yes|on|no|off|true|false|tab|cr|lf|newline|escape|slash|sp|space|null|'
+                      r'none|crlf|dot|null-byte)$', word):
+            yield match.start(), Name.Builtin.Pseudo, word
+        elif re.match(r'(#system-global|#include|#enum|#define|#either|#if|#import|#export|'
+                      r'#switch|#default|#get-definition)$', word):
+            yield match.start(), Keyword.Namespace, word
+        elif re.match(r'(system|halt|quit|quit-return|do|load|q|recycle|call|run|ask|parse|'
+                      r'raise-error|return|exit|break|alias|push|pop|probe|\?\?|spec-of|body-of|'
+                      r'quote|forever)$', word):
+            yield match.start(), Name.Exception, word
+        elif re.match(r'(action\?|block\?|char\?|datatype\?|file\?|function\?|get-path\?|zero\?|'
+                      r'get-word\?|integer\?|issue\?|lit-path\?|lit-word\?|logic\?|native\?|'
+                      r'op\?|paren\?|path\?|refinement\?|set-path\?|set-word\?|string\?|unset\?|'
+                      r'any-struct\?|none\?|word\?|any-series\?)$', word):
+            yield match.start(), Keyword, word
+        elif re.match(r'(JNICALL|stdcall|cdecl|infix)$', word):
+            yield match.start(), Keyword.Namespace, word
+        elif re.match("to-.*", word):
+            yield match.start(), Keyword, word
+        elif re.match(r'(\+|-\*\*|-|\*\*|//|/|\*|and|or|xor|=\?|===|==|=|<>|<=|>=|'
+                      r'<<<|>>>|<<|>>|<|>%)$', word):
+            yield match.start(), Operator, word
+        elif re.match(r".*\!$", word):
+            yield match.start(), Keyword.Type, word
+        elif re.match("'.*", word):
+            yield match.start(), Name.Variable.Instance, word  # lit-word
+        elif re.match("#.*", word):
+            yield match.start(), Name.Label, word  # issue
+        elif re.match("%.*", word):
+            yield match.start(), Name.Decorator, word  # file
+        elif re.match(":.*", word):
+            yield match.start(), Generic.Subheading, word  # get-word
+        else:
+            yield match.start(), Name.Variable, word
+
+    tokens = {
+        'root': [
+            (r'\s+', Text),
+            (r'#"', String.Char, 'char'),
+            (r'#\{[0-9a-f\s]*\}', Number.Hex),
+            (r'2#\{', Number.Hex, 'bin2'),
+            (r'64#\{[0-9a-z+/=\s]*\}', Number.Hex),
+            (r'([0-9a-f]+)(h)((\s)|(?=[\[\]{}"()]))',
+             bygroups(Number.Hex, Name.Variable, Whitespace)),
+            (r'"', String, 'string'),
+            (r'\{', String, 'string2'),
+            (r';#+.*\n', Comment.Special),
+            (r';\*+.*\n', Comment.Preproc),
+            (r';.*\n', Comment),
+            (r'%"', Name.Decorator, 'stringFile'),
+            (r'%[^(^{")\s\[\]]+', Name.Decorator),
+            (r'[+-]?([a-z]{1,3})?\$\d+(\.\d+)?', Number.Float),  # money
+            (r'[+-]?\d+\:\d+(\:\d+)?(\.\d+)?', String.Other),    # time
+            (r'\d+[\-/][0-9a-z]+[\-/]\d+(/\d+:\d+((:\d+)?'
+             r'([\.\d+]?([+-]?\d+:\d+)?)?)?)?', String.Other),   # date
+            (r'\d+(\.\d+)+\.\d+', Keyword.Constant),             # tuple
+            (r'\d+X\d+', Keyword.Constant),                   # pair
+            (r'[+-]?\d+(\'\d+)?([.,]\d*)?E[+-]?\d+', Number.Float),
+            (r'[+-]?\d+(\'\d+)?[.,]\d*', Number.Float),
+            (r'[+-]?\d+(\'\d+)?', Number),
+            (r'[\[\]()]', Generic.Strong),
+            (r'[a-z]+[^(^{"\s:)]*://[^(^{"\s)]*', Name.Decorator),  # url
+            (r'mailto:[^(^{"@\s)]+@[^(^{"@\s)]+', Name.Decorator),  # url
+            (r'[^(^{"@\s)]+@[^(^{"@\s)]+', Name.Decorator),         # email
+            (r'comment\s"', Comment, 'commentString1'),
+            (r'comment\s\{', Comment, 'commentString2'),
+            (r'comment\s\[', Comment, 'commentBlock'),
+            (r'comment\s[^(\s{"\[]+', Comment),
+            (r'/[^(^{^")\s/[\]]*', Name.Attribute),
+            (r'([^(^{^")\s/[\]]+)(?=[:({"\s/\[\]])', word_callback),
+            (r'<[\w:.-]*>', Name.Tag),
+            (r'<[^(<>\s")]+', Name.Tag, 'tag'),
+            (r'([^(^{")\s]+)', Text),
+        ],
+        'string': [
+            (r'[^(^")]+', String),
+            (escape_re, String.Escape),
+            (r'[(|)]+', String),
+            (r'\^.', String.Escape),
+            (r'"', String, '#pop'),
+        ],
+        'string2': [
+            (r'[^(^{})]+', String),
+            (escape_re, String.Escape),
+            (r'[(|)]+', String),
+            (r'\^.', String.Escape),
+            (r'\{', String, '#push'),
+            (r'\}', String, '#pop'),
+        ],
+        'stringFile': [
+            (r'[^(^")]+', Name.Decorator),
+            (escape_re, Name.Decorator),
+            (r'\^.', Name.Decorator),
+            (r'"', Name.Decorator, '#pop'),
+        ],
+        'char': [
+            (escape_re + '"', String.Char, '#pop'),
+            (r'\^."', String.Char, '#pop'),
+            (r'."', String.Char, '#pop'),
+        ],
+        'tag': [
+            (escape_re, Name.Tag),
+            (r'"', Name.Tag, 'tagString'),
+            (r'[^(<>\r\n")]+', Name.Tag),
+            (r'>', Name.Tag, '#pop'),
+        ],
+        'tagString': [
+            (r'[^(^")]+', Name.Tag),
+            (escape_re, Name.Tag),
+            (r'[(|)]+', Name.Tag),
+            (r'\^.', Name.Tag),
+            (r'"', Name.Tag, '#pop'),
+        ],
+        'tuple': [
+            (r'(\d+\.)+', Keyword.Constant),
+            (r'\d+', Keyword.Constant, '#pop'),
+        ],
+        'bin2': [
+            (r'\s+', Number.Hex),
+            (r'([01]\s*){8}', Number.Hex),
+            (r'\}', Number.Hex, '#pop'),
+        ],
+        'commentString1': [
+            (r'[^(^")]+', Comment),
+            (escape_re, Comment),
+            (r'[(|)]+', Comment),
+            (r'\^.', Comment),
+            (r'"', Comment, '#pop'),
+        ],
+        'commentString2': [
+            (r'[^(^{})]+', Comment),
+            (escape_re, Comment),
+            (r'[(|)]+', Comment),
+            (r'\^.', Comment),
+            (r'\{', Comment, '#push'),
+            (r'\}', Comment, '#pop'),
+        ],
+        'commentBlock': [
+            (r'\[', Comment, '#push'),
+            (r'\]', Comment, '#pop'),
+            (r'"', Comment, "commentString1"),
+            (r'\{', Comment, "commentString2"),
+            (r'[^(\[\]"{)]+', Comment),
+        ],
+    }
diff --git a/pygments/lexers/rego.py b/pygments/lexers/rego.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f2e3e9e669bcf30a05f6e0e8deef87d79f65bd2
--- /dev/null
+++ b/pygments/lexers/rego.py
@@ -0,0 +1,57 @@
+"""
+    pygments.lexers.rego
+    ~~~~~~~~~~~~~~~~~~~~
+
+    Lexers for the Rego policy languages.
+
+    :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS.
+    :license: BSD, see LICENSE for details.
+"""
+
+from pygments.lexer import RegexLexer, words
+from pygments.token import Comment, Operator, Keyword, Name, String, Number, Punctuation, Whitespace
+
+class RegoLexer(RegexLexer):
+    """
+    For Rego source.
+    """
+    name = 'Rego'
+    url = 'https://www.openpolicyagent.org/docs/latest/policy-language/'
+    filenames = ['*.rego']
+    aliases = ['rego']
+    mimetypes = ['text/x-rego']
+    version_added = '2.19'
+
+    reserved_words = (
+        'as', 'contains', 'data', 'default', 'else', 'every', 'false',
+        'if', 'in', 'import', 'package', 'not', 'null',
+        'some', 'true', 'with'
+    )
+
+    builtins = (
+        # https://www.openpolicyagent.org/docs/latest/philosophy/#the-opa-document-model
+        'data',  # Global variable for accessing base and virtual documents
+        'input', # Represents synchronously pushed base documents
+    )
+
+    tokens = {
+        'root': [
+            (r'\n', Whitespace),
+            (r'\s+', Whitespace),
+            (r'#.*?$', Comment.Single),
+            (words(reserved_words, suffix=r'\b'), Keyword),
+            (words(builtins, suffix=r'\b'), Name.Builtin),
+            (r'[a-zA-Z_][a-zA-Z0-9_]*', Name),
+            (r'"(\\\\|\\"|[^"])*"', String.Double),
+            (r'`[^`]*`', String.Backtick),
+            (r'-?\d+(\.\d+)?', Number),
+            (r'(==|!=|<=|>=|:=)', Operator),  # Compound operators
+            (r'[=<>+\-*/%&|]', Operator),     # Single-character operators
+            (r'[\[\]{}(),.:;]', Punctuation),
+        ]
+    }
+
+__all__ = ['RegoLexer']
+
+
+
diff --git a/pygments/lexers/resource.py b/pygments/lexers/resource.py
new file mode 100644
index 0000000000000000000000000000000000000000..9593c2124c1388aa2b7e821fada19347ab395d54
--- /dev/null
+++ b/pygments/lexers/resource.py
@@ -0,0 +1,83 @@
+"""
+    pygments.lexers.resource
+    ~~~~~~~~~~~~~~~~~~~~~~~~
+
+    Lexer for resource definition files.
+
+    :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS.
+    :license: BSD, see LICENSE for details.
+"""
+
+import re
+
+from pygments.lexer import RegexLexer, bygroups, words
+from pygments.token import Comment, String, Number, Operator, Text, \
+    Keyword, Name
+
+__all__ = ['ResourceLexer']
+
+
+class ResourceLexer(RegexLexer):
+    """Lexer for ICU Resource bundles.
+    """
+    name = 'ResourceBundle'
+    aliases = ['resourcebundle', 'resource']
+    filenames = []
+    url = 'https://unicode-org.github.io/icu/userguide/locale/resources.html'
+    version_added = '2.0'
+
+    _types = (':table', ':array', ':string', ':bin', ':import', ':intvector',
+              ':int', ':alias')
+
+    flags = re.MULTILINE | re.IGNORECASE
+    tokens = {
+        'root': [
+            (r'//.*?$', Comment),
+            (r'"', String, 'string'),
+            (r'-?\d+', Number.Integer),
+            (r'[,{}]', Operator),
+            (r'([^\s{{:]+)(\s*)({}?)'.format('|'.join(_types)),
+             bygroups(Name, Text, Keyword)),
+            (r'\s+', Text),
+            (words(_types), Keyword),
+        ],
+        'string': [
+            (r'(\\x[0-9a-f]{2}|\\u[0-9a-f]{4}|\\U00[0-9a-f]{6}|'
+             r'\\[0-7]{1,3}|\\c.|\\[abtnvfre\'"?\\]|\\\{|[^"{\\])+', String),
+            (r'\{', String.Escape, 'msgname'),
+            (r'"', String, '#pop')
+        ],
+        'msgname': [
+            (r'([^{},]+)(\s*)', bygroups(Name, String.Escape), ('#pop', 'message'))
+        ],
+        'message': [
+            (r'\{', String.Escape, 'msgname'),
+            (r'\}', String.Escape, '#pop'),
+            (r'(,)(\s*)([a-z]+)(\s*\})',
+             bygroups(Operator, String.Escape, Keyword, String.Escape), '#pop'),
+            (r'(,)(\s*)([a-z]+)(\s*)(,)(\s*)(offset)(\s*)(:)(\s*)(-?\d+)(\s*)',
+             bygroups(Operator, String.Escape, Keyword, String.Escape, Operator,
+                      String.Escape, Operator.Word, String.Escape, Operator,
+                      String.Escape, Number.Integer, String.Escape), 'choice'),
+            (r'(,)(\s*)([a-z]+)(\s*)(,)(\s*)',
+             bygroups(Operator, String.Escape, Keyword, String.Escape, Operator,
+                      String.Escape), 'choice'),
+            (r'\s+', String.Escape)
+        ],
+        'choice': [
+            (r'(=|<|>|<=|>=|!=)(-?\d+)(\s*\{)',
+             bygroups(Operator, Number.Integer, String.Escape), 'message'),
+            (r'([a-z]+)(\s*\{)', bygroups(Keyword.Type, String.Escape), 'str'),
+            (r'\}', String.Escape, ('#pop', '#pop')),
+            (r'\s+', String.Escape)
+        ],
+        'str': [
+            (r'\}', String.Escape, '#pop'),
+            (r'\{', String.Escape, 'msgname'),
+            (r'[^{}]+', String)
+        ]
+    }
+
+    def analyse_text(text):
+        if text.startswith('root:table'):
+            return 1.0
diff --git a/pygments/lexers/ride.py b/pygments/lexers/ride.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d60c29cbb259a7a1d70cbf0a3ec08438f1f55b2
--- /dev/null
+++ b/pygments/lexers/ride.py
@@ -0,0 +1,138 @@
+"""
+    pygments.lexers.ride
+    ~~~~~~~~~~~~~~~~~~~~
+
+    Lexer for the Ride programming language.
+
+    :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS.
+    :license: BSD, see LICENSE for details.
+"""
+
+from pygments.lexer import RegexLexer, words, include
+from pygments.token import Comment, Keyword, Name, Number, Punctuation, \
+    String, Text
+
+__all__ = ['RideLexer']
+
+
+class RideLexer(RegexLexer):
+    """
+    For Ride source code.
+    """
+
+    name = 'Ride'
+    aliases = ['ride']
+    filenames = ['*.ride']
+    mimetypes = ['text/x-ride']
+    url = 'https://docs.waves.tech/en/ride'
+    version_added = '2.6'
+
+    validName = r'[a-zA-Z_][a-zA-Z0-9_\']*'
+
+    builtinOps = (
+        '||', '|', '>=', '>', '==', '!',
+        '=', '<=', '<', '::', ':+', ':', '!=', '/',
+        '.', '=>', '-', '+', '*', '&&', '%', '++',
+    )
+
+    globalVariablesName = (
+        'NOALG', 'MD5', 'SHA1', 'SHA224', 'SHA256', 'SHA384', 'SHA512',
+        'SHA3224', 'SHA3256', 'SHA3384', 'SHA3512', 'nil', 'this', 'unit',
+        'height', 'lastBlock', 'Buy', 'Sell', 'CEILING', 'FLOOR', 'DOWN',
+        'HALFDOWN', 'HALFEVEN', 'HALFUP', 'UP',
+    )
+
+    typesName = (
+        'Unit', 'Int', 'Boolean', 'ByteVector', 'String', 'Address', 'Alias',
+        'Transfer', 'AssetPair', 'DataEntry', 'Order', 'Transaction',
+        'GenesisTransaction', 'PaymentTransaction', 'ReissueTransaction',
+        'BurnTransaction', 'MassTransferTransaction', 'ExchangeTransaction',
+        'TransferTransaction', 'SetAssetScriptTransaction',
+        'InvokeScriptTransaction', 'IssueTransaction', 'LeaseTransaction',
+        'LeaseCancelTransaction', 'CreateAliasTransaction',
+        'SetScriptTransaction', 'SponsorFeeTransaction', 'DataTransaction',
+        'WriteSet', 'AttachedPayment', 'ScriptTransfer', 'TransferSet',
+        'ScriptResult', 'Invocation', 'Asset', 'BlockInfo', 'Issue', 'Reissue',
+        'Burn', 'NoAlg', 'Md5', 'Sha1', 'Sha224', 'Sha256', 'Sha384', 'Sha512',
+        'Sha3224', 'Sha3256', 'Sha3384', 'Sha3512', 'BinaryEntry',
+        'BooleanEntry', 'IntegerEntry', 'StringEntry', 'List', 'Ceiling',
+        'Down', 'Floor', 'HalfDown', 'HalfEven', 'HalfUp', 'Up',
+    )
+
+    functionsName = (
+        'fraction', 'size', 'toBytes', 'take', 'drop', 'takeRight', 'dropRight',
+        'toString', 'isDefined', 'extract', 'throw', 'getElement', 'value',
+        'cons', 'toUtf8String', 'toInt', 'indexOf', 'lastIndexOf', 'split',
+        'parseInt', 'parseIntValue', 'keccak256', 'blake2b256', 'sha256',
+        'sigVerify', 'toBase58String', 'fromBase58String', 'toBase64String',
+        'fromBase64String', 'transactionById', 'transactionHeightById',
+        'getInteger', 'getBoolean', 'getBinary', 'getString',
+        'addressFromPublicKey', 'addressFromString', 'addressFromRecipient',
+        'assetBalance', 'wavesBalance', 'getIntegerValue', 'getBooleanValue',
+        'getBinaryValue', 'getStringValue', 'addressFromStringValue',
+        'assetInfo', 'rsaVerify', 'checkMerkleProof', 'median',
+        'valueOrElse', 'valueOrErrorMessage', 'contains', 'log', 'pow',
+        'toBase16String', 'fromBase16String', 'blockInfoByHeight',
+        'transferTransactionById',
+    )
+
+    reservedWords = words((
+        'match', 'case', 'else', 'func', 'if',
+        'let', 'then', '@Callable', '@Verifier',
+    ), suffix=r'\b')
+
+    tokens = {
+        'root': [
+            # Comments
+            (r'#.*', Comment.Single),
+            # Whitespace
+            (r'\s+', Text),
+            # Strings
+            (r'"', String, 'doublequote'),
+            (r'utf8\'', String, 'utf8quote'),
+            (r'base(58|64|16)\'', String, 'singlequote'),
+            # Keywords
+            (reservedWords, Keyword.Reserved),
+            (r'\{-#.*?#-\}', Keyword.Reserved),
+            (r'FOLD<\d+>', Keyword.Reserved),
+            # Types
+            (words(typesName), Keyword.Type),
+            # Main
+            # (specialName, Keyword.Reserved),
+            # Prefix Operators
+            (words(builtinOps, prefix=r'\(', suffix=r'\)'), Name.Function),
+            # Infix Operators
+            (words(builtinOps), Name.Function),
+            (words(globalVariablesName), Name.Function),
+            (words(functionsName), Name.Function),
+            # Numbers
+            include('numbers'),
+            # Variable Names
+            (validName, Name.Variable),
+            # Parens
+            (r'[,()\[\]{}]', Punctuation),
+        ],
+
+        'doublequote': [
+            (r'\\u[0-9a-fA-F]{4}', String.Escape),
+            (r'\\[nrfvb\\"]', String.Escape),
+            (r'[^"]', String),
+            (r'"', String, '#pop'),
+        ],
+
+        'utf8quote': [
+            (r'\\u[0-9a-fA-F]{4}', String.Escape),
+            (r'\\[nrfvb\\\']', String.Escape),
+            (r'[^\']', String),
+            (r'\'', String, '#pop'),
+        ],
+
+        'singlequote': [
+            (r'[^\']', String),
+            (r'\'', String, '#pop'),
+        ],
+
+        'numbers': [
+            (r'_?\d+', Number.Integer),
+        ],
+    }
diff --git a/pygments/lexers/rita.py b/pygments/lexers/rita.py
new file mode 100644
index 0000000000000000000000000000000000000000..536aafffd2e88c9a5a4b6b4ff9d6cd42872a2162
--- /dev/null
+++ b/pygments/lexers/rita.py
@@ -0,0 +1,42 @@
+"""
+    pygments.lexers.rita
+    ~~~~~~~~~~~~~~~~~~~~
+
+    Lexers for RITA language
+
+    :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS.
+    :license: BSD, see LICENSE for details.
+"""
+
+from pygments.lexer import RegexLexer
+from pygments.token import Comment, Operator, Keyword, Name, Literal, \
+    Punctuation, Whitespace
+
+__all__ = ['RitaLexer']
+
+
+class RitaLexer(RegexLexer):
+    """
+    Lexer for RITA.
+    """
+    name = 'Rita'
+    url = 'https://github.com/zaibacu/rita-dsl'
+    filenames = ['*.rita']
+    aliases = ['rita']
+    mimetypes = ['text/rita']
+    version_added = '2.11'
+
+    tokens = {
+        'root': [
+            (r'\n', Whitespace),
+            (r'\s+', Whitespace),
+            (r'#(.*?)\n', Comment.Single),
+            (r'@(.*?)\n', Operator),  # Yes, whole line as an operator
+            (r'"(\w|\d|\s|(\\")|[\'_\-./,\?\!])+?"', Literal),
+            (r'\'(\w|\d|\s|(\\\')|["_\-./,\?\!])+?\'', Literal),
+            (r'([A-Z_]+)', Keyword),
+            (r'([a-z0-9_]+)', Name),
+            (r'((->)|[!?+*|=])', Operator),
+            (r'[\(\),\{\}]', Punctuation)
+        ]
+    }
diff --git a/pygments/lexers/rnc.py b/pygments/lexers/rnc.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7a06bb918390f59753808c2f4f99009da5c51b1
--- /dev/null
+++ b/pygments/lexers/rnc.py
@@ -0,0 +1,66 @@
+"""
+    pygments.lexers.rnc
+    ~~~~~~~~~~~~~~~~~~~
+
+    Lexer for Relax-NG Compact syntax
+
+    :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS.
+    :license: BSD, see LICENSE for details.
+"""
+
+from pygments.lexer import RegexLexer
+from pygments.token import Text, Comment, Operator, Keyword, Name, String, \
+    Punctuation
+
+__all__ = ['RNCCompactLexer']
+
+
+class RNCCompactLexer(RegexLexer):
+    """
+    For RelaxNG-compact syntax.
+    """
+
+    name = 'Relax-NG Compact'
+    url = 'http://relaxng.org'
+    aliases = ['rng-compact', 'rnc']
+    filenames = ['*.rnc']
+    version_added = '2.2'
+
+    tokens = {
+        'root': [
+            (r'namespace\b', Keyword.Namespace),
+            (r'(?:default|datatypes)\b', Keyword.Declaration),
+            (r'##.*$', Comment.Preproc),
+            (r'#.*$', Comment.Single),
+            (r'"[^"]*"', String.Double),
+            # TODO single quoted strings and escape sequences outside of
+            # double-quoted strings
+            (r'(?:element|attribute|mixed)\b', Keyword.Declaration, 'variable'),
+            (r'(text\b|xsd:[^ ]+)', Keyword.Type, 'maybe_xsdattributes'),
+            (r'[,?&*=|~]|>>', Operator),
+            (r'[(){}]', Punctuation),
+            (r'.', Text),
+        ],
+
+        # a variable has been declared using `element` or `attribute`
+        'variable': [
+            (r'[^{]+', Name.Variable),
+            (r'\{', Punctuation, '#pop'),
+        ],
+
+        # after an xsd: declaration there may be attributes
+        'maybe_xsdattributes': [
+            (r'\{', Punctuation, 'xsdattributes'),
+            (r'\}', Punctuation, '#pop'),
+            (r'.', Text),
+        ],
+
+        # attributes take the form { key1 = value1 key2 = value2 ... }
+        'xsdattributes': [
+            (r'[^ =}]', Name.Attribute),
+            (r'=', Operator),
+            (r'"[^"]*"', String.Double),
+            (r'\}', Punctuation, '#pop'),
+            (r'.', Text),
+        ],
+    }
diff --git a/pygments/lexers/roboconf.py b/pygments/lexers/roboconf.py
new file mode 100644
index 0000000000000000000000000000000000000000..31adba9f48f99a4f9c38941e3c0d73714051a994
--- /dev/null
+++ b/pygments/lexers/roboconf.py
@@ -0,0 +1,81 @@
+"""
+    pygments.lexers.roboconf
+    ~~~~~~~~~~~~~~~~~~~~~~~~
+
+    Lexers for Roboconf DSL.
+
+    :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS.
+    :license: BSD, see LICENSE for details.
+"""
+
+from pygments.lexer import RegexLexer, words, re
+from pygments.token import Text, Operator, Keyword, Name, Comment
+
+__all__ = ['RoboconfGraphLexer', 'RoboconfInstancesLexer']
+
+
+class RoboconfGraphLexer(RegexLexer):
+    """
+    Lexer for Roboconf graph files.
+    """
+    name = 'Roboconf Graph'
+    aliases = ['roboconf-graph']
+    filenames = ['*.graph']
+    url = 'https://roboconf.github.io/en/user-guide/graph-definition.html'
+    version_added = '2.1'
+
+    flags = re.IGNORECASE | re.MULTILINE
+    tokens = {
+        'root': [
+            # Skip white spaces
+            (r'\s+', Text),
+
+            # There is one operator
+            (r'=', Operator),
+
+            # Keywords
+            (words(('facet', 'import'), suffix=r'\s*\b', prefix=r'\b'), Keyword),
+            (words((
+                'installer', 'extends', 'exports', 'imports', 'facets',
+                'children'), suffix=r'\s*:?', prefix=r'\b'), Name),
+
+            # Comments
+            (r'#.*\n', Comment),
+
+            # Default
+            (r'[^#]', Text),
+            (r'.*\n', Text)
+        ]
+    }
+
+
+class RoboconfInstancesLexer(RegexLexer):
+    """
+    Lexer for Roboconf instances files.
+    """
+    name = 'Roboconf Instances'
+    aliases = ['roboconf-instances']
+    filenames = ['*.instances']
+    url = 'https://roboconf.github.io'
+    version_added = '2.1'
+
+    flags = re.IGNORECASE | re.MULTILINE
+    tokens = {
+        'root': [
+
+            # Skip white spaces
+            (r'\s+', Text),
+
+            # Keywords
+            (words(('instance of', 'import'), suffix=r'\s*\b', prefix=r'\b'), Keyword),
+            (words(('name', 'count'), suffix=r's*:?', prefix=r'\b'), Name),
+            (r'\s*[\w.-]+\s*:', Name),
+
+            # Comments
+            (r'#.*\n', Comment),
+
+            # Default
+            (r'[^#]', Text),
+            (r'.*\n', Text)
+        ]
+    }
diff --git a/pygments/lexers/robotframework.py b/pygments/lexers/robotframework.py
new file mode 100644
index 0000000000000000000000000000000000000000..f92d567503fbb371a224d531f504c59fbf77e1a8
--- /dev/null
+++ b/pygments/lexers/robotframework.py
@@ -0,0 +1,551 @@
+"""
+    pygments.lexers.robotframework
+    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+    Lexer for Robot Framework.
+
+    :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS.
+    :license: BSD, see LICENSE for details.
+"""
+
+#  Copyright 2012 Nokia Siemens Networks Oyj
+#
+#  Licensed under the Apache License, Version 2.0 (the "License");
+#  you may not use this file except in compliance with the License.
+#  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+#  limitations under the License.
+
+import re
+
+from pygments.lexer import Lexer
+from pygments.token import Token
+
+__all__ = ['RobotFrameworkLexer']
+
+
+HEADING = Token.Generic.Heading
+SETTING = Token.Keyword.Namespace
+IMPORT = Token.Name.Namespace
+TC_KW_NAME = Token.Generic.Subheading
+KEYWORD = Token.Name.Function
+ARGUMENT = Token.String
+VARIABLE = Token.Name.Variable
+COMMENT = Token.Comment
+SEPARATOR = Token.Punctuation
+SYNTAX = Token.Punctuation
+GHERKIN = Token.Generic.Emph
+ERROR = Token.Error
+
+
+def normalize(string, remove=''):
+    string = string.lower()
+    for char in remove + ' ':
+        if char in string:
+            string = string.replace(char, '')
+    return string
+
+
+class RobotFrameworkLexer(Lexer):
+    """
+    For Robot Framework test data.
+
+    Supports both space and pipe separated plain text formats.
+    """
+    name = 'RobotFramework'
+    url = 'http://robotframework.org'
+    aliases = ['robotframework']
+    filenames = ['*.robot', '*.resource']
+    mimetypes = ['text/x-robotframework']
+    version_added = '1.6'
+
+    def __init__(self, **options):
+        options['tabsize'] = 2
+        options['encoding'] = 'UTF-8'
+        Lexer.__init__(self, **options)
+
+    def get_tokens_unprocessed(self, text):
+        row_tokenizer = RowTokenizer()
+        var_tokenizer = VariableTokenizer()
+        index = 0
+        for row in text.splitlines():
+            for value, token in row_tokenizer.tokenize(row):
+                for value, token in var_tokenizer.tokenize(value, token):
+                    if value:
+                        yield index, token, str(value)
+                        index += len(value)
+
+
+class VariableTokenizer:
+
+    def tokenize(self, string, token):
+        var = VariableSplitter(string, identifiers='$@%&')
+        if var.start < 0 or token in (COMMENT, ERROR):
+            yield string, token
+            return
+        for value, token in self._tokenize(var, string, token):
+            if value:
+                yield value, token
+
+    def _tokenize(self, var, string, orig_token):
+        before = string[:var.start]
+        yield before, orig_token
+        yield var.identifier + '{', SYNTAX
+        yield from self.tokenize(var.base, VARIABLE)
+        yield '}', SYNTAX
+        if var.index is not None:
+            yield '[', SYNTAX
+            yield from self.tokenize(var.index, VARIABLE)
+            yield ']', SYNTAX
+        yield from self.tokenize(string[var.end:], orig_token)
+
+
+class RowTokenizer:
+
+    def __init__(self):
+        self._table = UnknownTable()
+        self._splitter = RowSplitter()
+        testcases = TestCaseTable()
+        settings = SettingTable(testcases.set_default_template)
+        variables = VariableTable()
+        keywords = KeywordTable()
+        self._tables = {'settings': settings, 'setting': settings,
+                        'metadata': settings,
+                        'variables': variables, 'variable': variables,
+                        'testcases': testcases, 'testcase': testcases,
+                        'tasks': testcases, 'task': testcases,
+                        'keywords': keywords, 'keyword': keywords,
+                        'userkeywords': keywords, 'userkeyword': keywords}
+
+    def tokenize(self, row):
+        commented = False
+        heading = False
+        for index, value in enumerate(self._splitter.split(row)):
+            # First value, and every second after that, is a separator.
+            index, separator = divmod(index-1, 2)
+            if value.startswith('#'):
+                commented = True
+            elif index == 0 and value.startswith('*'):
+                self._table = self._start_table(value)
+                heading = True
+            yield from self._tokenize(value, index, commented,
+                                      separator, heading)
+        self._table.end_row()
+
+    def _start_table(self, header):
+        name = normalize(header, remove='*')
+        return self._tables.get(name, UnknownTable())
+
+    def _tokenize(self, value, index, commented, separator, heading):
+        if commented:
+            yield value, COMMENT
+        elif separator:
+            yield value, SEPARATOR
+        elif heading:
+            yield value, HEADING
+        else:
+            yield from self._table.tokenize(value, index)
+
+
+class RowSplitter:
+    _space_splitter = re.compile('( {2,})')
+    _pipe_splitter = re.compile(r'((?:^| +)\|(?: +|$))')
+
+    def split(self, row):
+        splitter = (row.startswith('| ') and self._split_from_pipes
+                    or self._split_from_spaces)
+        yield from splitter(row)
+        yield '\n'
+
+    def _split_from_spaces(self, row):
+        yield ''  # Start with (pseudo)separator similarly as with pipes
+        yield from self._space_splitter.split(row)
+
+    def _split_from_pipes(self, row):
+        _, separator, rest = self._pipe_splitter.split(row, 1)
+        yield separator
+        while self._pipe_splitter.search(rest):
+            cell, separator, rest = self._pipe_splitter.split(rest, 1)
+            yield cell
+            yield separator
+        yield rest
+
+
+class Tokenizer:
+    _tokens = None
+
+    def __init__(self):
+        self._index = 0
+
+    def tokenize(self, value):
+        values_and_tokens = self._tokenize(value, self._index)
+        self._index += 1
+        if isinstance(values_and_tokens, type(Token)):
+            values_and_tokens = [(value, values_and_tokens)]
+        return values_and_tokens
+
+    def _tokenize(self, value, index):
+        index = min(index, len(self._tokens) - 1)
+        return self._tokens[index]
+
+    def _is_assign(self, value):
+        if value.endswith('='):
+            value = value[:-1].strip()
+        var = VariableSplitter(value, identifiers='$@&')
+        return var.start == 0 and var.end == len(value)
+
+
+class Comment(Tokenizer):
+    _tokens = (COMMENT,)
+
+
+class Setting(Tokenizer):
+    _tokens = (SETTING, ARGUMENT)
+    _keyword_settings = ('suitesetup', 'suiteprecondition', 'suiteteardown',
+                         'suitepostcondition', 'testsetup', 'tasksetup', 'testprecondition',
+                         'testteardown','taskteardown', 'testpostcondition', 'testtemplate', 'tasktemplate')
+    _import_settings = ('library', 'resource', 'variables')
+    _other_settings = ('documentation', 'metadata', 'forcetags', 'defaulttags',
+                       'testtimeout','tasktimeout')
+    _custom_tokenizer = None
+
+    def __init__(self, template_setter=None):
+        Tokenizer.__init__(self)
+        self._template_setter = template_setter
+
+    def _tokenize(self, value, index):
+        if index == 1 and self._template_setter:
+            self._template_setter(value)
+        if index == 0:
+            normalized = normalize(value)
+            if normalized in self._keyword_settings:
+                self._custom_tokenizer = KeywordCall(support_assign=False)
+            elif normalized in self._import_settings:
+                self._custom_tokenizer = ImportSetting()
+            elif normalized not in self._other_settings:
+                return ERROR
+        elif self._custom_tokenizer:
+            return self._custom_tokenizer.tokenize(value)
+        return Tokenizer._tokenize(self, value, index)
+
+
+class ImportSetting(Tokenizer):
+    _tokens = (IMPORT, ARGUMENT)
+
+
+class TestCaseSetting(Setting):
+    _keyword_settings = ('setup', 'precondition', 'teardown', 'postcondition',
+                         'template')
+    _import_settings = ()
+    _other_settings = ('documentation', 'tags', 'timeout')
+
+    def _tokenize(self, value, index):
+        if index == 0:
+            type = Setting._tokenize(self, value[1:-1], index)
+            return [('[', SYNTAX), (value[1:-1], type), (']', SYNTAX)]
+        return Setting._tokenize(self, value, index)
+
+
+class KeywordSetting(TestCaseSetting):
+    _keyword_settings = ('teardown',)
+    _other_settings = ('documentation', 'arguments', 'return', 'timeout', 'tags')
+
+
+class Variable(Tokenizer):
+    _tokens = (SYNTAX, ARGUMENT)
+
+    def _tokenize(self, value, index):
+        if index == 0 and not self._is_assign(value):
+            return ERROR
+        return Tokenizer._tokenize(self, value, index)
+
+
+class KeywordCall(Tokenizer):
+    _tokens = (KEYWORD, ARGUMENT)
+
+    def __init__(self, support_assign=True):
+        Tokenizer.__init__(self)
+        self._keyword_found = not support_assign
+        self._assigns = 0
+
+    def _tokenize(self, value, index):
+        if not self._keyword_found and self._is_assign(value):
+            self._assigns += 1
+            return SYNTAX  # VariableTokenizer tokenizes this later.
+        if self._keyword_found:
+            return Tokenizer._tokenize(self, value, index - self._assigns)
+        self._keyword_found = True
+        return GherkinTokenizer().tokenize(value, KEYWORD)
+
+
+class GherkinTokenizer:
+    _gherkin_prefix = re.compile('^(Given|When|Then|And|But) ', re.IGNORECASE)
+
+    def tokenize(self, value, token):
+        match = self._gherkin_prefix.match(value)
+        if not match:
+            return [(value, token)]
+        end = match.end()
+        return [(value[:end], GHERKIN), (value[end:], token)]
+
+
+class TemplatedKeywordCall(Tokenizer):
+    _tokens = (ARGUMENT,)
+
+
+class ForLoop(Tokenizer):
+
+    def __init__(self):
+        Tokenizer.__init__(self)
+        self._in_arguments = False
+
+    def _tokenize(self, value, index):
+        token = self._in_arguments and ARGUMENT or SYNTAX
+        if value.upper() in ('IN', 'IN RANGE'):
+            self._in_arguments = True
+        return token
+
+
+class _Table:
+    _tokenizer_class = None
+
+    def __init__(self, prev_tokenizer=None):
+        self._tokenizer = self._tokenizer_class()
+        self._prev_tokenizer = prev_tokenizer
+        self._prev_values_on_row = []
+
+    def tokenize(self, value, index):
+        if self._continues(value, index):
+            self._tokenizer = self._prev_tokenizer
+            yield value, SYNTAX
+        else:
+            yield from self._tokenize(value, index)
+        self._prev_values_on_row.append(value)
+
+    def _continues(self, value, index):
+        return value == '...' and all(self._is_empty(t)
+                                      for t in self._prev_values_on_row)
+
+    def _is_empty(self, value):
+        return value in ('', '\\')
+
+    def _tokenize(self, value, index):
+        return self._tokenizer.tokenize(value)
+
+    def end_row(self):
+        self.__init__(prev_tokenizer=self._tokenizer)
+
+
+class UnknownTable(_Table):
+    _tokenizer_class = Comment
+
+    def _continues(self, value, index):
+        return False
+
+
+class VariableTable(_Table):
+    _tokenizer_class = Variable
+
+
+class SettingTable(_Table):
+    _tokenizer_class = Setting
+
+    def __init__(self, template_setter, prev_tokenizer=None):
+        _Table.__init__(self, prev_tokenizer)
+        self._template_setter = template_setter
+
+    def _tokenize(self, value, index):
+        if index == 0 and normalize(value) == 'testtemplate':
+            self._tokenizer = Setting(self._template_setter)
+        return _Table._tokenize(self, value, index)
+
+    def end_row(self):
+        self.__init__(self._template_setter, prev_tokenizer=self._tokenizer)
+
+
+class TestCaseTable(_Table):
+    _setting_class = TestCaseSetting
+    _test_template = None
+    _default_template = None
+
+    @property
+    def _tokenizer_class(self):
+        if self._test_template or (self._default_template and
+                                   self._test_template is not False):
+            return TemplatedKeywordCall
+        return KeywordCall
+
+    def _continues(self, value, index):
+        return index > 0 and _Table._continues(self, value, index)
+
+    def _tokenize(self, value, index):
+        if index == 0:
+            if value:
+                self._test_template = None
+            return GherkinTokenizer().tokenize(value, TC_KW_NAME)
+        if index == 1 and self._is_setting(value):
+            if self._is_template(value):
+                self._test_template = False
+                self._tokenizer = self._setting_class(self.set_test_template)
+            else:
+                self._tokenizer = self._setting_class()
+        if index == 1 and self._is_for_loop(value):
+            self._tokenizer = ForLoop()
+        if index == 1 and self._is_empty(value):
+            return [(value, SYNTAX)]
+        return _Table._tokenize(self, value, index)
+
+    def _is_setting(self, value):
+        return value.startswith('[') and value.endswith(']')
+
+    def _is_template(self, value):
+        return normalize(value) == '[template]'
+
+    def _is_for_loop(self, value):
+        return value.startswith(':') and normalize(value, remove=':') == 'for'
+
+    def set_test_template(self, template):
+        self._test_template = self._is_template_set(template)
+
+    def set_default_template(self, template):
+        self._default_template = self._is_template_set(template)
+
+    def _is_template_set(self, template):
+        return normalize(template) not in ('', '\\', 'none', '${empty}')
+
+
+class KeywordTable(TestCaseTable):
+    _tokenizer_class = KeywordCall
+    _setting_class = KeywordSetting
+
+    def _is_template(self, value):
+        return False
+
+
+# Following code copied directly from Robot Framework 2.7.5.
+
+class VariableSplitter:
+
+    def __init__(self, string, identifiers):
+        self.identifier = None
+        self.base = None
+        self.index = None
+        self.start = -1
+        self.end = -1
+        self._identifiers = identifiers
+        self._may_have_internal_variables = False
+        try:
+            self._split(string)
+        except ValueError:
+            pass
+        else:
+            self._finalize()
+
+    def get_replaced_base(self, variables):
+        if self._may_have_internal_variables:
+            return variables.replace_string(self.base)
+        return self.base
+
+    def _finalize(self):
+        self.identifier = self._variable_chars[0]
+        self.base = ''.join(self._variable_chars[2:-1])
+        self.end = self.start + len(self._variable_chars)
+        if self._has_list_or_dict_variable_index():
+            self.index = ''.join(self._list_and_dict_variable_index_chars[1:-1])
+            self.end += len(self._list_and_dict_variable_index_chars)
+
+    def _has_list_or_dict_variable_index(self):
+        return self._list_and_dict_variable_index_chars\
+        and self._list_and_dict_variable_index_chars[-1] == ']'
+
+    def _split(self, string):
+        start_index, max_index = self._find_variable(string)
+        self.start = start_index
+        self._open_curly = 1
+        self._state = self._variable_state
+        self._variable_chars = [string[start_index], '{']
+        self._list_and_dict_variable_index_chars = []
+        self._string = string
+        start_index += 2
+        for index, char in enumerate(string[start_index:]):
+            index += start_index  # Giving start to enumerate only in Py 2.6+
+            try:
+                self._state(char, index)
+            except StopIteration:
+                return
+            if index  == max_index and not self._scanning_list_variable_index():
+                return
+
+    def _scanning_list_variable_index(self):
+        return self._state in [self._waiting_list_variable_index_state,
+                               self._list_variable_index_state]
+
+    def _find_variable(self, string):
+        max_end_index = string.rfind('}')
+        if max_end_index == -1:
+            raise ValueError('No variable end found')
+        if self._is_escaped(string, max_end_index):
+            return self._find_variable(string[:max_end_index])
+        start_index = self._find_start_index(string, 1, max_end_index)
+        if start_index == -1:
+            raise ValueError('No variable start found')
+        return start_index, max_end_index
+
+    def _find_start_index(self, string, start, end):
+        index = string.find('{', start, end) - 1
+        if index < 0:
+            return -1
+        if self._start_index_is_ok(string, index):
+            return index
+        return self._find_start_index(string, index+2, end)
+
+    def _start_index_is_ok(self, string, index):
+        return string[index] in self._identifiers\
+        and not self._is_escaped(string, index)
+
+    def _is_escaped(self, string, index):
+        escaped = False
+        while index > 0 and string[index-1] == '\\':
+            index -= 1
+            escaped = not escaped
+        return escaped
+
+    def _variable_state(self, char, index):
+        self._variable_chars.append(char)
+        if char == '}' and not self._is_escaped(self._string, index):
+            self._open_curly -= 1
+            if self._open_curly == 0:
+                if not self._is_list_or_dict_variable():
+                    raise StopIteration
+                self._state = self._waiting_list_variable_index_state
+        elif char in self._identifiers:
+            self._state = self._internal_variable_start_state
+
+    def _is_list_or_dict_variable(self):
+        return self._variable_chars[0] in ('@','&')
+
+    def _internal_variable_start_state(self, char, index):
+        self._state = self._variable_state
+        if char == '{':
+            self._variable_chars.append(char)
+            self._open_curly += 1
+            self._may_have_internal_variables = True
+        else:
+            self._variable_state(char, index)
+
+    def _waiting_list_variable_index_state(self, char, index):
+        if char != '[':
+            raise StopIteration
+        self._list_and_dict_variable_index_chars.append(char)
+        self._state = self._list_variable_index_state
+
+    def _list_variable_index_state(self, char, index):
+        self._list_and_dict_variable_index_chars.append(char)
+        if char == ']':
+            raise StopIteration
diff --git a/pygments/lexers/ruby.py b/pygments/lexers/ruby.py
new file mode 100644
index 0000000000000000000000000000000000000000..72aaeb5fec08c7c14e3f827d9c297c485a087e67
--- /dev/null
+++ b/pygments/lexers/ruby.py
@@ -0,0 +1,518 @@
+"""
+    pygments.lexers.ruby
+    ~~~~~~~~~~~~~~~~~~~~
+
+    Lexers for Ruby and related languages.
+
+    :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS.
+    :license: BSD, see LICENSE for details.
+"""
+
+import re
+
+from pygments.lexer import Lexer, RegexLexer, ExtendedRegexLexer, include, \
+    bygroups, default, LexerContext, do_insertions, words, line_re
+from pygments.token import Text, Comment, Operator, Keyword, Name, String, \
+    Number, Punctuation, Error, Generic, Whitespace
+from pygments.util import shebang_matches
+
+__all__ = ['RubyLexer', 'RubyConsoleLexer', 'FancyLexer']
+
+
+RUBY_OPERATORS = (
+    '*', '**', '-', '+', '-@', '+@', '/', '%', '&', '|', '^', '`', '~',
+    '[]', '[]=', '<<', '>>', '<', '<>', '<=>', '>', '>=', '==', '==='
+)
+
+
+class RubyLexer(ExtendedRegexLexer):
+    """
+    For Ruby source code.
+    """
+
+    name = 'Ruby'
+    url = 'http://www.ruby-lang.org'
+    aliases = ['ruby', 'rb', 'duby']
+    filenames = ['*.rb', '*.rbw', 'Rakefile', '*.rake', '*.gemspec',
+                 '*.rbx', '*.duby', 'Gemfile', 'Vagrantfile']
+    mimetypes = ['text/x-ruby', 'application/x-ruby']
+    version_added = ''
+
+    flags = re.DOTALL | re.MULTILINE
+
+    def heredoc_callback(self, match, ctx):
+        # okay, this is the hardest part of parsing Ruby...
+        # match: 1 = <<[-~]?, 2 = quote? 3 = name 4 = quote? 5 = rest of line
+
+        start = match.start(1)
+        yield start, Operator, match.group(1)        # <<[-~]?
+        yield match.start(2), String.Heredoc, match.group(2)   # quote ", ', `
+        yield match.start(3), String.Delimiter, match.group(3) # heredoc name
+        yield match.start(4), String.Heredoc, match.group(4)   # quote again
+
+        heredocstack = ctx.__dict__.setdefault('heredocstack', [])
+        outermost = not bool(heredocstack)
+        heredocstack.append((match.group(1) in ('<<-', '<<~'), match.group(3)))
+
+        ctx.pos = match.start(5)
+        ctx.end = match.end(5)
+        # this may find other heredocs, so limit the recursion depth
+        if len(heredocstack) < 100:
+            yield from self.get_tokens_unprocessed(context=ctx)
+        else:
+            yield ctx.pos, String.Heredoc, match.group(5)
+        ctx.pos = match.end()
+
+        if outermost:
+            # this is the outer heredoc again, now we can process them all
+            for tolerant, hdname in heredocstack:
+                lines = []
+                for match in line_re.finditer(ctx.text, ctx.pos):
+                    if tolerant:
+                        check = match.group().strip()
+                    else:
+                        check = match.group().rstrip()
+                    if check == hdname:
+                        for amatch in lines:
+                            yield amatch.start(), String.Heredoc, amatch.group()
+                        yield match.start(), String.Delimiter, match.group()
+                        ctx.pos = match.end()
+                        break
+                    else:
+                        lines.append(match)
+                else:
+                    # end of heredoc not found -- error!
+                    for amatch in lines:
+                        yield amatch.start(), Error, amatch.group()
+            ctx.end = len(ctx.text)
+            del heredocstack[:]
+
+    def gen_rubystrings_rules():
+        def intp_regex_callback(self, match, ctx):
+            yield match.start(1), String.Regex, match.group(1)  # begin
+            nctx = LexerContext(match.group(3), 0, ['interpolated-regex'])
+            for i, t, v in self.get_tokens_unprocessed(context=nctx):
+                yield match.start(3)+i, t, v
+            yield match.start(4), String.Regex, match.group(4)  # end[mixounse]*
+            ctx.pos = match.end()
+
+        def intp_string_callback(self, match, ctx):
+            yield match.start(1), String.Other, match.group(1)
+            nctx = LexerContext(match.group(3), 0, ['interpolated-string'])
+            for i, t, v in self.get_tokens_unprocessed(context=nctx):
+                yield match.start(3)+i, t, v
+            yield match.start(4), String.Other, match.group(4)  # end
+            ctx.pos = match.end()
+
+        states = {}
+        states['strings'] = [
+            # easy ones
+            (r'\:@{0,2}[a-zA-Z_]\w*[!?]?', String.Symbol),
+            (words(RUBY_OPERATORS, prefix=r'\:@{0,2}'), String.Symbol),
+            (r":'(\\\\|\\[^\\]|[^'\\])*'", String.Symbol),
+            (r':"', String.Symbol, 'simple-sym'),
+            (r'([a-zA-Z_]\w*)(:)(?!:)',
+             bygroups(String.Symbol, Punctuation)),  # Since Ruby 1.9
+            (r'"', String.Double, 'simple-string-double'),
+            (r"'", String.Single, 'simple-string-single'),
+            (r'(?', '<>', 'ab'):
+            states[name+'-intp-string'] = [
+                (r'\\[\\' + bracecc + ']', String.Other),
+                (lbrace, String.Other, '#push'),
+                (rbrace, String.Other, '#pop'),
+                include('string-intp-escaped'),
+                (r'[\\#' + bracecc + ']', String.Other),
+                (r'[^\\#' + bracecc + ']+', String.Other),
+            ]
+            states['strings'].append((r'%[QWx]?' + lbrace, String.Other,
+                                      name+'-intp-string'))
+            states[name+'-string'] = [
+                (r'\\[\\' + bracecc + ']', String.Other),
+                (lbrace, String.Other, '#push'),
+                (rbrace, String.Other, '#pop'),
+                (r'[\\#' + bracecc + ']', String.Other),
+                (r'[^\\#' + bracecc + ']+', String.Other),
+            ]
+            states['strings'].append((r'%[qsw]' + lbrace, String.Other,
+                                      name+'-string'))
+            states[name+'-regex'] = [
+                (r'\\[\\' + bracecc + ']', String.Regex),
+                (lbrace, String.Regex, '#push'),
+                (rbrace + '[mixounse]*', String.Regex, '#pop'),
+                include('string-intp'),
+                (r'[\\#' + bracecc + ']', String.Regex),
+                (r'[^\\#' + bracecc + ']+', String.Regex),
+            ]
+            states['strings'].append((r'%r' + lbrace, String.Regex,
+                                      name+'-regex'))
+
+        # these must come after %!
+        states['strings'] += [
+            # %r regex
+            (r'(%r([\W_]))((?:\\\2|(?!\2).)*)(\2[mixounse]*)',
+             intp_regex_callback),
+            # regular fancy strings with qsw
+            (r'%[qsw]([\W_])((?:\\\1|(?!\1).)*)\1', String.Other),
+            (r'(%[QWx]([\W_]))((?:\\\2|(?!\2).)*)(\2)',
+             intp_string_callback),
+            # special forms of fancy strings after operators or
+            # in method calls with braces
+            (r'(?<=[-+/*%=<>&!^|~,(])(\s*)(%([\t ])(?:(?:\\\3|(?!\3).)*)\3)',
+             bygroups(Whitespace, String.Other, None)),
+            # and because of fixed width lookbehinds the whole thing a
+            # second time for line startings...
+            (r'^(\s*)(%([\t ])(?:(?:\\\3|(?!\3).)*)\3)',
+             bygroups(Whitespace, String.Other, None)),
+            # all regular fancy strings without qsw
+            (r'(%([^a-zA-Z0-9\s]))((?:\\\2|(?!\2).)*)(\2)',
+             intp_string_callback),
+        ]
+
+        return states
+
+    tokens = {
+        'root': [
+            (r'\A#!.+?$', Comment.Hashbang),
+            (r'#.*?$', Comment.Single),
+            (r'=begin\s.*?\n=end.*?$', Comment.Multiline),
+            # keywords
+            (words((
+                'BEGIN', 'END', 'alias', 'begin', 'break', 'case', 'defined?',
+                'do', 'else', 'elsif', 'end', 'ensure', 'for', 'if', 'in', 'next', 'redo',
+                'rescue', 'raise', 'retry', 'return', 'super', 'then', 'undef',
+                'unless', 'until', 'when', 'while', 'yield'), suffix=r'\b'),
+             Keyword),
+            # start of function, class and module names
+            (r'(module)(\s+)([a-zA-Z_]\w*'
+             r'(?:::[a-zA-Z_]\w*)*)',
+             bygroups(Keyword, Whitespace, Name.Namespace)),
+            (r'(def)(\s+)', bygroups(Keyword, Whitespace), 'funcname'),
+            (r'def(?=[*%&^`~+-/\[<>=])', Keyword, 'funcname'),
+            (r'(class)(\s+)', bygroups(Keyword, Whitespace), 'classname'),
+            # special methods
+            (words((
+                'initialize', 'new', 'loop', 'include', 'extend', 'raise', 'attr_reader',
+                'attr_writer', 'attr_accessor', 'attr', 'catch', 'throw', 'private',
+                'module_function', 'public', 'protected', 'true', 'false', 'nil'),
+                suffix=r'\b'),
+             Keyword.Pseudo),
+            (r'(not|and|or)\b', Operator.Word),
+            (words((
+                'autoload', 'block_given', 'const_defined', 'eql', 'equal', 'frozen', 'include',
+                'instance_of', 'is_a', 'iterator', 'kind_of', 'method_defined', 'nil',
+                'private_method_defined', 'protected_method_defined',
+                'public_method_defined', 'respond_to', 'tainted'), suffix=r'\?'),
+             Name.Builtin),
+            (r'(chomp|chop|exit|gsub|sub)!', Name.Builtin),
+            (words((
+                'Array', 'Float', 'Integer', 'String', '__id__', '__send__', 'abort',
+                'ancestors', 'at_exit', 'autoload', 'binding', 'callcc', 'caller',
+                'catch', 'chomp', 'chop', 'class_eval', 'class_variables',
+                'clone', 'const_defined?', 'const_get', 'const_missing', 'const_set',
+                'constants', 'display', 'dup', 'eval', 'exec', 'exit', 'extend', 'fail', 'fork',
+                'format', 'freeze', 'getc', 'gets', 'global_variables', 'gsub',
+                'hash', 'id', 'included_modules', 'inspect', 'instance_eval',
+                'instance_method', 'instance_methods',
+                'instance_variable_get', 'instance_variable_set', 'instance_variables',
+                'lambda', 'load', 'local_variables', 'loop',
+                'method', 'method_missing', 'methods', 'module_eval', 'name',
+                'object_id', 'open', 'p', 'print', 'printf', 'private_class_method',
+                'private_instance_methods',
+                'private_methods', 'proc', 'protected_instance_methods',
+                'protected_methods', 'public_class_method',
+                'public_instance_methods', 'public_methods',
+                'putc', 'puts', 'raise', 'rand', 'readline', 'readlines', 'require',
+                'scan', 'select', 'self', 'send', 'set_trace_func', 'singleton_methods', 'sleep',
+                'split', 'sprintf', 'srand', 'sub', 'syscall', 'system', 'taint',
+                'test', 'throw', 'to_a', 'to_s', 'trace_var', 'trap', 'untaint',
+                'untrace_var', 'warn'), prefix=r'(?~!:])|'
+             r'(?<=(?:\s|;)when\s)|'
+             r'(?<=(?:\s|;)or\s)|'
+             r'(?<=(?:\s|;)and\s)|'
+             r'(?<=\.index\s)|'
+             r'(?<=\.scan\s)|'
+             r'(?<=\.sub\s)|'
+             r'(?<=\.sub!\s)|'
+             r'(?<=\.gsub\s)|'
+             r'(?<=\.gsub!\s)|'
+             r'(?<=\.match\s)|'
+             r'(?<=(?:\s|;)if\s)|'
+             r'(?<=(?:\s|;)elsif\s)|'
+             r'(?<=^when\s)|'
+             r'(?<=^index\s)|'
+             r'(?<=^scan\s)|'
+             r'(?<=^sub\s)|'
+             r'(?<=^gsub\s)|'
+             r'(?<=^sub!\s)|'
+             r'(?<=^gsub!\s)|'
+             r'(?<=^match\s)|'
+             r'(?<=^if\s)|'
+             r'(?<=^elsif\s)'
+             r')(\s*)(/)', bygroups(Text, String.Regex), 'multiline-regex'),
+            # multiline regex (in method calls or subscripts)
+            (r'(?<=\(|,|\[)/', String.Regex, 'multiline-regex'),
+            # multiline regex (this time the funny no whitespace rule)
+            (r'(\s+)(/)(?![\s=])', bygroups(Whitespace, String.Regex),
+             'multiline-regex'),
+            # lex numbers and ignore following regular expressions which
+            # are division operators in fact (grrrr. i hate that. any
+            # better ideas?)
+            # since pygments 0.7 we also eat a "?" operator after numbers
+            # so that the char operator does not work. Chars are not allowed
+            # there so that you can use the ternary operator.
+            # stupid example:
+            #   x>=0?n[x]:""
+            (r'(0_?[0-7]+(?:_[0-7]+)*)(\s*)([/?])?',
+             bygroups(Number.Oct, Whitespace, Operator)),
+            (r'(0x[0-9A-Fa-f]+(?:_[0-9A-Fa-f]+)*)(\s*)([/?])?',
+             bygroups(Number.Hex, Whitespace, Operator)),
+            (r'(0b[01]+(?:_[01]+)*)(\s*)([/?])?',
+             bygroups(Number.Bin, Whitespace, Operator)),
+            (r'([\d]+(?:_\d+)*)(\s*)([/?])?',
+             bygroups(Number.Integer, Whitespace, Operator)),
+            # Names
+            (r'@@[a-zA-Z_]\w*', Name.Variable.Class),
+            (r'@[a-zA-Z_]\w*', Name.Variable.Instance),
+            (r'\$\w+', Name.Variable.Global),
+            (r'\$[!@&`\'+~=/\\,;.<>_*$?:"^-]', Name.Variable.Global),
+            (r'\$-[0adFiIlpvw]', Name.Variable.Global),
+            (r'::', Operator),
+            include('strings'),
+            # chars
+            (r'\?(\\[MC]-)*'  # modifiers
+             r'(\\([\\abefnrstv#"\']|x[a-fA-F0-9]{1,2}|[0-7]{1,3})|\S)'
+             r'(?!\w)',
+             String.Char),
+            (r'[A-Z]\w+', Name.Constant),
+            # this is needed because ruby attributes can look
+            # like keywords (class) or like this: ` ?!?
+            (words(RUBY_OPERATORS, prefix=r'(\.|::)'),
+             bygroups(Operator, Name.Operator)),
+            (r'(\.|::)([a-zA-Z_]\w*[!?]?|[*%&^`~+\-/\[<>=])',
+             bygroups(Operator, Name)),
+            (r'[a-zA-Z_]\w*[!?]?', Name),
+            (r'(\[|\]|\*\*|<>?|>=|<=|<=>|=~|={3}|'
+             r'!~|&&?|\|\||\.{1,3})', Operator),
+            (r'[-+/*%=<>&!^|~]=?', Operator),
+            (r'[(){};,/?:\\]', Punctuation),
+            (r'\s+', Whitespace)
+        ],
+        'funcname': [
+            (r'\(', Punctuation, 'defexpr'),
+            (r'(?:([a-zA-Z_]\w*)(\.))?'  # optional scope name, like "self."
+             r'('
+                r'[a-zA-Z\u0080-\uffff][a-zA-Z0-9_\u0080-\uffff]*[!?=]?'  # method name
+                r'|!=|!~|=~|\*\*?|[-+!~]@?|[/%&|^]|<=>|<[<=]?|>[>=]?|===?'  # or operator override
+                r'|\[\]=?'  # or element reference/assignment override
+                r'|`'  # or the undocumented backtick override
+             r')',
+             bygroups(Name.Class, Operator, Name.Function), '#pop'),
+            default('#pop')
+        ],
+        'classname': [
+            (r'\(', Punctuation, 'defexpr'),
+            (r'<<', Operator, '#pop'),
+            (r'[A-Z_]\w*', Name.Class, '#pop'),
+            default('#pop')
+        ],
+        'defexpr': [
+            (r'(\))(\.|::)?', bygroups(Punctuation, Operator), '#pop'),
+            (r'\(', Operator, '#push'),
+            include('root')
+        ],
+        'in-intp': [
+            (r'\{', String.Interpol, '#push'),
+            (r'\}', String.Interpol, '#pop'),
+            include('root'),
+        ],
+        'string-intp': [
+            (r'#\{', String.Interpol, 'in-intp'),
+            (r'#@@?[a-zA-Z_]\w*', String.Interpol),
+            (r'#\$[a-zA-Z_]\w*', String.Interpol)
+        ],
+        'string-intp-escaped': [
+            include('string-intp'),
+            (r'\\([\\abefnrstv#"\']|x[a-fA-F0-9]{1,2}|[0-7]{1,3})',
+             String.Escape)
+        ],
+        'interpolated-regex': [
+            include('string-intp'),
+            (r'[\\#]', String.Regex),
+            (r'[^\\#]+', String.Regex),
+        ],
+        'interpolated-string': [
+            include('string-intp'),
+            (r'[\\#]', String.Other),
+            (r'[^\\#]+', String.Other),
+        ],
+        'multiline-regex': [
+            include('string-intp'),
+            (r'\\\\', String.Regex),
+            (r'\\/', String.Regex),
+            (r'[\\#]', String.Regex),
+            (r'[^\\/#]+', String.Regex),
+            (r'/[mixounse]*', String.Regex, '#pop'),
+        ],
+        'end-part': [
+            (r'.+', Comment.Preproc, '#pop')
+        ]
+    }
+    tokens.update(gen_rubystrings_rules())
+
+    def analyse_text(text):
+        return shebang_matches(text, r'ruby(1\.\d)?')
+
+
+class RubyConsoleLexer(Lexer):
+    """
+    For Ruby interactive console (**irb**) output.
+    """
+    name = 'Ruby irb session'
+    aliases = ['rbcon', 'irb']
+    mimetypes = ['text/x-ruby-shellsession']
+    url = 'https://www.ruby-lang.org'
+    version_added = ''
+    _example = 'rbcon/console'
+
+    _prompt_re = re.compile(r'irb\([a-zA-Z_]\w*\):\d{3}:\d+[>*"\'] '
+                            r'|>> |\?> ')
+
+    def get_tokens_unprocessed(self, text):
+        rblexer = RubyLexer(**self.options)
+
+        curcode = ''
+        insertions = []
+        for match in line_re.finditer(text):
+            line = match.group()
+            m = self._prompt_re.match(line)
+            if m is not None:
+                end = m.end()
+                insertions.append((len(curcode),
+                                   [(0, Generic.Prompt, line[:end])]))
+                curcode += line[end:]
+            else:
+                if curcode:
+                    yield from do_insertions(
+                        insertions, rblexer.get_tokens_unprocessed(curcode))
+                    curcode = ''
+                    insertions = []
+                yield match.start(), Generic.Output, line
+        if curcode:
+            yield from do_insertions(
+                insertions, rblexer.get_tokens_unprocessed(curcode))
+
+
+class FancyLexer(RegexLexer):
+    """
+    Pygments Lexer For Fancy.
+
+    Fancy is a self-hosted, pure object-oriented, dynamic,
+    class-based, concurrent general-purpose programming language
+    running on Rubinius, the Ruby VM.
+    """
+    name = 'Fancy'
+    url = 'https://github.com/bakkdoor/fancy'
+    filenames = ['*.fy', '*.fancypack']
+    aliases = ['fancy', 'fy']
+    mimetypes = ['text/x-fancysrc']
+    version_added = '1.5'
+
+    tokens = {
+        # copied from PerlLexer:
+        'balanced-regex': [
+            (r'/(\\\\|\\[^\\]|[^/\\])*/[egimosx]*', String.Regex, '#pop'),
+            (r'!(\\\\|\\[^\\]|[^!\\])*![egimosx]*', String.Regex, '#pop'),
+            (r'\\(\\\\|[^\\])*\\[egimosx]*', String.Regex, '#pop'),
+            (r'\{(\\\\|\\[^\\]|[^}\\])*\}[egimosx]*', String.Regex, '#pop'),
+            (r'<(\\\\|\\[^\\]|[^>\\])*>[egimosx]*', String.Regex, '#pop'),
+            (r'\[(\\\\|\\[^\\]|[^\]\\])*\][egimosx]*', String.Regex, '#pop'),
+            (r'\((\\\\|\\[^\\]|[^)\\])*\)[egimosx]*', String.Regex, '#pop'),
+            (r'@(\\\\|\\[^\\]|[^@\\])*@[egimosx]*', String.Regex, '#pop'),
+            (r'%(\\\\|\\[^\\]|[^%\\])*%[egimosx]*', String.Regex, '#pop'),
+            (r'\$(\\\\|\\[^\\]|[^$\\])*\$[egimosx]*', String.Regex, '#pop'),
+        ],
+        'root': [
+            (r'\s+', Whitespace),
+
+            # balanced delimiters (copied from PerlLexer):
+            (r's\{(\\\\|\\[^\\]|[^}\\])*\}\s*', String.Regex, 'balanced-regex'),
+            (r's<(\\\\|\\[^\\]|[^>\\])*>\s*', String.Regex, 'balanced-regex'),
+            (r's\[(\\\\|\\[^\\]|[^\]\\])*\]\s*', String.Regex, 'balanced-regex'),
+            (r's\((\\\\|\\[^\\]|[^)\\])*\)\s*', String.Regex, 'balanced-regex'),
+            (r'm?/(\\\\|\\[^\\]|[^///\n])*/[gcimosx]*', String.Regex),
+            (r'm(?=[/!\\{<\[(@%$])', String.Regex, 'balanced-regex'),
+
+            # Comments
+            (r'#(.*?)\n', Comment.Single),
+            # Symbols
+            (r'\'([^\'\s\[\](){}]+|\[\])', String.Symbol),
+            # Multi-line DoubleQuotedString
+            (r'"""(\\\\|\\[^\\]|[^\\])*?"""', String),
+            # DoubleQuotedString
+            (r'"(\\\\|\\[^\\]|[^"\\])*"', String),
+            # keywords
+            (r'(def|class|try|catch|finally|retry|return|return_local|match|'
+             r'case|->|=>)\b', Keyword),
+            # constants
+            (r'(self|super|nil|false|true)\b', Name.Constant),
+            (r'[(){};,/?|:\\]', Punctuation),
+            # names
+            (words((
+                'Object', 'Array', 'Hash', 'Directory', 'File', 'Class', 'String',
+                'Number', 'Enumerable', 'FancyEnumerable', 'Block', 'TrueClass',
+                'NilClass', 'FalseClass', 'Tuple', 'Symbol', 'Stack', 'Set',
+                'FancySpec', 'Method', 'Package', 'Range'), suffix=r'\b'),
+             Name.Builtin),
+            # functions
+            (r'[a-zA-Z](\w|[-+?!=*/^><%])*:', Name.Function),
+            # operators, must be below functions
+            (r'[-+*/~,<>=&!?%^\[\].$]+', Operator),
+            (r'[A-Z]\w*', Name.Constant),
+            (r'@[a-zA-Z_]\w*', Name.Variable.Instance),
+            (r'@@[a-zA-Z_]\w*', Name.Variable.Class),
+            ('@@?', Operator),
+            (r'[a-zA-Z_]\w*', Name),
+            # numbers - / checks are necessary to avoid mismarking regexes,
+            # see comment in RubyLexer
+            (r'(0[oO]?[0-7]+(?:_[0-7]+)*)(\s*)([/?])?',
+             bygroups(Number.Oct, Whitespace, Operator)),
+            (r'(0[xX][0-9A-Fa-f]+(?:_[0-9A-Fa-f]+)*)(\s*)([/?])?',
+             bygroups(Number.Hex, Whitespace, Operator)),
+            (r'(0[bB][01]+(?:_[01]+)*)(\s*)([/?])?',
+             bygroups(Number.Bin, Whitespace, Operator)),
+            (r'([\d]+(?:_\d+)*)(\s*)([/?])?',
+             bygroups(Number.Integer, Whitespace, Operator)),
+            (r'\d+([eE][+-]?[0-9]+)|\d+\.\d+([eE][+-]?[0-9]+)?', Number.Float),
+            (r'\d+', Number.Integer)
+        ]
+    }
diff --git a/pygments/lexers/rust.py b/pygments/lexers/rust.py
new file mode 100644
index 0000000000000000000000000000000000000000..63410475553d44d7aa080e075ac71180f3160d6c
--- /dev/null
+++ b/pygments/lexers/rust.py
@@ -0,0 +1,222 @@
+"""
+    pygments.lexers.rust
+    ~~~~~~~~~~~~~~~~~~~~
+
+    Lexers for the Rust language.
+
+    :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS.
+    :license: BSD, see LICENSE for details.
+"""
+
+from pygments.lexer import RegexLexer, include, bygroups, words, default
+from pygments.token import Text, Comment, Operator, Keyword, Name, String, \
+    Number, Punctuation, Whitespace
+
+__all__ = ['RustLexer']
+
+
+class RustLexer(RegexLexer):
+    """
+    Lexer for the Rust programming language (version 1.47).
+    """
+    name = 'Rust'
+    url = 'https://www.rust-lang.org/'
+    filenames = ['*.rs', '*.rs.in']
+    aliases = ['rust', 'rs']
+    mimetypes = ['text/rust', 'text/x-rust']
+    version_added = '1.6'
+
+    keyword_types = (words((
+        'u8', 'u16', 'u32', 'u64', 'u128', 'i8', 'i16', 'i32', 'i64', 'i128',
+        'usize', 'isize', 'f32', 'f64', 'char', 'str', 'bool',
+    ), suffix=r'\b'), Keyword.Type)
+
+    builtin_funcs_types = (words((
+        'Copy', 'Send', 'Sized', 'Sync', 'Unpin',
+        'Drop', 'Fn', 'FnMut', 'FnOnce', 'drop',
+        'Box', 'ToOwned', 'Clone',
+        'PartialEq', 'PartialOrd', 'Eq', 'Ord',
+        'AsRef', 'AsMut', 'Into', 'From', 'Default',
+        'Iterator', 'Extend', 'IntoIterator', 'DoubleEndedIterator',
+        'ExactSizeIterator',
+        'Option', 'Some', 'None',
+        'Result', 'Ok', 'Err',
+        'String', 'ToString', 'Vec',
+    ), suffix=r'\b'), Name.Builtin)
+
+    builtin_macros = (words((
+        'asm', 'assert', 'assert_eq', 'assert_ne', 'cfg', 'column',
+        'compile_error', 'concat', 'concat_idents', 'dbg', 'debug_assert',
+        'debug_assert_eq', 'debug_assert_ne', 'env', 'eprint', 'eprintln',
+        'file', 'format', 'format_args', 'format_args_nl', 'global_asm',
+        'include', 'include_bytes', 'include_str',
+        'is_aarch64_feature_detected',
+        'is_arm_feature_detected',
+        'is_mips64_feature_detected',
+        'is_mips_feature_detected',
+        'is_powerpc64_feature_detected',
+        'is_powerpc_feature_detected',
+        'is_x86_feature_detected',
+        'line', 'llvm_asm', 'log_syntax', 'macro_rules', 'matches',
+        'module_path', 'option_env', 'panic', 'print', 'println', 'stringify',
+        'thread_local', 'todo', 'trace_macros', 'unimplemented', 'unreachable',
+        'vec', 'write', 'writeln',
+    ), suffix=r'!'), Name.Function.Magic)
+
+    tokens = {
+        'root': [
+            # rust allows a file to start with a shebang, but if the first line
+            # starts with #![ then it's not a shebang but a crate attribute.
+            (r'#![^[\r\n].*$', Comment.Preproc),
+            default('base'),
+        ],
+        'base': [
+            # Whitespace and Comments
+            (r'\n', Whitespace),
+            (r'\s+', Whitespace),
+            (r'//!.*?\n', String.Doc),
+            (r'///(\n|[^/].*?\n)', String.Doc),
+            (r'//(.*?)\n', Comment.Single),
+            (r'/\*\*(\n|[^/*])', String.Doc, 'doccomment'),
+            (r'/\*!', String.Doc, 'doccomment'),
+            (r'/\*', Comment.Multiline, 'comment'),
+
+            # Macro parameters
+            (r"""\$([a-zA-Z_]\w*|\(,?|\),?|,?)""", Comment.Preproc),
+            # Keywords
+            (words(('as', 'async', 'await', 'box', 'const', 'crate', 'dyn',
+                    'else', 'extern', 'for', 'if', 'impl', 'in', 'loop',
+                    'match', 'move', 'mut', 'pub', 'ref', 'return', 'static',
+                    'super', 'trait', 'unsafe', 'use', 'where', 'while'),
+                   suffix=r'\b'), Keyword),
+            (words(('abstract', 'become', 'do', 'final', 'macro', 'override',
+                    'priv', 'typeof', 'try', 'unsized', 'virtual', 'yield'),
+                   suffix=r'\b'), Keyword.Reserved),
+            (r'(true|false)\b', Keyword.Constant),
+            (r'self\b', Name.Builtin.Pseudo),
+            (r'mod\b', Keyword, 'modname'),
+            (r'let\b', Keyword.Declaration),
+            (r'fn\b', Keyword, 'funcname'),
+            (r'(struct|enum|type|union)\b', Keyword, 'typename'),
+            (r'(default)(\s+)(type|fn)\b', bygroups(Keyword, Whitespace, Keyword)),
+            keyword_types,
+            (r'[sS]elf\b', Name.Builtin.Pseudo),
+            # Prelude (taken from Rust's src/libstd/prelude.rs)
+            builtin_funcs_types,
+            builtin_macros,
+            # Path separators, so types don't catch them.
+            (r'::\b', Punctuation),
+            # Types in positions.
+            (r'(?::|->)', Punctuation, 'typename'),
+            # Labels
+            (r'(break|continue)(\b\s*)(\'[A-Za-z_]\w*)?',
+             bygroups(Keyword, Text.Whitespace, Name.Label)),
+
+            # Character literals
+            (r"""'(\\['"\\nrt]|\\x[0-7][0-9a-fA-F]|\\0"""
+             r"""|\\u\{[0-9a-fA-F]{1,6}\}|.)'""",
+             String.Char),
+            (r"""b'(\\['"\\nrt]|\\x[0-9a-fA-F]{2}|\\0"""
+             r"""|\\u\{[0-9a-fA-F]{1,6}\}|.)'""",
+             String.Char),
+
+            # Binary literals
+            (r'0b[01_]+', Number.Bin, 'number_lit'),
+            # Octal literals
+            (r'0o[0-7_]+', Number.Oct, 'number_lit'),
+            # Hexadecimal literals
+            (r'0[xX][0-9a-fA-F_]+', Number.Hex, 'number_lit'),
+            # Decimal literals
+            (r'[0-9][0-9_]*(\.[0-9_]+[eE][+\-]?[0-9_]+|'
+             r'\.[0-9_]*(?!\.)|[eE][+\-]?[0-9_]+)', Number.Float,
+             'number_lit'),
+            (r'[0-9][0-9_]*', Number.Integer, 'number_lit'),
+
+            # String literals
+            (r'b"', String, 'bytestring'),
+            (r'"', String, 'string'),
+            (r'(?s)b?r(#*)".*?"\1', String),
+
+            # Lifetime names
+            (r"'", Operator, 'lifetime'),
+
+            # Operators and Punctuation
+            (r'\.\.=?', Operator),
+            (r'[{}()\[\],.;]', Punctuation),
+            (r'[+\-*/%&|<>^!~@=:?]', Operator),
+
+            # Identifiers
+            (r'[a-zA-Z_]\w*', Name),
+            # Raw identifiers
+            (r'r#[a-zA-Z_]\w*', Name),
+
+            # Attributes
+            (r'#!?\[', Comment.Preproc, 'attribute['),
+
+            # Misc
+            # Lone hashes: not used in Rust syntax, but allowed in macro
+            # arguments, most famously for quote::quote!()
+            (r'#', Punctuation),
+        ],
+        'comment': [
+            (r'[^*/]+', Comment.Multiline),
+            (r'/\*', Comment.Multiline, '#push'),
+            (r'\*/', Comment.Multiline, '#pop'),
+            (r'[*/]', Comment.Multiline),
+        ],
+        'doccomment': [
+            (r'[^*/]+', String.Doc),
+            (r'/\*', String.Doc, '#push'),
+            (r'\*/', String.Doc, '#pop'),
+            (r'[*/]', String.Doc),
+        ],
+        'modname': [
+            (r'\s+', Whitespace),
+            (r'[a-zA-Z_]\w*', Name.Namespace, '#pop'),
+            default('#pop'),
+        ],
+        'funcname': [
+            (r'\s+', Whitespace),
+            (r'[a-zA-Z_]\w*', Name.Function, '#pop'),
+            default('#pop'),
+        ],
+        'typename': [
+            (r'\s+', Whitespace),
+            (r'&', Keyword.Pseudo),
+            (r"'", Operator, 'lifetime'),
+            builtin_funcs_types,
+            keyword_types,
+            (r'[a-zA-Z_]\w*', Name.Class, '#pop'),
+            default('#pop'),
+        ],
+        'lifetime': [
+            (r"(static|_)", Name.Builtin),
+            (r"[a-zA-Z_]+\w*", Name.Attribute),
+            default('#pop'),
+        ],
+        'number_lit': [
+            (r'[ui](8|16|32|64|size)', Keyword, '#pop'),
+            (r'f(32|64)', Keyword, '#pop'),
+            default('#pop'),
+        ],
+        'string': [
+            (r'"', String, '#pop'),
+            (r"""\\['"\\nrt]|\\x[0-7][0-9a-fA-F]|\\0"""
+             r"""|\\u\{[0-9a-fA-F]{1,6}\}""", String.Escape),
+            (r'[^\\"]+', String),
+            (r'\\', String),
+        ],
+        'bytestring': [
+            (r"""\\x[89a-fA-F][0-9a-fA-F]""", String.Escape),
+            include('string'),
+        ],
+        'attribute_common': [
+            (r'"', String, 'string'),
+            (r'\[', Comment.Preproc, 'attribute['),
+        ],
+        'attribute[': [
+            include('attribute_common'),
+            (r'\]', Comment.Preproc, '#pop'),
+            (r'[^"\]\[]+', Comment.Preproc),
+        ],
+    }
diff --git a/pygments/lexers/sas.py b/pygments/lexers/sas.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b2ad432d231d21192c4f3653ef2e541588afd1d
--- /dev/null
+++ b/pygments/lexers/sas.py
@@ -0,0 +1,227 @@
+"""
+    pygments.lexers.sas
+    ~~~~~~~~~~~~~~~~~~~
+
+    Lexer for SAS.
+
+    :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS.
+    :license: BSD, see LICENSE for details.
+"""
+
+import re
+from pygments.lexer import RegexLexer, include, words
+from pygments.token import Comment, Keyword, Name, Number, String, Text, \
+    Other, Generic
+
+__all__ = ['SASLexer']
+
+
+class SASLexer(RegexLexer):
+    """
+    For SAS files.
+    """
+    # Syntax from syntax/sas.vim by James Kidd 
+
+    name      = 'SAS'
+    aliases   = ['sas']
+    filenames = ['*.SAS', '*.sas']
+    mimetypes = ['text/x-sas', 'text/sas', 'application/x-sas']
+    url = 'https://en.wikipedia.org/wiki/SAS_(software)'
+    version_added = '2.2'
+    flags     = re.IGNORECASE | re.MULTILINE
+
+    builtins_macros = (
+        "bquote", "nrbquote", "cmpres", "qcmpres", "compstor", "datatyp",
+        "display", "do", "else", "end", "eval", "global", "goto", "if",
+        "index", "input", "keydef", "label", "left", "length", "let",
+        "local", "lowcase", "macro", "mend", "nrquote",
+        "nrstr", "put", "qleft", "qlowcase", "qscan",
+        "qsubstr", "qsysfunc", "qtrim", "quote", "qupcase", "scan",
+        "str", "substr", "superq", "syscall", "sysevalf", "sysexec",
+        "sysfunc", "sysget", "syslput", "sysprod", "sysrc", "sysrput",
+        "then", "to", "trim", "unquote", "until", "upcase", "verify",
+        "while", "window"
+    )
+
+    builtins_conditionals = (
+        "do", "if", "then", "else", "end", "until", "while"
+    )
+
+    builtins_statements = (
+        "abort", "array", "attrib", "by", "call", "cards", "cards4",
+        "catname", "continue", "datalines", "datalines4", "delete", "delim",
+        "delimiter", "display", "dm", "drop", "endsas", "error", "file",
+        "filename", "footnote", "format", "goto", "in", "infile", "informat",
+        "input", "keep", "label", "leave", "length", "libname", "link",
+        "list", "lostcard", "merge", "missing", "modify", "options", "output",
+        "out", "page", "put", "redirect", "remove", "rename", "replace",
+        "retain", "return", "select", "set", "skip", "startsas", "stop",
+        "title", "update", "waitsas", "where", "window", "x", "systask"
+    )
+
+    builtins_sql = (
+        "add", "and", "alter", "as", "cascade", "check", "create",
+        "delete", "describe", "distinct", "drop", "foreign", "from",
+        "group", "having", "index", "insert", "into", "in", "key", "like",
+        "message", "modify", "msgtype", "not", "null", "on", "or",
+        "order", "primary", "references", "reset", "restrict", "select",
+        "set", "table", "unique", "update", "validate", "view", "where"
+    )
+
+    builtins_functions = (
+        "abs", "addr", "airy", "arcos", "arsin", "atan", "attrc",
+        "attrn", "band", "betainv", "blshift", "bnot", "bor",
+        "brshift", "bxor", "byte", "cdf", "ceil", "cexist", "cinv",
+        "close", "cnonct", "collate", "compbl", "compound",
+        "compress", "cos", "cosh", "css", "curobs", "cv", "daccdb",
+        "daccdbsl", "daccsl", "daccsyd", "dacctab", "dairy", "date",
+        "datejul", "datepart", "datetime", "day", "dclose", "depdb",
+        "depdbsl", "depsl", "depsyd",
+        "deptab", "dequote", "dhms", "dif", "digamma",
+        "dim", "dinfo", "dnum", "dopen", "doptname", "doptnum",
+        "dread", "dropnote", "dsname", "erf", "erfc", "exist", "exp",
+        "fappend", "fclose", "fcol", "fdelete", "fetch", "fetchobs",
+        "fexist", "fget", "fileexist", "filename", "fileref",
+        "finfo", "finv", "fipname", "fipnamel", "fipstate", "floor",
+        "fnonct", "fnote", "fopen", "foptname", "foptnum", "fpoint",
+        "fpos", "fput", "fread", "frewind", "frlen", "fsep", "fuzz",
+        "fwrite", "gaminv", "gamma", "getoption", "getvarc", "getvarn",
+        "hbound", "hms", "hosthelp", "hour", "ibessel", "index",
+        "indexc", "indexw", "input", "inputc", "inputn", "int",
+        "intck", "intnx", "intrr", "irr", "jbessel", "juldate",
+        "kurtosis", "lag", "lbound", "left", "length", "lgamma",
+        "libname", "libref", "log", "log10", "log2", "logpdf", "logpmf",
+        "logsdf", "lowcase", "max", "mdy", "mean", "min", "minute",
+        "mod", "month", "mopen", "mort", "n", "netpv", "nmiss",
+        "normal", "note", "npv", "open", "ordinal", "pathname",
+        "pdf", "peek", "peekc", "pmf", "point", "poisson", "poke",
+        "probbeta", "probbnml", "probchi", "probf", "probgam",
+        "probhypr", "probit", "probnegb", "probnorm", "probt",
+        "put", "putc", "putn", "qtr", "quote", "ranbin", "rancau",
+        "ranexp", "rangam", "range", "rank", "rannor", "ranpoi",
+        "rantbl", "rantri", "ranuni", "repeat", "resolve", "reverse",
+        "rewind", "right", "round", "saving", "scan", "sdf", "second",
+        "sign", "sin", "sinh", "skewness", "soundex", "spedis",
+        "sqrt", "std", "stderr", "stfips", "stname", "stnamel",
+        "substr", "sum", "symget", "sysget", "sysmsg", "sysprod",
+        "sysrc", "system", "tan", "tanh", "time", "timepart", "tinv",
+        "tnonct", "today", "translate", "tranwrd", "trigamma",
+        "trim", "trimn", "trunc", "uniform", "upcase", "uss", "var",
+        "varfmt", "varinfmt", "varlabel", "varlen", "varname",
+        "varnum", "varray", "varrayx", "vartype", "verify", "vformat",
+        "vformatd", "vformatdx", "vformatn", "vformatnx", "vformatw",
+        "vformatwx", "vformatx", "vinarray", "vinarrayx", "vinformat",
+        "vinformatd", "vinformatdx", "vinformatn", "vinformatnx",
+        "vinformatw", "vinformatwx", "vinformatx", "vlabel",
+        "vlabelx", "vlength", "vlengthx", "vname", "vnamex", "vtype",
+        "vtypex", "weekday", "year", "yyq", "zipfips", "zipname",
+        "zipnamel", "zipstate"
+    )
+
+    tokens = {
+        'root': [
+            include('comments'),
+            include('proc-data'),
+            include('cards-datalines'),
+            include('logs'),
+            include('general'),
+            (r'.', Text),
+        ],
+        # SAS is multi-line regardless, but * is ended by ;
+        'comments': [
+            (r'^\s*\*.*?;', Comment),
+            (r'/\*.*?\*/', Comment),
+            (r'^\s*\*(.|\n)*?;', Comment.Multiline),
+            (r'/[*](.|\n)*?[*]/', Comment.Multiline),
+        ],
+        # Special highlight for proc, data, quit, run
+        'proc-data': [
+            (r'(^|;)\s*(proc \w+|data|run|quit)[\s;]',
+             Keyword.Reserved),
+        ],
+        # Special highlight cards and datalines
+        'cards-datalines': [
+            (r'^\s*(datalines|cards)\s*;\s*$', Keyword, 'data'),
+        ],
+        'data': [
+            (r'(.|\n)*^\s*;\s*$', Other, '#pop'),
+        ],
+        # Special highlight for put NOTE|ERROR|WARNING (order matters)
+        'logs': [
+            (r'\n?^\s*%?put ', Keyword, 'log-messages'),
+        ],
+        'log-messages': [
+            (r'NOTE(:|-).*', Generic, '#pop'),
+            (r'WARNING(:|-).*', Generic.Emph, '#pop'),
+            (r'ERROR(:|-).*', Generic.Error, '#pop'),
+            include('general'),
+        ],
+        'general': [
+            include('keywords'),
+            include('vars-strings'),
+            include('special'),
+            include('numbers'),
+        ],
+        # Keywords, statements, functions, macros
+        'keywords': [
+            (words(builtins_statements,
+                   prefix = r'\b',
+                   suffix = r'\b'),
+             Keyword),
+            (words(builtins_sql,
+                   prefix = r'\b',
+                   suffix = r'\b'),
+             Keyword),
+            (words(builtins_conditionals,
+                   prefix = r'\b',
+                   suffix = r'\b'),
+             Keyword),
+            (words(builtins_macros,
+                   prefix = r'%',
+                   suffix = r'\b'),
+             Name.Builtin),
+            (words(builtins_functions,
+                   prefix = r'\b',
+                   suffix = r'\('),
+             Name.Builtin),
+        ],
+        # Strings and user-defined variables and macros (order matters)
+        'vars-strings': [
+            (r'&[a-z_]\w{0,31}\.?', Name.Variable),
+            (r'%[a-z_]\w{0,31}', Name.Function),
+            (r'\'', String, 'string_squote'),
+            (r'"', String, 'string_dquote'),
+        ],
+        'string_squote': [
+            ('\'', String, '#pop'),
+            (r'\\\\|\\"|\\\n', String.Escape),
+            # AFAIK, macro variables are not evaluated in single quotes
+            # (r'&', Name.Variable, 'validvar'),
+            (r'[^$\'\\]+', String),
+            (r'[$\'\\]', String),
+        ],
+        'string_dquote': [
+            (r'"', String, '#pop'),
+            (r'\\\\|\\"|\\\n', String.Escape),
+            (r'&', Name.Variable, 'validvar'),
+            (r'[^$&"\\]+', String),
+            (r'[$"\\]', String),
+        ],
+        'validvar': [
+            (r'[a-z_]\w{0,31}\.?', Name.Variable, '#pop'),
+        ],
+        # SAS numbers and special variables
+        'numbers': [
+            (r'\b[+-]?([0-9]+(\.[0-9]+)?|\.[0-9]+|\.)(E[+-]?[0-9]+)?i?\b',
+             Number),
+        ],
+        'special': [
+            (r'(null|missing|_all_|_automatic_|_character_|_n_|'
+             r'_infile_|_name_|_null_|_numeric_|_user_|_webout_)',
+             Keyword.Constant),
+        ],
+        # 'operators': [
+        #     (r'(-|=|<=|>=|<|>|<>|&|!=|'
+        #      r'\||\*|\+|\^|/|!|~|~=)', Operator)
+        # ],
+    }
diff --git a/pygments/lexers/savi.py b/pygments/lexers/savi.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e443ae302a647b6c401d9c0e6fe451e12732dde
--- /dev/null
+++ b/pygments/lexers/savi.py
@@ -0,0 +1,171 @@
+"""
+    pygments.lexers.savi
+    ~~~~~~~~~~~~~~~~~~~~
+
+    Lexer for Savi.
+
+    :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS.
+    :license: BSD, see LICENSE for details.
+"""
+
+from pygments.lexer import RegexLexer, bygroups, include
+from pygments.token import Whitespace, Keyword, Name, String, Number, \
+  Operator, Punctuation, Comment, Generic, Error
+
+__all__ = ['SaviLexer']
+
+
+# The canonical version of this file can be found in the following repository,
+# where it is kept in sync with any language changes, as well as the other
+# pygments-like lexers that are maintained for use with other tools:
+# - https://github.com/savi-lang/savi/blob/main/tooling/pygments/lexers/savi.py
+#
+# If you're changing this file in the pygments repository, please ensure that
+# any changes you make are also propagated to the official Savi repository,
+# in order to avoid accidental clobbering of your changes later when an update
+# from the Savi repository flows forward into the pygments repository.
+#
+# If you're changing this file in the Savi repository, please ensure that
+# any changes you make are also reflected in the other pygments-like lexers
+# (rouge, vscode, etc) so that all of the lexers can be kept cleanly in sync.
+
+class SaviLexer(RegexLexer):
+    """
+    For Savi source code.
+
+    .. versionadded: 2.10
+    """
+
+    name = 'Savi'
+    url = 'https://github.com/savi-lang/savi'
+    aliases = ['savi']
+    filenames = ['*.savi']
+    version_added = ''
+
+    tokens = {
+      "root": [
+        # Line Comment
+        (r'//.*?$', Comment.Single),
+
+        # Doc Comment
+        (r'::.*?$', Comment.Single),
+
+        # Capability Operator
+        (r'(\')(\w+)(?=[^\'])', bygroups(Operator, Name)),
+
+        # Double-Quote String
+        (r'\w?"', String.Double, "string.double"),
+
+        # Single-Char String
+        (r"'", String.Char, "string.char"),
+
+        # Type Name
+        (r'(_?[A-Z]\w*)', Name.Class),
+
+        # Nested Type Name
+        (r'(\.)(\s*)(_?[A-Z]\w*)', bygroups(Punctuation, Whitespace, Name.Class)),
+
+        # Declare
+        (r'^([ \t]*)(:\w+)',
+          bygroups(Whitespace, Name.Tag),
+          "decl"),
+
+        # Error-Raising Calls/Names
+        (r'((\w+|\+|\-|\*)\!)', Generic.Deleted),
+
+        # Numeric Values
+        (r'\b\d([\d_]*(\.[\d_]+)?)\b', Number),
+
+        # Hex Numeric Values
+        (r'\b0x([0-9a-fA-F_]+)\b', Number.Hex),
+
+        # Binary Numeric Values
+        (r'\b0b([01_]+)\b', Number.Bin),
+
+        # Function Call (with braces)
+        (r'\w+(?=\()', Name.Function),
+
+        # Function Call (with receiver)
+        (r'(\.)(\s*)(\w+)', bygroups(Punctuation, Whitespace, Name.Function)),
+
+        # Function Call (with self receiver)
+        (r'(@)(\w+)', bygroups(Punctuation, Name.Function)),
+
+        # Parenthesis
+        (r'\(', Punctuation, "root"),
+        (r'\)', Punctuation, "#pop"),
+
+        # Brace
+        (r'\{', Punctuation, "root"),
+        (r'\}', Punctuation, "#pop"),
+
+        # Bracket
+        (r'\[', Punctuation, "root"),
+        (r'(\])(\!)', bygroups(Punctuation, Generic.Deleted), "#pop"),
+        (r'\]', Punctuation, "#pop"),
+
+        # Punctuation
+        (r'[,;:\.@]', Punctuation),
+
+        # Piping Operators
+        (r'(\|\>)', Operator),
+
+        # Branching Operators
+        (r'(\&\&|\|\||\?\?|\&\?|\|\?|\.\?)', Operator),
+
+        # Comparison Operators
+        (r'(\<\=\>|\=\~|\=\=|\<\=|\>\=|\<|\>)', Operator),
+
+        # Arithmetic Operators
+        (r'(\+|\-|\/|\*|\%)', Operator),
+
+        # Assignment Operators
+        (r'(\=)', Operator),
+
+        # Other Operators
+        (r'(\!|\<\<|\<|\&|\|)', Operator),
+
+        # Identifiers
+        (r'\b\w+\b', Name),
+
+        # Whitespace
+        (r'[ \t\r]+\n*|\n+', Whitespace),
+      ],
+
+      # Declare (nested rules)
+      "decl": [
+        (r'\b[a-z_]\w*\b(?!\!)', Keyword.Declaration),
+        (r':', Punctuation, "#pop"),
+        (r'\n', Whitespace, "#pop"),
+        include("root"),
+      ],
+
+      # Double-Quote String (nested rules)
+      "string.double": [
+        (r'\\\(', String.Interpol, "string.interpolation"),
+        (r'\\u[0-9a-fA-F]{4}', String.Escape),
+        (r'\\x[0-9a-fA-F]{2}', String.Escape),
+        (r'\\[bfnrt\\\']', String.Escape),
+        (r'\\"', String.Escape),
+        (r'"', String.Double, "#pop"),
+        (r'[^\\"]+', String.Double),
+        (r'.', Error),
+      ],
+
+      # Single-Char String (nested rules)
+      "string.char": [
+        (r'\\u[0-9a-fA-F]{4}', String.Escape),
+        (r'\\x[0-9a-fA-F]{2}', String.Escape),
+        (r'\\[bfnrt\\\']', String.Escape),
+        (r"\\'", String.Escape),
+        (r"'", String.Char, "#pop"),
+        (r"[^\\']+", String.Char),
+        (r'.', Error),
+      ],
+
+      # Interpolation inside String (nested rules)
+      "string.interpolation": [
+        (r"\)", String.Interpol, "#pop"),
+        include("root"),
+      ]
+    }
diff --git a/pygments/lexers/scdoc.py b/pygments/lexers/scdoc.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e850d02ed8ca77a90c9de217fac89dccd60b9ab
--- /dev/null
+++ b/pygments/lexers/scdoc.py
@@ -0,0 +1,85 @@
+"""
+    pygments.lexers.scdoc
+    ~~~~~~~~~~~~~~~~~~~~~
+
+    Lexer for scdoc, a simple man page generator.
+
+    :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS.
+    :license: BSD, see LICENSE for details.
+"""
+
+import re
+
+from pygments.lexer import RegexLexer, include, bygroups, using, this
+from pygments.token import Text, Comment, Keyword, String, Generic
+
+__all__ = ['ScdocLexer']
+
+
+class ScdocLexer(RegexLexer):
+    """
+    `scdoc` is a simple man page generator for POSIX systems written in C99.
+    """
+    name = 'scdoc'
+    url = 'https://git.sr.ht/~sircmpwn/scdoc'
+    aliases = ['scdoc', 'scd']
+    filenames = ['*.scd', '*.scdoc']
+    version_added = '2.5'
+    flags = re.MULTILINE
+
+    tokens = {
+        'root': [
+            # comment
+            (r'^(;.+\n)', bygroups(Comment)),
+
+            # heading with pound prefix
+            (r'^(#)([^#].+\n)', bygroups(Generic.Heading, Text)),
+            (r'^(#{2})(.+\n)', bygroups(Generic.Subheading, Text)),
+            # bulleted lists
+            (r'^(\s*)([*-])(\s)(.+\n)',
+            bygroups(Text, Keyword, Text, using(this, state='inline'))),
+            # numbered lists
+            (r'^(\s*)(\.+\.)( .+\n)',
+            bygroups(Text, Keyword, using(this, state='inline'))),
+            # quote
+            (r'^(\s*>\s)(.+\n)', bygroups(Keyword, Generic.Emph)),
+            # text block
+            (r'^(```\n)([\w\W]*?)(^```$)', bygroups(String, Text, String)),
+
+            include('inline'),
+        ],
+        'inline': [
+            # escape
+            (r'\\.', Text),
+            # underlines
+            (r'(\s)(_[^_]+_)(\W|\n)', bygroups(Text, Generic.Emph, Text)),
+            # bold
+            (r'(\s)(\*[^*]+\*)(\W|\n)', bygroups(Text, Generic.Strong, Text)),
+            # inline code
+            (r'`[^`]+`', String.Backtick),
+
+            # general text, must come last!
+            (r'[^\\\s]+', Text),
+            (r'.', Text),
+        ],
+    }
+
+    def analyse_text(text):
+        """We checks for bold and underline text with * and _. Also
+        every scdoc file must start with a strictly defined first line."""
+        result = 0
+
+        if '*' in text:
+            result += 0.01
+
+        if '_' in text:
+            result += 0.01
+
+        # name(section) ["left_footer" ["center_header"]]
+        first_line = text.partition('\n')[0]
+        scdoc_preamble_pattern = r'^.*\([1-7]\)( "[^"]+"){0,2}$'
+
+        if re.search(scdoc_preamble_pattern, first_line):
+            result += 0.5
+
+        return result
diff --git a/pygments/lexers/scripting.py b/pygments/lexers/scripting.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e494c33b8ec2c939c6e55921a03aba232f5a456
--- /dev/null
+++ b/pygments/lexers/scripting.py
@@ -0,0 +1,1616 @@
+"""
+    pygments.lexers.scripting
+    ~~~~~~~~~~~~~~~~~~~~~~~~~
+
+    Lexer for scripting and embedded languages.
+
+    :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS.
+    :license: BSD, see LICENSE for details.
+"""
+
+import re
+
+from pygments.lexer import RegexLexer, include, bygroups, default, combined, \
+    words
+from pygments.token import Text, Comment, Operator, Keyword, Name, String, \
+    Number, Punctuation, Error, Whitespace, Other
+from pygments.util import get_bool_opt, get_list_opt
+
+__all__ = ['LuaLexer', 'LuauLexer', 'MoonScriptLexer', 'ChaiscriptLexer', 'LSLLexer',
+           'AppleScriptLexer', 'RexxLexer', 'MOOCodeLexer', 'HybrisLexer',
+           'EasytrieveLexer', 'JclLexer', 'MiniScriptLexer']
+
+
+def all_lua_builtins():
+    from pygments.lexers._lua_builtins import MODULES
+    return [w for values in MODULES.values() for w in values]
+
+class LuaLexer(RegexLexer):
+    """
+    For Lua source code.
+
+    Additional options accepted:
+
+    `func_name_highlighting`
+        If given and ``True``, highlight builtin function names
+        (default: ``True``).
+    `disabled_modules`
+        If given, must be a list of module names whose function names
+        should not be highlighted. By default all modules are highlighted.
+
+        To get a list of allowed modules have a look into the
+        `_lua_builtins` module:
+
+        .. sourcecode:: pycon
+
+            >>> from pygments.lexers._lua_builtins import MODULES
+            >>> MODULES.keys()
+            ['string', 'coroutine', 'modules', 'io', 'basic', ...]
+    """
+
+    name = 'Lua'
+    url = 'https://www.lua.org/'
+    aliases = ['lua']
+    filenames = ['*.lua', '*.wlua']
+    mimetypes = ['text/x-lua', 'application/x-lua']
+    version_added = ''
+
+    _comment_multiline = r'(?:--\[(?P=*)\[[\w\W]*?\](?P=level)\])'
+    _comment_single = r'(?:--.*$)'
+    _space = r'(?:\s+(?!\s))'
+    _s = rf'(?:{_comment_multiline}|{_comment_single}|{_space})'
+    _name = r'(?:[^\W\d]\w*)'
+
+    tokens = {
+        'root': [
+            # Lua allows a file to start with a shebang.
+            (r'#!.*', Comment.Preproc),
+            default('base'),
+        ],
+        'ws': [
+            (_comment_multiline, Comment.Multiline),
+            (_comment_single, Comment.Single),
+            (_space, Whitespace),
+        ],
+        'base': [
+            include('ws'),
+
+            (r'(?i)0x[\da-f]*(\.[\da-f]*)?(p[+-]?\d+)?', Number.Hex),
+            (r'(?i)(\d*\.\d+|\d+\.\d*)(e[+-]?\d+)?', Number.Float),
+            (r'(?i)\d+e[+-]?\d+', Number.Float),
+            (r'\d+', Number.Integer),
+
+            # multiline strings
+            (r'(?s)\[(=*)\[.*?\]\1\]', String),
+
+            (r'::', Punctuation, 'label'),
+            (r'\.{3}', Punctuation),
+            (r'[=<>|~&+\-*/%#^]+|\.\.', Operator),
+            (r'[\[\]{}().,:;]+', Punctuation),
+            (r'(and|or|not)\b', Operator.Word),
+
+            (words([
+                'break', 'do', 'else', 'elseif', 'end', 'for', 'if', 'in',
+                'repeat', 'return', 'then', 'until', 'while'
+            ], suffix=r'\b'), Keyword.Reserved),
+            (r'goto\b', Keyword.Reserved, 'goto'),
+            (r'(local)\b', Keyword.Declaration),
+            (r'(true|false|nil)\b', Keyword.Constant),
+
+            (r'(function)\b', Keyword.Reserved, 'funcname'),
+
+            (words(all_lua_builtins(), suffix=r"\b"), Name.Builtin),
+            (fr'[A-Za-z_]\w*(?={_s}*[.:])', Name.Variable, 'varname'),
+            (fr'[A-Za-z_]\w*(?={_s}*\()', Name.Function),
+            (r'[A-Za-z_]\w*', Name.Variable),
+
+            ("'", String.Single, combined('stringescape', 'sqs')),
+            ('"', String.Double, combined('stringescape', 'dqs'))
+        ],
+
+        'varname': [
+            include('ws'),
+            (r'\.\.', Operator, '#pop'),
+            (r'[.:]', Punctuation),
+            (rf'{_name}(?={_s}*[.:])', Name.Property),
+            (rf'{_name}(?={_s}*\()', Name.Function, '#pop'),
+            (_name, Name.Property, '#pop'),
+        ],
+
+        'funcname': [
+            include('ws'),
+            (r'[.:]', Punctuation),
+            (rf'{_name}(?={_s}*[.:])', Name.Class),
+            (_name, Name.Function, '#pop'),
+            # inline function
+            (r'\(', Punctuation, '#pop'),
+        ],
+
+        'goto': [
+            include('ws'),
+            (_name, Name.Label, '#pop'),
+        ],
+
+        'label': [
+            include('ws'),
+            (r'::', Punctuation, '#pop'),
+            (_name, Name.Label),
+        ],
+
+        'stringescape': [
+            (r'\\([abfnrtv\\"\']|[\r\n]{1,2}|z\s*|x[0-9a-fA-F]{2}|\d{1,3}|'
+             r'u\{[0-9a-fA-F]+\})', String.Escape),
+        ],
+
+        'sqs': [
+            (r"'", String.Single, '#pop'),
+            (r"[^\\']+", String.Single),
+        ],
+
+        'dqs': [
+            (r'"', String.Double, '#pop'),
+            (r'[^\\"]+', String.Double),
+        ]
+    }
+
+    def __init__(self, **options):
+        self.func_name_highlighting = get_bool_opt(
+            options, 'func_name_highlighting', True)
+        self.disabled_modules = get_list_opt(options, 'disabled_modules', [])
+
+        self._functions = set()
+        if self.func_name_highlighting:
+            from pygments.lexers._lua_builtins import MODULES
+            for mod, func in MODULES.items():
+                if mod not in self.disabled_modules:
+                    self._functions.update(func)
+        RegexLexer.__init__(self, **options)
+
+    def get_tokens_unprocessed(self, text):
+        for index, token, value in \
+                RegexLexer.get_tokens_unprocessed(self, text):
+            if token is Name.Builtin and value not in self._functions:
+                if '.' in value:
+                    a, b = value.split('.')
+                    yield index, Name, a
+                    yield index + len(a), Punctuation, '.'
+                    yield index + len(a) + 1, Name, b
+                else:
+                    yield index, Name, value
+                continue
+            yield index, token, value
+
+def _luau_make_expression(should_pop, _s):
+    temp_list = [
+        (r'0[xX][\da-fA-F_]*', Number.Hex, '#pop'),
+        (r'0[bB][\d_]*', Number.Bin, '#pop'),
+        (r'\.?\d[\d_]*(?:\.[\d_]*)?(?:[eE][+-]?[\d_]+)?', Number.Float, '#pop'),
+
+        (words((
+            'true', 'false', 'nil'
+        ), suffix=r'\b'), Keyword.Constant, '#pop'),
+
+        (r'\[(=*)\[[.\n]*?\]\1\]', String, '#pop'),
+
+        (r'(\.)([a-zA-Z_]\w*)(?=%s*[({"\'])', bygroups(Punctuation, Name.Function), '#pop'),
+        (r'(\.)([a-zA-Z_]\w*)', bygroups(Punctuation, Name.Variable), '#pop'),
+
+        (rf'[a-zA-Z_]\w*(?:\.[a-zA-Z_]\w*)*(?={_s}*[({{"\'])', Name.Other, '#pop'),
+        (r'[a-zA-Z_]\w*(?:\.[a-zA-Z_]\w*)*', Name, '#pop'),
+    ]
+    if should_pop:
+        return temp_list
+    return [entry[:2] for entry in temp_list]
+
+def _luau_make_expression_special(should_pop):
+    temp_list = [
+        (r'\{', Punctuation, ('#pop', 'closing_brace_base', 'expression')),
+        (r'\(', Punctuation, ('#pop', 'closing_parenthesis_base', 'expression')),
+
+        (r'::?', Punctuation, ('#pop', 'type_end', 'type_start')),
+
+        (r"'", String.Single, ('#pop', 'string_single')),
+        (r'"', String.Double, ('#pop', 'string_double')),
+        (r'`', String.Backtick, ('#pop', 'string_interpolated')),
+    ]
+    if should_pop:
+        return temp_list
+    return [(entry[0], entry[1], entry[2][1:]) for entry in temp_list]
+
+class LuauLexer(RegexLexer):
+    """
+    For Luau source code.
+
+    Additional options accepted:
+
+    `include_luau_builtins`
+        If given and ``True``, automatically highlight Luau builtins
+        (default: ``True``).
+    `include_roblox_builtins`
+        If given and ``True``, automatically highlight Roblox-specific builtins
+        (default: ``False``).
+    `additional_builtins`
+        If given, must be a list of additional builtins to highlight.
+    `disabled_builtins`
+        If given, must be a list of builtins that will not be highlighted.
+    """
+
+    name = 'Luau'
+    url = 'https://luau-lang.org/'
+    aliases = ['luau']
+    filenames = ['*.luau']
+    version_added = '2.18'
+
+    _comment_multiline = r'(?:--\[(?P=*)\[[\w\W]*?\](?P=level)\])'
+    _comment_single = r'(?:--.*$)'
+    _s = r'(?:{}|{}|{})'.format(_comment_multiline, _comment_single, r'\s+')
+
+    tokens = {
+        'root': [
+            (r'#!.*', Comment.Hashbang, 'base'),
+            default('base'),
+        ],
+
+        'ws': [
+            (_comment_multiline, Comment.Multiline),
+            (_comment_single, Comment.Single),
+            (r'\s+', Whitespace),
+        ],
+
+        'base': [
+            include('ws'),
+
+            *_luau_make_expression_special(False),
+            (r'\.\.\.', Punctuation),
+
+            (rf'type\b(?={_s}+[a-zA-Z_])', Keyword.Reserved, 'type_declaration'),
+            (rf'export\b(?={_s}+[a-zA-Z_])', Keyword.Reserved),
+
+            (r'(?:\.\.|//|[+\-*\/%^<>=])=?', Operator, 'expression'),
+            (r'~=', Operator, 'expression'),
+
+            (words((
+                'and', 'or', 'not'
+            ), suffix=r'\b'), Operator.Word, 'expression'),
+
+            (words((
+                'elseif', 'for', 'if', 'in', 'repeat', 'return', 'until',
+                'while'), suffix=r'\b'), Keyword.Reserved, 'expression'),
+            (r'local\b', Keyword.Declaration, 'expression'),
+
+            (r'function\b', Keyword.Reserved, ('expression', 'func_name')),
+
+            (r'[\])};]+', Punctuation),
+
+            include('expression_static'),
+            *_luau_make_expression(False, _s),
+
+            (r'[\[.,]', Punctuation, 'expression'),
+        ],
+        'expression_static': [
+            (words((
+                'break', 'continue', 'do', 'else', 'elseif', 'end', 'for',
+                'if', 'in', 'repeat', 'return', 'then', 'until', 'while'),
+                suffix=r'\b'), Keyword.Reserved),
+        ],
+        'expression': [
+            include('ws'),
+
+            (r'if\b', Keyword.Reserved, ('ternary', 'expression')),
+
+            (r'local\b', Keyword.Declaration),
+            *_luau_make_expression_special(True),
+            (r'\.\.\.', Punctuation, '#pop'),
+
+            (r'function\b', Keyword.Reserved, 'func_name'),
+
+            include('expression_static'),
+            *_luau_make_expression(True, _s),
+
+            default('#pop'),
+        ],
+        'ternary': [
+            include('ws'),
+
+            (r'else\b', Keyword.Reserved, '#pop'),
+            (words((
+                'then', 'elseif',
+            ), suffix=r'\b'), Operator.Reserved, 'expression'),
+
+            default('#pop'),
+        ],
+
+        'closing_brace_pop': [
+            (r'\}', Punctuation, '#pop'),
+        ],
+        'closing_parenthesis_pop': [
+            (r'\)', Punctuation, '#pop'),
+        ],
+        'closing_gt_pop': [
+            (r'>', Punctuation, '#pop'),
+        ],
+
+        'closing_parenthesis_base': [
+            include('closing_parenthesis_pop'),
+            include('base'),
+        ],
+        'closing_parenthesis_type': [
+            include('closing_parenthesis_pop'),
+            include('type'),
+        ],
+        'closing_brace_base': [
+            include('closing_brace_pop'),
+            include('base'),
+        ],
+        'closing_brace_type': [
+            include('closing_brace_pop'),
+            include('type'),
+        ],
+        'closing_gt_type': [
+            include('closing_gt_pop'),
+            include('type'),
+        ],
+
+        'string_escape': [
+            (r'\\z\s*', String.Escape),
+            (r'\\(?:[abfnrtvz\\"\'`\{\n])|[\r\n]{1,2}|x[\da-fA-F]{2}|\d{1,3}|'
+             r'u\{\}[\da-fA-F]*\}', String.Escape),
+        ],
+        'string_single': [
+            include('string_escape'),
+
+            (r"'", String.Single, "#pop"),
+            (r"[^\\']+", String.Single),
+        ],
+        'string_double': [
+            include('string_escape'),
+
+            (r'"', String.Double, "#pop"),
+            (r'[^\\"]+', String.Double),
+        ],
+        'string_interpolated': [
+            include('string_escape'),
+
+            (r'\{', Punctuation, ('closing_brace_base', 'expression')),
+
+            (r'`', String.Backtick, "#pop"),
+            (r'[^\\`\{]+', String.Backtick),
+        ],
+
+        'func_name': [
+            include('ws'),
+
+            (r'[.:]', Punctuation),
+            (rf'[a-zA-Z_]\w*(?={_s}*[.:])', Name.Class),
+            (r'[a-zA-Z_]\w*', Name.Function),
+
+            (r'<', Punctuation, 'closing_gt_type'),
+
+            (r'\(', Punctuation, '#pop'),
+        ],
+
+        'type': [
+            include('ws'),
+
+            (r'\(', Punctuation, 'closing_parenthesis_type'),
+            (r'\{', Punctuation, 'closing_brace_type'),
+            (r'<', Punctuation, 'closing_gt_type'),
+
+            (r"'", String.Single, 'string_single'),
+            (r'"', String.Double, 'string_double'),
+
+            (r'[|&\.,\[\]:=]+', Punctuation),
+            (r'->', Punctuation),
+
+            (r'typeof\(', Name.Builtin, ('closing_parenthesis_base',
+                                         'expression')),
+            (r'[a-zA-Z_]\w*', Name.Class),
+        ],
+        'type_start': [
+            include('ws'),
+
+            (r'\(', Punctuation, ('#pop', 'closing_parenthesis_type')),
+            (r'\{', Punctuation, ('#pop', 'closing_brace_type')),
+            (r'<', Punctuation, ('#pop', 'closing_gt_type')),
+
+            (r"'", String.Single, ('#pop', 'string_single')),
+            (r'"', String.Double, ('#pop', 'string_double')),
+
+            (r'typeof\(', Name.Builtin, ('#pop', 'closing_parenthesis_base',
+                                         'expression')),
+            (r'[a-zA-Z_]\w*', Name.Class, '#pop'),
+        ],
+        'type_end': [
+            include('ws'),
+
+            (r'[|&\.]', Punctuation, 'type_start'),
+            (r'->', Punctuation, 'type_start'),
+
+            (r'<', Punctuation, 'closing_gt_type'),
+
+            default('#pop'),
+        ],
+        'type_declaration': [
+            include('ws'),
+
+            (r'[a-zA-Z_]\w*', Name.Class),
+            (r'<', Punctuation, 'closing_gt_type'),
+
+            (r'=', Punctuation, ('#pop', 'type_end', 'type_start')),
+        ],
+    }
+
+    def __init__(self, **options):
+        self.include_luau_builtins = get_bool_opt(
+            options, 'include_luau_builtins', True)
+        self.include_roblox_builtins = get_bool_opt(
+            options, 'include_roblox_builtins', False)
+        self.additional_builtins = get_list_opt(options, 'additional_builtins', [])
+        self.disabled_builtins = get_list_opt(options, 'disabled_builtins', [])
+
+        self._builtins = set(self.additional_builtins)
+        if self.include_luau_builtins:
+            from pygments.lexers._luau_builtins import LUAU_BUILTINS
+            self._builtins.update(LUAU_BUILTINS)
+        if self.include_roblox_builtins:
+            from pygments.lexers._luau_builtins import ROBLOX_BUILTINS
+            self._builtins.update(ROBLOX_BUILTINS)
+        if self.additional_builtins:
+            self._builtins.update(self.additional_builtins)
+        self._builtins.difference_update(self.disabled_builtins)
+
+        RegexLexer.__init__(self, **options)
+
+    def get_tokens_unprocessed(self, text):
+        for index, token, value in \
+                RegexLexer.get_tokens_unprocessed(self, text):
+            if token is Name or token is Name.Other:
+                split_value = value.split('.')
+                complete_value = []
+                new_index = index
+                for position in range(len(split_value), 0, -1):
+                    potential_string = '.'.join(split_value[:position])
+                    if potential_string in self._builtins:
+                        yield index, Name.Builtin, potential_string
+                        new_index += len(potential_string)
+
+                        if complete_value:
+                            yield new_index, Punctuation, '.'
+                            new_index += 1
+                        break
+                    complete_value.insert(0, split_value[position - 1])
+
+                for position, substring in enumerate(complete_value):
+                    if position + 1 == len(complete_value):
+                        if token is Name:
+                            yield new_index, Name.Variable, substring
+                            continue
+                        yield new_index, Name.Function, substring
+                        continue
+                    yield new_index, Name.Variable, substring
+                    new_index += len(substring)
+                    yield new_index, Punctuation, '.'
+                    new_index += 1
+
+                continue
+            yield index, token, value
+
+class MoonScriptLexer(LuaLexer):
+    """
+    For MoonScript source code.
+    """
+
+    name = 'MoonScript'
+    url = 'http://moonscript.org'
+    aliases = ['moonscript', 'moon']
+    filenames = ['*.moon']
+    mimetypes = ['text/x-moonscript', 'application/x-moonscript']
+    version_added = '1.5'
+
+    tokens = {
+        'root': [
+            (r'#!(.*?)$', Comment.Preproc),
+            default('base'),
+        ],
+        'base': [
+            ('--.*$', Comment.Single),
+            (r'(?i)(\d*\.\d+|\d+\.\d*)(e[+-]?\d+)?', Number.Float),
+            (r'(?i)\d+e[+-]?\d+', Number.Float),
+            (r'(?i)0x[0-9a-f]*', Number.Hex),
+            (r'\d+', Number.Integer),
+            (r'\n', Whitespace),
+            (r'[^\S\n]+', Text),
+            (r'(?s)\[(=*)\[.*?\]\1\]', String),
+            (r'(->|=>)', Name.Function),
+            (r':[a-zA-Z_]\w*', Name.Variable),
+            (r'(==|!=|~=|<=|>=|\.\.\.|\.\.|[=+\-*/%^<>#!.\\:])', Operator),
+            (r'[;,]', Punctuation),
+            (r'[\[\]{}()]', Keyword.Type),
+            (r'[a-zA-Z_]\w*:', Name.Variable),
+            (words((
+                'class', 'extends', 'if', 'then', 'super', 'do', 'with',
+                'import', 'export', 'while', 'elseif', 'return', 'for', 'in',
+                'from', 'when', 'using', 'else', 'and', 'or', 'not', 'switch',
+                'break'), suffix=r'\b'),
+             Keyword),
+            (r'(true|false|nil)\b', Keyword.Constant),
+            (r'(and|or|not)\b', Operator.Word),
+            (r'(self)\b', Name.Builtin.Pseudo),
+            (r'@@?([a-zA-Z_]\w*)?', Name.Variable.Class),
+            (r'[A-Z]\w*', Name.Class),  # proper name
+            (words(all_lua_builtins(), suffix=r"\b"), Name.Builtin),
+            (r'[A-Za-z_]\w*', Name),
+            ("'", String.Single, combined('stringescape', 'sqs')),
+            ('"', String.Double, combined('stringescape', 'dqs'))
+        ],
+        'stringescape': [
+            (r'''\\([abfnrtv\\"']|\d{1,3})''', String.Escape)
+        ],
+        'sqs': [
+            ("'", String.Single, '#pop'),
+            ("[^']+", String)
+        ],
+        'dqs': [
+            ('"', String.Double, '#pop'),
+            ('[^"]+', String)
+        ]
+    }
+
+    def get_tokens_unprocessed(self, text):
+        # set . as Operator instead of Punctuation
+        for index, token, value in LuaLexer.get_tokens_unprocessed(self, text):
+            if token == Punctuation and value == ".":
+                token = Operator
+            yield index, token, value
+
+
+class ChaiscriptLexer(RegexLexer):
+    """
+    For ChaiScript source code.
+    """
+
+    name = 'ChaiScript'
+    url = 'http://chaiscript.com/'
+    aliases = ['chaiscript', 'chai']
+    filenames = ['*.chai']
+    mimetypes = ['text/x-chaiscript', 'application/x-chaiscript']
+    version_added = '2.0'
+
+    flags = re.DOTALL | re.MULTILINE
+
+    tokens = {
+        'commentsandwhitespace': [
+            (r'\s+', Text),
+            (r'//.*?\n', Comment.Single),
+            (r'/\*.*?\*/', Comment.Multiline),
+            (r'^\#.*?\n', Comment.Single)
+        ],
+        'slashstartsregex': [
+            include('commentsandwhitespace'),
+            (r'/(\\.|[^[/\\\n]|\[(\\.|[^\]\\\n])*])+/'
+             r'([gim]+\b|\B)', String.Regex, '#pop'),
+            (r'(?=/)', Text, ('#pop', 'badregex')),
+            default('#pop')
+        ],
+        'badregex': [
+            (r'\n', Text, '#pop')
+        ],
+        'root': [
+            include('commentsandwhitespace'),
+            (r'\n', Text),
+            (r'[^\S\n]+', Text),
+            (r'\+\+|--|~|&&|\?|:|\|\||\\(?=\n)|\.\.'
+             r'(<<|>>>?|==?|!=?|[-<>+*%&|^/])=?', Operator, 'slashstartsregex'),
+            (r'[{(\[;,]', Punctuation, 'slashstartsregex'),
+            (r'[})\].]', Punctuation),
+            (r'[=+\-*/]', Operator),
+            (r'(for|in|while|do|break|return|continue|if|else|'
+             r'throw|try|catch'
+             r')\b', Keyword, 'slashstartsregex'),
+            (r'(var)\b', Keyword.Declaration, 'slashstartsregex'),
+            (r'(attr|def|fun)\b', Keyword.Reserved),
+            (r'(true|false)\b', Keyword.Constant),
+            (r'(eval|throw)\b', Name.Builtin),
+            (r'`\S+`', Name.Builtin),
+            (r'[$a-zA-Z_]\w*', Name.Other),
+            (r'[0-9][0-9]*\.[0-9]+([eE][0-9]+)?[fd]?', Number.Float),
+            (r'0x[0-9a-fA-F]+', Number.Hex),
+            (r'[0-9]+', Number.Integer),
+            (r'"', String.Double, 'dqstring'),
+            (r"'(\\\\|\\[^\\]|[^'\\])*'", String.Single),
+        ],
+        'dqstring': [
+            (r'\$\{[^"}]+?\}', String.Interpol),
+            (r'\$', String.Double),
+            (r'\\\\', String.Double),
+            (r'\\"', String.Double),
+            (r'[^\\"$]+', String.Double),
+            (r'"', String.Double, '#pop'),
+        ],
+    }
+
+
+class LSLLexer(RegexLexer):
+    """
+    For Second Life's Linden Scripting Language source code.
+    """
+
+    name = 'LSL'
+    aliases = ['lsl']
+    filenames = ['*.lsl']
+    mimetypes = ['text/x-lsl']
+    url = 'https://wiki.secondlife.com/wiki/Linden_Scripting_Language'
+    version_added = '2.0'
+
+    flags = re.MULTILINE
+
+    lsl_keywords = r'\b(?:do|else|for|if|jump|return|while)\b'
+    lsl_types = r'\b(?:float|integer|key|list|quaternion|rotation|string|vector)\b'
+    lsl_states = r'\b(?:(?:state)\s+\w+|default)\b'
+    lsl_events = r'\b(?:state_(?:entry|exit)|touch(?:_(?:start|end))?|(?:land_)?collision(?:_(?:start|end))?|timer|listen|(?:no_)?sensor|control|(?:not_)?at_(?:rot_)?target|money|email|run_time_permissions|changed|attach|dataserver|moving_(?:start|end)|link_message|(?:on|object)_rez|remote_data|http_re(?:sponse|quest)|path_update|transaction_result)\b'
+    lsl_functions_builtin = r'\b(?:ll(?:ReturnObjectsBy(?:ID|Owner)|Json(?:2List|[GS]etValue|ValueType)|Sin|Cos|Tan|Atan2|Sqrt|Pow|Abs|Fabs|Frand|Floor|Ceil|Round|Vec(?:Mag|Norm|Dist)|Rot(?:Between|2(?:Euler|Fwd|Left|Up))|(?:Euler|Axes)2Rot|Whisper|(?:Region|Owner)?Say|Shout|Listen(?:Control|Remove)?|Sensor(?:Repeat|Remove)?|Detected(?:Name|Key|Owner|Type|Pos|Vel|Grab|Rot|Group|LinkNumber)|Die|Ground|Wind|(?:[GS]et)(?:AnimationOverride|MemoryLimit|PrimMediaParams|ParcelMusicURL|Object(?:Desc|Name)|PhysicsMaterial|Status|Scale|Color|Alpha|Texture|Pos|Rot|Force|Torque)|ResetAnimationOverride|(?:Scale|Offset|Rotate)Texture|(?:Rot)?Target(?:Remove)?|(?:Stop)?MoveToTarget|Apply(?:Rotational)?Impulse|Set(?:KeyframedMotion|ContentType|RegionPos|(?:Angular)?Velocity|Buoyancy|HoverHeight|ForceAndTorque|TimerEvent|ScriptState|Damage|TextureAnim|Sound(?:Queueing|Radius)|Vehicle(?:Type|(?:Float|Vector|Rotation)Param)|(?:Touch|Sit)?Text|Camera(?:Eye|At)Offset|PrimitiveParams|ClickAction|Link(?:Alpha|Color|PrimitiveParams(?:Fast)?|Texture(?:Anim)?|Camera|Media)|RemoteScriptAccessPin|PayPrice|LocalRot)|ScaleByFactor|Get(?:(?:Max|Min)ScaleFactor|ClosestNavPoint|StaticPath|SimStats|Env|PrimitiveParams|Link(?:PrimitiveParams|Number(?:OfSides)?|Key|Name|Media)|HTTPHeader|FreeURLs|Object(?:Details|PermMask|PrimCount)|Parcel(?:MaxPrims|Details|Prim(?:Count|Owners))|Attached|(?:SPMax|Free|Used)Memory|Region(?:Name|TimeDilation|FPS|Corner|AgentCount)|Root(?:Position|Rotation)|UnixTime|(?:Parcel|Region)Flags|(?:Wall|GMT)clock|SimulatorHostname|BoundingBox|GeometricCenter|Creator|NumberOf(?:Prims|NotecardLines|Sides)|Animation(?:List)?|(?:Camera|Local)(?:Pos|Rot)|Vel|Accel|Omega|Time(?:stamp|OfDay)|(?:Object|CenterOf)?Mass|MassMKS|Energy|Owner|(?:Owner)?Key|SunDirection|Texture(?:Offset|Scale|Rot)|Inventory(?:Number|Name|Key|Type|Creator|PermMask)|Permissions(?:Key)?|StartParameter|List(?:Length|EntryType)|Date|Agent(?:Size|Info|Language|List)|LandOwnerAt|NotecardLine|Script(?:Name|State))|(?:Get|Reset|GetAndReset)Time|PlaySound(?:Slave)?|LoopSound(?:Master|Slave)?|(?:Trigger|Stop|Preload)Sound|(?:(?:Get|Delete)Sub|Insert)String|To(?:Upper|Lower)|Give(?:InventoryList|Money)|RezObject|(?:Stop)?LookAt|Sleep|CollisionFilter|(?:Take|Release)Controls|DetachFromAvatar|AttachToAvatar(?:Temp)?|InstantMessage|(?:GetNext)?Email|StopHover|MinEventDelay|RotLookAt|String(?:Length|Trim)|(?:Start|Stop)Animation|TargetOmega|RequestPermissions|(?:Create|Break)Link|BreakAllLinks|(?:Give|Remove)Inventory|Water|PassTouches|Request(?:Agent|Inventory)Data|TeleportAgent(?:Home|GlobalCoords)?|ModifyLand|CollisionSound|ResetScript|MessageLinked|PushObject|PassCollisions|AxisAngle2Rot|Rot2(?:Axis|Angle)|A(?:cos|sin)|AngleBetween|AllowInventoryDrop|SubStringIndex|List2(?:CSV|Integer|Json|Float|String|Key|Vector|Rot|List(?:Strided)?)|DeleteSubList|List(?:Statistics|Sort|Randomize|(?:Insert|Find|Replace)List)|EdgeOfWorld|AdjustSoundVolume|Key2Name|TriggerSoundLimited|EjectFromLand|(?:CSV|ParseString)2List|OverMyLand|SameGroup|UnSit|Ground(?:Slope|Normal|Contour)|GroundRepel|(?:Set|Remove)VehicleFlags|(?:AvatarOn)?(?:Link)?SitTarget|Script(?:Danger|Profiler)|Dialog|VolumeDetect|ResetOtherScript|RemoteLoadScriptPin|(?:Open|Close)RemoteDataChannel|SendRemoteData|RemoteDataReply|(?:Integer|String)ToBase64|XorBase64|Log(?:10)?|Base64To(?:String|Integer)|ParseStringKeepNulls|RezAtRoot|RequestSimulatorData|ForceMouselook|(?:Load|Release|(?:E|Une)scape)URL|ParcelMedia(?:CommandList|Query)|ModPow|MapDestination|(?:RemoveFrom|AddTo|Reset)Land(?:Pass|Ban)List|(?:Set|Clear)CameraParams|HTTP(?:Request|Response)|TextBox|DetectedTouch(?:UV|Face|Pos|(?:N|Bin)ormal|ST)|(?:MD5|SHA1|DumpList2)String|Request(?:Secure)?URL|Clear(?:Prim|Link)Media|(?:Link)?ParticleSystem|(?:Get|Request)(?:Username|DisplayName)|RegionSayTo|CastRay|GenerateKey|TransferLindenDollars|ManageEstateAccess|(?:Create|Delete)Character|ExecCharacterCmd|Evade|FleeFrom|NavigateTo|PatrolPoints|Pursue|UpdateCharacter|WanderWithin))\b'
+    lsl_constants_float = r'\b(?:DEG_TO_RAD|PI(?:_BY_TWO)?|RAD_TO_DEG|SQRT2|TWO_PI)\b'
+    lsl_constants_integer = r'\b(?:JSON_APPEND|STATUS_(?:PHYSICS|ROTATE_[XYZ]|PHANTOM|SANDBOX|BLOCK_GRAB(?:_OBJECT)?|(?:DIE|RETURN)_AT_EDGE|CAST_SHADOWS|OK|MALFORMED_PARAMS|TYPE_MISMATCH|BOUNDS_ERROR|NOT_(?:FOUND|SUPPORTED)|INTERNAL_ERROR|WHITELIST_FAILED)|AGENT(?:_(?:BY_(?:LEGACY_|USER)NAME|FLYING|ATTACHMENTS|SCRIPTED|MOUSELOOK|SITTING|ON_OBJECT|AWAY|WALKING|IN_AIR|TYPING|CROUCHING|BUSY|ALWAYS_RUN|AUTOPILOT|LIST_(?:PARCEL(?:_OWNER)?|REGION)))?|CAMERA_(?:PITCH|DISTANCE|BEHINDNESS_(?:ANGLE|LAG)|(?:FOCUS|POSITION)(?:_(?:THRESHOLD|LOCKED|LAG))?|FOCUS_OFFSET|ACTIVE)|ANIM_ON|LOOP|REVERSE|PING_PONG|SMOOTH|ROTATE|SCALE|ALL_SIDES|LINK_(?:ROOT|SET|ALL_(?:OTHERS|CHILDREN)|THIS)|ACTIVE|PASSIVE|SCRIPTED|CONTROL_(?:FWD|BACK|(?:ROT_)?(?:LEFT|RIGHT)|UP|DOWN|(?:ML_)?LBUTTON)|PERMISSION_(?:RETURN_OBJECTS|DEBIT|OVERRIDE_ANIMATIONS|SILENT_ESTATE_MANAGEMENT|TAKE_CONTROLS|TRIGGER_ANIMATION|ATTACH|CHANGE_LINKS|(?:CONTROL|TRACK)_CAMERA|TELEPORT)|INVENTORY_(?:TEXTURE|SOUND|OBJECT|SCRIPT|LANDMARK|CLOTHING|NOTECARD|BODYPART|ANIMATION|GESTURE|ALL|NONE)|CHANGED_(?:INVENTORY|COLOR|SHAPE|SCALE|TEXTURE|LINK|ALLOWED_DROP|OWNER|REGION(?:_START)?|TELEPORT|MEDIA)|OBJECT_(?:(?:PHYSICS|SERVER|STREAMING)_COST|UNKNOWN_DETAIL|CHARACTER_TIME|PHANTOM|PHYSICS|TEMP_ON_REZ|NAME|DESC|POS|PRIM_EQUIVALENCE|RETURN_(?:PARCEL(?:_OWNER)?|REGION)|ROO?T|VELOCITY|OWNER|GROUP|CREATOR|ATTACHED_POINT|RENDER_WEIGHT|PATHFINDING_TYPE|(?:RUNNING|TOTAL)_SCRIPT_COUNT|SCRIPT_(?:MEMORY|TIME))|TYPE_(?:INTEGER|FLOAT|STRING|KEY|VECTOR|ROTATION|INVALID)|(?:DEBUG|PUBLIC)_CHANNEL|ATTACH_(?:AVATAR_CENTER|CHEST|HEAD|BACK|PELVIS|MOUTH|CHIN|NECK|NOSE|BELLY|[LR](?:SHOULDER|HAND|FOOT|EAR|EYE|[UL](?:ARM|LEG)|HIP)|(?:LEFT|RIGHT)_PEC|HUD_(?:CENTER_[12]|TOP_(?:RIGHT|CENTER|LEFT)|BOTTOM(?:_(?:RIGHT|LEFT))?))|LAND_(?:LEVEL|RAISE|LOWER|SMOOTH|NOISE|REVERT)|DATA_(?:ONLINE|NAME|BORN|SIM_(?:POS|STATUS|RATING)|PAYINFO)|PAYMENT_INFO_(?:ON_FILE|USED)|REMOTE_DATA_(?:CHANNEL|REQUEST|REPLY)|PSYS_(?:PART_(?:BF_(?:ZERO|ONE(?:_MINUS_(?:DEST_COLOR|SOURCE_(ALPHA|COLOR)))?|DEST_COLOR|SOURCE_(ALPHA|COLOR))|BLEND_FUNC_(DEST|SOURCE)|FLAGS|(?:START|END)_(?:COLOR|ALPHA|SCALE|GLOW)|MAX_AGE|(?:RIBBON|WIND|INTERP_(?:COLOR|SCALE)|BOUNCE|FOLLOW_(?:SRC|VELOCITY)|TARGET_(?:POS|LINEAR)|EMISSIVE)_MASK)|SRC_(?:MAX_AGE|PATTERN|ANGLE_(?:BEGIN|END)|BURST_(?:RATE|PART_COUNT|RADIUS|SPEED_(?:MIN|MAX))|ACCEL|TEXTURE|TARGET_KEY|OMEGA|PATTERN_(?:DROP|EXPLODE|ANGLE(?:_CONE(?:_EMPTY)?)?)))|VEHICLE_(?:REFERENCE_FRAME|TYPE_(?:NONE|SLED|CAR|BOAT|AIRPLANE|BALLOON)|(?:LINEAR|ANGULAR)_(?:FRICTION_TIMESCALE|MOTOR_DIRECTION)|LINEAR_MOTOR_OFFSET|HOVER_(?:HEIGHT|EFFICIENCY|TIMESCALE)|BUOYANCY|(?:LINEAR|ANGULAR)_(?:DEFLECTION_(?:EFFICIENCY|TIMESCALE)|MOTOR_(?:DECAY_)?TIMESCALE)|VERTICAL_ATTRACTION_(?:EFFICIENCY|TIMESCALE)|BANKING_(?:EFFICIENCY|MIX|TIMESCALE)|FLAG_(?:NO_DEFLECTION_UP|LIMIT_(?:ROLL_ONLY|MOTOR_UP)|HOVER_(?:(?:WATER|TERRAIN|UP)_ONLY|GLOBAL_HEIGHT)|MOUSELOOK_(?:STEER|BANK)|CAMERA_DECOUPLED))|PRIM_(?:TYPE(?:_(?:BOX|CYLINDER|PRISM|SPHERE|TORUS|TUBE|RING|SCULPT))?|HOLE_(?:DEFAULT|CIRCLE|SQUARE|TRIANGLE)|MATERIAL(?:_(?:STONE|METAL|GLASS|WOOD|FLESH|PLASTIC|RUBBER))?|SHINY_(?:NONE|LOW|MEDIUM|HIGH)|BUMP_(?:NONE|BRIGHT|DARK|WOOD|BARK|BRICKS|CHECKER|CONCRETE|TILE|STONE|DISKS|GRAVEL|BLOBS|SIDING|LARGETILE|STUCCO|SUCTION|WEAVE)|TEXGEN_(?:DEFAULT|PLANAR)|SCULPT_(?:TYPE_(?:SPHERE|TORUS|PLANE|CYLINDER|MASK)|FLAG_(?:MIRROR|INVERT))|PHYSICS(?:_(?:SHAPE_(?:CONVEX|NONE|PRIM|TYPE)))?|(?:POS|ROT)_LOCAL|SLICE|TEXT|FLEXIBLE|POINT_LIGHT|TEMP_ON_REZ|PHANTOM|POSITION|SIZE|ROTATION|TEXTURE|NAME|OMEGA|DESC|LINK_TARGET|COLOR|BUMP_SHINY|FULLBRIGHT|TEXGEN|GLOW|MEDIA_(?:ALT_IMAGE_ENABLE|CONTROLS|(?:CURRENT|HOME)_URL|AUTO_(?:LOOP|PLAY|SCALE|ZOOM)|FIRST_CLICK_INTERACT|(?:WIDTH|HEIGHT)_PIXELS|WHITELIST(?:_ENABLE)?|PERMS_(?:INTERACT|CONTROL)|PARAM_MAX|CONTROLS_(?:STANDARD|MINI)|PERM_(?:NONE|OWNER|GROUP|ANYONE)|MAX_(?:URL_LENGTH|WHITELIST_(?:SIZE|COUNT)|(?:WIDTH|HEIGHT)_PIXELS)))|MASK_(?:BASE|OWNER|GROUP|EVERYONE|NEXT)|PERM_(?:TRANSFER|MODIFY|COPY|MOVE|ALL)|PARCEL_(?:MEDIA_COMMAND_(?:STOP|PAUSE|PLAY|LOOP|TEXTURE|URL|TIME|AGENT|UNLOAD|AUTO_ALIGN|TYPE|SIZE|DESC|LOOP_SET)|FLAG_(?:ALLOW_(?:FLY|(?:GROUP_)?SCRIPTS|LANDMARK|TERRAFORM|DAMAGE|CREATE_(?:GROUP_)?OBJECTS)|USE_(?:ACCESS_(?:GROUP|LIST)|BAN_LIST|LAND_PASS_LIST)|LOCAL_SOUND_ONLY|RESTRICT_PUSHOBJECT|ALLOW_(?:GROUP|ALL)_OBJECT_ENTRY)|COUNT_(?:TOTAL|OWNER|GROUP|OTHER|SELECTED|TEMP)|DETAILS_(?:NAME|DESC|OWNER|GROUP|AREA|ID|SEE_AVATARS))|LIST_STAT_(?:MAX|MIN|MEAN|MEDIAN|STD_DEV|SUM(?:_SQUARES)?|NUM_COUNT|GEOMETRIC_MEAN|RANGE)|PAY_(?:HIDE|DEFAULT)|REGION_FLAG_(?:ALLOW_DAMAGE|FIXED_SUN|BLOCK_TERRAFORM|SANDBOX|DISABLE_(?:COLLISIONS|PHYSICS)|BLOCK_FLY|ALLOW_DIRECT_TELEPORT|RESTRICT_PUSHOBJECT)|HTTP_(?:METHOD|MIMETYPE|BODY_(?:MAXLENGTH|TRUNCATED)|CUSTOM_HEADER|PRAGMA_NO_CACHE|VERBOSE_THROTTLE|VERIFY_CERT)|STRING_(?:TRIM(?:_(?:HEAD|TAIL))?)|CLICK_ACTION_(?:NONE|TOUCH|SIT|BUY|PAY|OPEN(?:_MEDIA)?|PLAY|ZOOM)|TOUCH_INVALID_FACE|PROFILE_(?:NONE|SCRIPT_MEMORY)|RC_(?:DATA_FLAGS|DETECT_PHANTOM|GET_(?:LINK_NUM|NORMAL|ROOT_KEY)|MAX_HITS|REJECT_(?:TYPES|AGENTS|(?:NON)?PHYSICAL|LAND))|RCERR_(?:CAST_TIME_EXCEEDED|SIM_PERF_LOW|UNKNOWN)|ESTATE_ACCESS_(?:ALLOWED_(?:AGENT|GROUP)_(?:ADD|REMOVE)|BANNED_AGENT_(?:ADD|REMOVE))|DENSITY|FRICTION|RESTITUTION|GRAVITY_MULTIPLIER|KFM_(?:COMMAND|CMD_(?:PLAY|STOP|PAUSE|SET_MODE)|MODE|FORWARD|LOOP|PING_PONG|REVERSE|DATA|ROTATION|TRANSLATION)|ERR_(?:GENERIC|PARCEL_PERMISSIONS|MALFORMED_PARAMS|RUNTIME_PERMISSIONS|THROTTLED)|CHARACTER_(?:CMD_(?:(?:SMOOTH_)?STOP|JUMP)|DESIRED_(?:TURN_)?SPEED|RADIUS|STAY_WITHIN_PARCEL|LENGTH|ORIENTATION|ACCOUNT_FOR_SKIPPED_FRAMES|AVOIDANCE_MODE|TYPE(?:_(?:[A-D]|NONE))?|MAX_(?:DECEL|TURN_RADIUS|(?:ACCEL|SPEED)))|PURSUIT_(?:OFFSET|FUZZ_FACTOR|GOAL_TOLERANCE|INTERCEPT)|REQUIRE_LINE_OF_SIGHT|FORCE_DIRECT_PATH|VERTICAL|HORIZONTAL|AVOID_(?:CHARACTERS|DYNAMIC_OBSTACLES|NONE)|PU_(?:EVADE_(?:HIDDEN|SPOTTED)|FAILURE_(?:DYNAMIC_PATHFINDING_DISABLED|INVALID_(?:GOAL|START)|NO_(?:NAVMESH|VALID_DESTINATION)|OTHER|TARGET_GONE|(?:PARCEL_)?UNREACHABLE)|(?:GOAL|SLOWDOWN_DISTANCE)_REACHED)|TRAVERSAL_TYPE(?:_(?:FAST|NONE|SLOW))?|CONTENT_TYPE_(?:ATOM|FORM|HTML|JSON|LLSD|RSS|TEXT|XHTML|XML)|GCNP_(?:RADIUS|STATIC)|(?:PATROL|WANDER)_PAUSE_AT_WAYPOINTS|OPT_(?:AVATAR|CHARACTER|EXCLUSION_VOLUME|LEGACY_LINKSET|MATERIAL_VOLUME|OTHER|STATIC_OBSTACLE|WALKABLE)|SIM_STAT_PCT_CHARS_STEPPED)\b'
+    lsl_constants_integer_boolean = r'\b(?:FALSE|TRUE)\b'
+    lsl_constants_rotation = r'\b(?:ZERO_ROTATION)\b'
+    lsl_constants_string = r'\b(?:EOF|JSON_(?:ARRAY|DELETE|FALSE|INVALID|NULL|NUMBER|OBJECT|STRING|TRUE)|NULL_KEY|TEXTURE_(?:BLANK|DEFAULT|MEDIA|PLYWOOD|TRANSPARENT)|URL_REQUEST_(?:GRANTED|DENIED))\b'
+    lsl_constants_vector = r'\b(?:TOUCH_INVALID_(?:TEXCOORD|VECTOR)|ZERO_VECTOR)\b'
+    lsl_invalid_broken = r'\b(?:LAND_(?:LARGE|MEDIUM|SMALL)_BRUSH)\b'
+    lsl_invalid_deprecated = r'\b(?:ATTACH_[LR]PEC|DATA_RATING|OBJECT_ATTACHMENT_(?:GEOMETRY_BYTES|SURFACE_AREA)|PRIM_(?:CAST_SHADOWS|MATERIAL_LIGHT|TYPE_LEGACY)|PSYS_SRC_(?:INNER|OUTER)ANGLE|VEHICLE_FLAG_NO_FLY_UP|ll(?:Cloud|Make(?:Explosion|Fountain|Smoke|Fire)|RemoteDataSetRegion|Sound(?:Preload)?|XorBase64Strings(?:Correct)?))\b'
+    lsl_invalid_illegal = r'\b(?:event)\b'
+    lsl_invalid_unimplemented = r'\b(?:CHARACTER_(?:MAX_ANGULAR_(?:ACCEL|SPEED)|TURN_SPEED_MULTIPLIER)|PERMISSION_(?:CHANGE_(?:JOINTS|PERMISSIONS)|RELEASE_OWNERSHIP|REMAP_CONTROLS)|PRIM_PHYSICS_MATERIAL|PSYS_SRC_OBJ_REL_MASK|ll(?:CollisionSprite|(?:Stop)?PointAt|(?:(?:Refresh|Set)Prim)URL|(?:Take|Release)Camera|RemoteLoadScript))\b'
+    lsl_reserved_godmode = r'\b(?:ll(?:GodLikeRezObject|Set(?:Inventory|Object)PermMask))\b'
+    lsl_reserved_log = r'\b(?:print)\b'
+    lsl_operators = r'\+\+|\-\-|<<|>>|&&?|\|\|?|\^|~|[!%<>=*+\-/]=?'
+
+    tokens = {
+        'root':
+        [
+            (r'//.*?\n',                          Comment.Single),
+            (r'/\*',                              Comment.Multiline, 'comment'),
+            (r'"',                                String.Double, 'string'),
+            (lsl_keywords,                        Keyword),
+            (lsl_types,                           Keyword.Type),
+            (lsl_states,                          Name.Class),
+            (lsl_events,                          Name.Builtin),
+            (lsl_functions_builtin,               Name.Function),
+            (lsl_constants_float,                 Keyword.Constant),
+            (lsl_constants_integer,               Keyword.Constant),
+            (lsl_constants_integer_boolean,       Keyword.Constant),
+            (lsl_constants_rotation,              Keyword.Constant),
+            (lsl_constants_string,                Keyword.Constant),
+            (lsl_constants_vector,                Keyword.Constant),
+            (lsl_invalid_broken,                  Error),
+            (lsl_invalid_deprecated,              Error),
+            (lsl_invalid_illegal,                 Error),
+            (lsl_invalid_unimplemented,           Error),
+            (lsl_reserved_godmode,                Keyword.Reserved),
+            (lsl_reserved_log,                    Keyword.Reserved),
+            (r'\b([a-zA-Z_]\w*)\b',     Name.Variable),
+            (r'(\d+\.\d*|\.\d+|\d+)[eE][+-]?\d*', Number.Float),
+            (r'(\d+\.\d*|\.\d+)',                 Number.Float),
+            (r'0[xX][0-9a-fA-F]+',                Number.Hex),
+            (r'\d+',                              Number.Integer),
+            (lsl_operators,                       Operator),
+            (r':=?',                              Error),
+            (r'[,;{}()\[\]]',                     Punctuation),
+            (r'\n+',                              Whitespace),
+            (r'\s+',                              Whitespace)
+        ],
+        'comment':
+        [
+            (r'[^*/]+',                           Comment.Multiline),
+            (r'/\*',                              Comment.Multiline, '#push'),
+            (r'\*/',                              Comment.Multiline, '#pop'),
+            (r'[*/]',                             Comment.Multiline)
+        ],
+        'string':
+        [
+            (r'\\([nt"\\])',                      String.Escape),
+            (r'"',                                String.Double, '#pop'),
+            (r'\\.',                              Error),
+            (r'[^"\\]+',                          String.Double),
+        ]
+    }
+
+
+class AppleScriptLexer(RegexLexer):
+    """
+    For AppleScript source code,
+    including `AppleScript Studio
+    `_.
+    Contributed by Andreas Amann .
+    """
+
+    name = 'AppleScript'
+    url = 'https://developer.apple.com/library/archive/documentation/AppleScript/Conceptual/AppleScriptLangGuide/introduction/ASLR_intro.html'
+    aliases = ['applescript']
+    filenames = ['*.applescript']
+    version_added = '1.0'
+
+    flags = re.MULTILINE | re.DOTALL
+
+    Identifiers = r'[a-zA-Z]\w*'
+
+    # XXX: use words() for all of these
+    Literals = ('AppleScript', 'current application', 'false', 'linefeed',
+                'missing value', 'pi', 'quote', 'result', 'return', 'space',
+                'tab', 'text item delimiters', 'true', 'version')
+    Classes = ('alias ', 'application ', 'boolean ', 'class ', 'constant ',
+               'date ', 'file ', 'integer ', 'list ', 'number ', 'POSIX file ',
+               'real ', 'record ', 'reference ', 'RGB color ', 'script ',
+               'text ', 'unit types', '(?:Unicode )?text', 'string')
+    BuiltIn = ('attachment', 'attribute run', 'character', 'day', 'month',
+               'paragraph', 'word', 'year')
+    HandlerParams = ('about', 'above', 'against', 'apart from', 'around',
+                     'aside from', 'at', 'below', 'beneath', 'beside',
+                     'between', 'for', 'given', 'instead of', 'on', 'onto',
+                     'out of', 'over', 'since')
+    Commands = ('ASCII (character|number)', 'activate', 'beep', 'choose URL',
+                'choose application', 'choose color', 'choose file( name)?',
+                'choose folder', 'choose from list',
+                'choose remote application', 'clipboard info',
+                'close( access)?', 'copy', 'count', 'current date', 'delay',
+                'delete', 'display (alert|dialog)', 'do shell script',
+                'duplicate', 'exists', 'get eof', 'get volume settings',
+                'info for', 'launch', 'list (disks|folder)', 'load script',
+                'log', 'make', 'mount volume', 'new', 'offset',
+                'open( (for access|location))?', 'path to', 'print', 'quit',
+                'random number', 'read', 'round', 'run( script)?',
+                'say', 'scripting components',
+                'set (eof|the clipboard to|volume)', 'store script',
+                'summarize', 'system attribute', 'system info',
+                'the clipboard', 'time to GMT', 'write', 'quoted form')
+    References = ('(in )?back of', '(in )?front of', '[0-9]+(st|nd|rd|th)',
+                  'first', 'second', 'third', 'fourth', 'fifth', 'sixth',
+                  'seventh', 'eighth', 'ninth', 'tenth', 'after', 'back',
+                  'before', 'behind', 'every', 'front', 'index', 'last',
+                  'middle', 'some', 'that', 'through', 'thru', 'where', 'whose')
+    Operators = ("and", "or", "is equal", "equals", "(is )?equal to", "is not",
+                 "isn't", "isn't equal( to)?", "is not equal( to)?",
+                 "doesn't equal", "does not equal", "(is )?greater than",
+                 "comes after", "is not less than or equal( to)?",
+                 "isn't less than or equal( to)?", "(is )?less than",
+                 "comes before", "is not greater than or equal( to)?",
+                 "isn't greater than or equal( to)?",
+                 "(is  )?greater than or equal( to)?", "is not less than",
+                 "isn't less than", "does not come before",
+                 "doesn't come before", "(is )?less than or equal( to)?",
+                 "is not greater than", "isn't greater than",
+                 "does not come after", "doesn't come after", "starts? with",
+                 "begins? with", "ends? with", "contains?", "does not contain",
+                 "doesn't contain", "is in", "is contained by", "is not in",
+                 "is not contained by", "isn't contained by", "div", "mod",
+                 "not", "(a  )?(ref( to)?|reference to)", "is", "does")
+    Control = ('considering', 'else', 'error', 'exit', 'from', 'if',
+               'ignoring', 'in', 'repeat', 'tell', 'then', 'times', 'to',
+               'try', 'until', 'using terms from', 'while', 'whith',
+               'with timeout( of)?', 'with transaction', 'by', 'continue',
+               'end', 'its?', 'me', 'my', 'return', 'of', 'as')
+    Declarations = ('global', 'local', 'prop(erty)?', 'set', 'get')
+    Reserved = ('but', 'put', 'returning', 'the')
+    StudioClasses = ('action cell', 'alert reply', 'application', 'box',
+                     'browser( cell)?', 'bundle', 'button( cell)?', 'cell',
+                     'clip view', 'color well', 'color-panel',
+                     'combo box( item)?', 'control',
+                     'data( (cell|column|item|row|source))?', 'default entry',
+                     'dialog reply', 'document', 'drag info', 'drawer',
+                     'event', 'font(-panel)?', 'formatter',
+                     'image( (cell|view))?', 'matrix', 'menu( item)?', 'item',
+                     'movie( view)?', 'open-panel', 'outline view', 'panel',
+                     'pasteboard', 'plugin', 'popup button',
+                     'progress indicator', 'responder', 'save-panel',
+                     'scroll view', 'secure text field( cell)?', 'slider',
+                     'sound', 'split view', 'stepper', 'tab view( item)?',
+                     'table( (column|header cell|header view|view))',
+                     'text( (field( cell)?|view))?', 'toolbar( item)?',
+                     'user-defaults', 'view', 'window')
+    StudioEvents = ('accept outline drop', 'accept table drop', 'action',
+                    'activated', 'alert ended', 'awake from nib', 'became key',
+                    'became main', 'begin editing', 'bounds changed',
+                    'cell value', 'cell value changed', 'change cell value',
+                    'change item value', 'changed', 'child of item',
+                    'choose menu item', 'clicked', 'clicked toolbar item',
+                    'closed', 'column clicked', 'column moved',
+                    'column resized', 'conclude drop', 'data representation',
+                    'deminiaturized', 'dialog ended', 'document nib name',
+                    'double clicked', 'drag( (entered|exited|updated))?',
+                    'drop', 'end editing', 'exposed', 'idle', 'item expandable',
+                    'item value', 'item value changed', 'items changed',
+                    'keyboard down', 'keyboard up', 'launched',
+                    'load data representation', 'miniaturized', 'mouse down',
+                    'mouse dragged', 'mouse entered', 'mouse exited',
+                    'mouse moved', 'mouse up', 'moved',
+                    'number of browser rows', 'number of items',
+                    'number of rows', 'open untitled', 'opened', 'panel ended',
+                    'parameters updated', 'plugin loaded', 'prepare drop',
+                    'prepare outline drag', 'prepare outline drop',
+                    'prepare table drag', 'prepare table drop',
+                    'read from file', 'resigned active', 'resigned key',
+                    'resigned main', 'resized( sub views)?',
+                    'right mouse down', 'right mouse dragged',
+                    'right mouse up', 'rows changed', 'scroll wheel',
+                    'selected tab view item', 'selection changed',
+                    'selection changing', 'should begin editing',
+                    'should close', 'should collapse item',
+                    'should end editing', 'should expand item',
+                    'should open( untitled)?',
+                    'should quit( after last window closed)?',
+                    'should select column', 'should select item',
+                    'should select row', 'should select tab view item',
+                    'should selection change', 'should zoom', 'shown',
+                    'update menu item', 'update parameters',
+                    'update toolbar item', 'was hidden', 'was miniaturized',
+                    'will become active', 'will close', 'will dismiss',
+                    'will display browser cell', 'will display cell',
+                    'will display item cell', 'will display outline cell',
+                    'will finish launching', 'will hide', 'will miniaturize',
+                    'will move', 'will open', 'will pop up', 'will quit',
+                    'will resign active', 'will resize( sub views)?',
+                    'will select tab view item', 'will show', 'will zoom',
+                    'write to file', 'zoomed')
+    StudioCommands = ('animate', 'append', 'call method', 'center',
+                      'close drawer', 'close panel', 'display',
+                      'display alert', 'display dialog', 'display panel', 'go',
+                      'hide', 'highlight', 'increment', 'item for',
+                      'load image', 'load movie', 'load nib', 'load panel',
+                      'load sound', 'localized string', 'lock focus', 'log',
+                      'open drawer', 'path for', 'pause', 'perform action',
+                      'play', 'register', 'resume', 'scroll', 'select( all)?',
+                      'show', 'size to fit', 'start', 'step back',
+                      'step forward', 'stop', 'synchronize', 'unlock focus',
+                      'update')
+    StudioProperties = ('accepts arrow key', 'action method', 'active',
+                        'alignment', 'allowed identifiers',
+                        'allows branch selection', 'allows column reordering',
+                        'allows column resizing', 'allows column selection',
+                        'allows customization',
+                        'allows editing text attributes',
+                        'allows empty selection', 'allows mixed state',
+                        'allows multiple selection', 'allows reordering',
+                        'allows undo', 'alpha( value)?', 'alternate image',
+                        'alternate increment value', 'alternate title',
+                        'animation delay', 'associated file name',
+                        'associated object', 'auto completes', 'auto display',
+                        'auto enables items', 'auto repeat',
+                        'auto resizes( outline column)?',
+                        'auto save expanded items', 'auto save name',
+                        'auto save table columns', 'auto saves configuration',
+                        'auto scroll', 'auto sizes all columns to fit',
+                        'auto sizes cells', 'background color', 'bezel state',
+                        'bezel style', 'bezeled', 'border rect', 'border type',
+                        'bordered', 'bounds( rotation)?', 'box type',
+                        'button returned', 'button type',
+                        'can choose directories', 'can choose files',
+                        'can draw', 'can hide',
+                        'cell( (background color|size|type))?', 'characters',
+                        'class', 'click count', 'clicked( data)? column',
+                        'clicked data item', 'clicked( data)? row',
+                        'closeable', 'collating', 'color( (mode|panel))',
+                        'command key down', 'configuration',
+                        'content(s| (size|view( margins)?))?', 'context',
+                        'continuous', 'control key down', 'control size',
+                        'control tint', 'control view',
+                        'controller visible', 'coordinate system',
+                        'copies( on scroll)?', 'corner view', 'current cell',
+                        'current column', 'current( field)?  editor',
+                        'current( menu)? item', 'current row',
+                        'current tab view item', 'data source',
+                        'default identifiers', 'delta (x|y|z)',
+                        'destination window', 'directory', 'display mode',
+                        'displayed cell', 'document( (edited|rect|view))?',
+                        'double value', 'dragged column', 'dragged distance',
+                        'dragged items', 'draws( cell)? background',
+                        'draws grid', 'dynamically scrolls', 'echos bullets',
+                        'edge', 'editable', 'edited( data)? column',
+                        'edited data item', 'edited( data)? row', 'enabled',
+                        'enclosing scroll view', 'ending page',
+                        'error handling', 'event number', 'event type',
+                        'excluded from windows menu', 'executable path',
+                        'expanded', 'fax number', 'field editor', 'file kind',
+                        'file name', 'file type', 'first responder',
+                        'first visible column', 'flipped', 'floating',
+                        'font( panel)?', 'formatter', 'frameworks path',
+                        'frontmost', 'gave up', 'grid color', 'has data items',
+                        'has horizontal ruler', 'has horizontal scroller',
+                        'has parent data item', 'has resize indicator',
+                        'has shadow', 'has sub menu', 'has vertical ruler',
+                        'has vertical scroller', 'header cell', 'header view',
+                        'hidden', 'hides when deactivated', 'highlights by',
+                        'horizontal line scroll', 'horizontal page scroll',
+                        'horizontal ruler view', 'horizontally resizable',
+                        'icon image', 'id', 'identifier',
+                        'ignores multiple clicks',
+                        'image( (alignment|dims when disabled|frame style|scaling))?',
+                        'imports graphics', 'increment value',
+                        'indentation per level', 'indeterminate', 'index',
+                        'integer value', 'intercell spacing', 'item height',
+                        'key( (code|equivalent( modifier)?|window))?',
+                        'knob thickness', 'label', 'last( visible)? column',
+                        'leading offset', 'leaf', 'level', 'line scroll',
+                        'loaded', 'localized sort', 'location', 'loop mode',
+                        'main( (bunde|menu|window))?', 'marker follows cell',
+                        'matrix mode', 'maximum( content)? size',
+                        'maximum visible columns',
+                        'menu( form representation)?', 'miniaturizable',
+                        'miniaturized', 'minimized image', 'minimized title',
+                        'minimum column width', 'minimum( content)? size',
+                        'modal', 'modified', 'mouse down state',
+                        'movie( (controller|file|rect))?', 'muted', 'name',
+                        'needs display', 'next state', 'next text',
+                        'number of tick marks', 'only tick mark values',
+                        'opaque', 'open panel', 'option key down',
+                        'outline table column', 'page scroll', 'pages across',
+                        'pages down', 'palette label', 'pane splitter',
+                        'parent data item', 'parent window', 'pasteboard',
+                        'path( (names|separator))?', 'playing',
+                        'plays every frame', 'plays selection only', 'position',
+                        'preferred edge', 'preferred type', 'pressure',
+                        'previous text', 'prompt', 'properties',
+                        'prototype cell', 'pulls down', 'rate',
+                        'released when closed', 'repeated',
+                        'requested print time', 'required file type',
+                        'resizable', 'resized column', 'resource path',
+                        'returns records', 'reuses columns', 'rich text',
+                        'roll over', 'row height', 'rulers visible',
+                        'save panel', 'scripts path', 'scrollable',
+                        'selectable( identifiers)?', 'selected cell',
+                        'selected( data)? columns?', 'selected data items?',
+                        'selected( data)? rows?', 'selected item identifier',
+                        'selection by rect', 'send action on arrow key',
+                        'sends action when done editing', 'separates columns',
+                        'separator item', 'sequence number', 'services menu',
+                        'shared frameworks path', 'shared support path',
+                        'sheet', 'shift key down', 'shows alpha',
+                        'shows state by', 'size( mode)?',
+                        'smart insert delete enabled', 'sort case sensitivity',
+                        'sort column', 'sort order', 'sort type',
+                        'sorted( data rows)?', 'sound', 'source( mask)?',
+                        'spell checking enabled', 'starting page', 'state',
+                        'string value', 'sub menu', 'super menu', 'super view',
+                        'tab key traverses cells', 'tab state', 'tab type',
+                        'tab view', 'table view', 'tag', 'target( printer)?',
+                        'text color', 'text container insert',
+                        'text container origin', 'text returned',
+                        'tick mark position', 'time stamp',
+                        'title(d| (cell|font|height|position|rect))?',
+                        'tool tip', 'toolbar', 'trailing offset', 'transparent',
+                        'treat packages as directories', 'truncated labels',
+                        'types', 'unmodified characters', 'update views',
+                        'use sort indicator', 'user defaults',
+                        'uses data source', 'uses ruler',
+                        'uses threaded animation',
+                        'uses title from previous column', 'value wraps',
+                        'version',
+                        'vertical( (line scroll|page scroll|ruler view))?',
+                        'vertically resizable', 'view',
+                        'visible( document rect)?', 'volume', 'width', 'window',
+                        'windows menu', 'wraps', 'zoomable', 'zoomed')
+
+    tokens = {
+        'root': [
+            (r'\s+', Text),
+            (r'¬\n', String.Escape),
+            (r"'s\s+", Text),  # This is a possessive, consider moving
+            (r'(--|#).*?$', Comment),
+            (r'\(\*', Comment.Multiline, 'comment'),
+            (r'[(){}!,.:]', Punctuation),
+            (r'(«)([^»]+)(»)',
+             bygroups(Text, Name.Builtin, Text)),
+            (r'\b((?:considering|ignoring)\s*)'
+             r'(application responses|case|diacriticals|hyphens|'
+             r'numeric strings|punctuation|white space)',
+             bygroups(Keyword, Name.Builtin)),
+            (r'(-|\*|\+|&|≠|>=?|<=?|=|≥|≤|/|÷|\^)', Operator),
+            (r"\b({})\b".format('|'.join(Operators)), Operator.Word),
+            (r'^(\s*(?:on|end)\s+)'
+             r'({})'.format('|'.join(StudioEvents[::-1])),
+             bygroups(Keyword, Name.Function)),
+            (r'^(\s*)(in|on|script|to)(\s+)', bygroups(Text, Keyword, Text)),
+            (r'\b(as )({})\b'.format('|'.join(Classes)),
+             bygroups(Keyword, Name.Class)),
+            (r'\b({})\b'.format('|'.join(Literals)), Name.Constant),
+            (r'\b({})\b'.format('|'.join(Commands)), Name.Builtin),
+            (r'\b({})\b'.format('|'.join(Control)), Keyword),
+            (r'\b({})\b'.format('|'.join(Declarations)), Keyword),
+            (r'\b({})\b'.format('|'.join(Reserved)), Name.Builtin),
+            (r'\b({})s?\b'.format('|'.join(BuiltIn)), Name.Builtin),
+            (r'\b({})\b'.format('|'.join(HandlerParams)), Name.Builtin),
+            (r'\b({})\b'.format('|'.join(StudioProperties)), Name.Attribute),
+            (r'\b({})s?\b'.format('|'.join(StudioClasses)), Name.Builtin),
+            (r'\b({})\b'.format('|'.join(StudioCommands)), Name.Builtin),
+            (r'\b({})\b'.format('|'.join(References)), Name.Builtin),
+            (r'"(\\\\|\\[^\\]|[^"\\])*"', String.Double),
+            (rf'\b({Identifiers})\b', Name.Variable),
+            (r'[-+]?(\d+\.\d*|\d*\.\d+)(E[-+][0-9]+)?', Number.Float),
+            (r'[-+]?\d+', Number.Integer),
+        ],
+        'comment': [
+            (r'\(\*', Comment.Multiline, '#push'),
+            (r'\*\)', Comment.Multiline, '#pop'),
+            ('[^*(]+', Comment.Multiline),
+            ('[*(]', Comment.Multiline),
+        ],
+    }
+
+
+class RexxLexer(RegexLexer):
+    """
+    Rexx is a scripting language available for
+    a wide range of different platforms with its roots found on mainframe
+    systems. It is popular for I/O- and data based tasks and can act as glue
+    language to bind different applications together.
+    """
+    name = 'Rexx'
+    url = 'http://www.rexxinfo.org/'
+    aliases = ['rexx', 'arexx']
+    filenames = ['*.rexx', '*.rex', '*.rx', '*.arexx']
+    mimetypes = ['text/x-rexx']
+    version_added = '2.0'
+    flags = re.IGNORECASE
+
+    tokens = {
+        'root': [
+            (r'\s+', Whitespace),
+            (r'/\*', Comment.Multiline, 'comment'),
+            (r'"', String, 'string_double'),
+            (r"'", String, 'string_single'),
+            (r'[0-9]+(\.[0-9]+)?(e[+-]?[0-9])?', Number),
+            (r'([a-z_]\w*)(\s*)(:)(\s*)(procedure)\b',
+             bygroups(Name.Function, Whitespace, Operator, Whitespace,
+                      Keyword.Declaration)),
+            (r'([a-z_]\w*)(\s*)(:)',
+             bygroups(Name.Label, Whitespace, Operator)),
+            include('function'),
+            include('keyword'),
+            include('operator'),
+            (r'[a-z_]\w*', Text),
+        ],
+        'function': [
+            (words((
+                'abbrev', 'abs', 'address', 'arg', 'b2x', 'bitand', 'bitor', 'bitxor',
+                'c2d', 'c2x', 'center', 'charin', 'charout', 'chars', 'compare',
+                'condition', 'copies', 'd2c', 'd2x', 'datatype', 'date', 'delstr',
+                'delword', 'digits', 'errortext', 'form', 'format', 'fuzz', 'insert',
+                'lastpos', 'left', 'length', 'linein', 'lineout', 'lines', 'max',
+                'min', 'overlay', 'pos', 'queued', 'random', 'reverse', 'right', 'sign',
+                'sourceline', 'space', 'stream', 'strip', 'substr', 'subword', 'symbol',
+                'time', 'trace', 'translate', 'trunc', 'value', 'verify', 'word',
+                'wordindex', 'wordlength', 'wordpos', 'words', 'x2b', 'x2c', 'x2d',
+                'xrange'), suffix=r'(\s*)(\()'),
+             bygroups(Name.Builtin, Whitespace, Operator)),
+        ],
+        'keyword': [
+            (r'(address|arg|by|call|do|drop|else|end|exit|for|forever|if|'
+             r'interpret|iterate|leave|nop|numeric|off|on|options|parse|'
+             r'pull|push|queue|return|say|select|signal|to|then|trace|until|'
+             r'while)\b', Keyword.Reserved),
+        ],
+        'operator': [
+            (r'(-|//|/|\(|\)|\*\*|\*|\\<<|\\<|\\==|\\=|\\>>|\\>|\\|\|\||\||'
+             r'&&|&|%|\+|<<=|<<|<=|<>|<|==|=|><|>=|>>=|>>|>|¬<<|¬<|¬==|¬=|'
+             r'¬>>|¬>|¬|\.|,)', Operator),
+        ],
+        'string_double': [
+            (r'[^"\n]+', String),
+            (r'""', String),
+            (r'"', String, '#pop'),
+            (r'\n', Text, '#pop'),  # Stray linefeed also terminates strings.
+        ],
+        'string_single': [
+            (r'[^\'\n]+', String),
+            (r'\'\'', String),
+            (r'\'', String, '#pop'),
+            (r'\n', Text, '#pop'),  # Stray linefeed also terminates strings.
+        ],
+        'comment': [
+            (r'[^*]+', Comment.Multiline),
+            (r'\*/', Comment.Multiline, '#pop'),
+            (r'\*', Comment.Multiline),
+        ]
+    }
+
+    def _c(s):
+        return re.compile(s, re.MULTILINE)
+    _ADDRESS_COMMAND_PATTERN = _c(r'^\s*address\s+command\b')
+    _ADDRESS_PATTERN = _c(r'^\s*address\s+')
+    _DO_WHILE_PATTERN = _c(r'^\s*do\s+while\b')
+    _IF_THEN_DO_PATTERN = _c(r'^\s*if\b.+\bthen\s+do\s*$')
+    _PROCEDURE_PATTERN = _c(r'^\s*([a-z_]\w*)(\s*)(:)(\s*)(procedure)\b')
+    _ELSE_DO_PATTERN = _c(r'\belse\s+do\s*$')
+    _PARSE_ARG_PATTERN = _c(r'^\s*parse\s+(upper\s+)?(arg|value)\b')
+    PATTERNS_AND_WEIGHTS = (
+        (_ADDRESS_COMMAND_PATTERN, 0.2),
+        (_ADDRESS_PATTERN, 0.05),
+        (_DO_WHILE_PATTERN, 0.1),
+        (_ELSE_DO_PATTERN, 0.1),
+        (_IF_THEN_DO_PATTERN, 0.1),
+        (_PROCEDURE_PATTERN, 0.5),
+        (_PARSE_ARG_PATTERN, 0.2),
+    )
+
+    def analyse_text(text):
+        """
+        Check for initial comment and patterns that distinguish Rexx from other
+        C-like languages.
+        """
+        if re.search(r'/\*\**\s*rexx', text, re.IGNORECASE):
+            # Header matches MVS Rexx requirements, this is certainly a Rexx
+            # script.
+            return 1.0
+        elif text.startswith('/*'):
+            # Header matches general Rexx requirements; the source code might
+            # still be any language using C comments such as C++, C# or Java.
+            lowerText = text.lower()
+            result = sum(weight
+                         for (pattern, weight) in RexxLexer.PATTERNS_AND_WEIGHTS
+                         if pattern.search(lowerText)) + 0.01
+            return min(result, 1.0)
+
+
+class MOOCodeLexer(RegexLexer):
+    """
+    For MOOCode (the MOO scripting language).
+    """
+    name = 'MOOCode'
+    url = 'http://www.moo.mud.org/'
+    filenames = ['*.moo']
+    aliases = ['moocode', 'moo']
+    mimetypes = ['text/x-moocode']
+    version_added = '0.9'
+
+    tokens = {
+        'root': [
+            # Numbers
+            (r'(0|[1-9][0-9_]*)', Number.Integer),
+            # Strings
+            (r'"(\\\\|\\[^\\]|[^"\\])*"', String),
+            # exceptions
+            (r'(E_PERM|E_DIV)', Name.Exception),
+            # db-refs
+            (r'((#[-0-9]+)|(\$\w+))', Name.Entity),
+            # Keywords
+            (r'\b(if|else|elseif|endif|for|endfor|fork|endfork|while'
+             r'|endwhile|break|continue|return|try'
+             r'|except|endtry|finally|in)\b', Keyword),
+            # builtins
+            (r'(random|length)', Name.Builtin),
+            # special variables
+            (r'(player|caller|this|args)', Name.Variable.Instance),
+            # skip whitespace
+            (r'\s+', Text),
+            (r'\n', Text),
+            # other operators
+            (r'([!;=,{}&|:.\[\]@()<>?]+)', Operator),
+            # function call
+            (r'(\w+)(\()', bygroups(Name.Function, Operator)),
+            # variables
+            (r'(\w+)', Text),
+        ]
+    }
+
+
+class HybrisLexer(RegexLexer):
+    """
+    For Hybris source code.
+    """
+
+    name = 'Hybris'
+    aliases = ['hybris']
+    filenames = ['*.hyb']
+    mimetypes = ['text/x-hybris', 'application/x-hybris']
+    url = 'https://github.com/evilsocket/hybris'
+    version_added = '1.4'
+
+    flags = re.MULTILINE | re.DOTALL
+
+    tokens = {
+        'root': [
+            # method names
+            (r'^(\s*(?:function|method|operator\s+)+?)'
+             r'([a-zA-Z_]\w*)'
+             r'(\s*)(\()', bygroups(Keyword, Name.Function, Text, Operator)),
+            (r'[^\S\n]+', Text),
+            (r'//.*?\n', Comment.Single),
+            (r'/\*.*?\*/', Comment.Multiline),
+            (r'@[a-zA-Z_][\w.]*', Name.Decorator),
+            (r'(break|case|catch|next|default|do|else|finally|for|foreach|of|'
+             r'unless|if|new|return|switch|me|throw|try|while)\b', Keyword),
+            (r'(extends|private|protected|public|static|throws|function|method|'
+             r'operator)\b', Keyword.Declaration),
+            (r'(true|false|null|__FILE__|__LINE__|__VERSION__|__LIB_PATH__|'
+             r'__INC_PATH__)\b', Keyword.Constant),
+            (r'(class|struct)(\s+)',
+             bygroups(Keyword.Declaration, Text), 'class'),
+            (r'(import|include)(\s+)',
+             bygroups(Keyword.Namespace, Text), 'import'),
+            (words((
+                'gc_collect', 'gc_mm_items', 'gc_mm_usage', 'gc_collect_threshold',
+                'urlencode', 'urldecode', 'base64encode', 'base64decode', 'sha1', 'crc32',
+                'sha2', 'md5', 'md5_file', 'acos', 'asin', 'atan', 'atan2', 'ceil', 'cos',
+                'cosh', 'exp', 'fabs', 'floor', 'fmod', 'log', 'log10', 'pow', 'sin',
+                'sinh', 'sqrt', 'tan', 'tanh', 'isint', 'isfloat', 'ischar', 'isstring',
+                'isarray', 'ismap', 'isalias', 'typeof', 'sizeof', 'toint', 'tostring',
+                'fromxml', 'toxml', 'binary', 'pack', 'load', 'eval', 'var_names',
+                'var_values', 'user_functions', 'dyn_functions', 'methods', 'call',
+                'call_method', 'mknod', 'mkfifo', 'mount', 'umount2', 'umount', 'ticks',
+                'usleep', 'sleep', 'time', 'strtime', 'strdate', 'dllopen', 'dlllink',
+                'dllcall', 'dllcall_argv', 'dllclose', 'env', 'exec', 'fork', 'getpid',
+                'wait', 'popen', 'pclose', 'exit', 'kill', 'pthread_create',
+                'pthread_create_argv', 'pthread_exit', 'pthread_join', 'pthread_kill',
+                'smtp_send', 'http_get', 'http_post', 'http_download', 'socket', 'bind',
+                'listen', 'accept', 'getsockname', 'getpeername', 'settimeout', 'connect',
+                'server', 'recv', 'send', 'close', 'print', 'println', 'printf', 'input',
+                'readline', 'serial_open', 'serial_fcntl', 'serial_get_attr',
+                'serial_get_ispeed', 'serial_get_ospeed', 'serial_set_attr',
+                'serial_set_ispeed', 'serial_set_ospeed', 'serial_write', 'serial_read',
+                'serial_close', 'xml_load', 'xml_parse', 'fopen', 'fseek', 'ftell',
+                'fsize', 'fread', 'fwrite', 'fgets', 'fclose', 'file', 'readdir',
+                'pcre_replace', 'size', 'pop', 'unmap', 'has', 'keys', 'values',
+                'length', 'find', 'substr', 'replace', 'split', 'trim', 'remove',
+                'contains', 'join'), suffix=r'\b'),
+             Name.Builtin),
+            (words((
+                'MethodReference', 'Runner', 'Dll', 'Thread', 'Pipe', 'Process',
+                'Runnable', 'CGI', 'ClientSocket', 'Socket', 'ServerSocket',
+                'File', 'Console', 'Directory', 'Exception'), suffix=r'\b'),
+             Keyword.Type),
+            (r'"(\\\\|\\[^\\]|[^"\\])*"', String),
+            (r"'\\.'|'[^\\]'|'\\u[0-9a-f]{4}'", String.Char),
+            (r'(\.)([a-zA-Z_]\w*)',
+             bygroups(Operator, Name.Attribute)),
+            (r'[a-zA-Z_]\w*:', Name.Label),
+            (r'[a-zA-Z_$]\w*', Name),
+            (r'[~^*!%&\[\](){}<>|+=:;,./?\-@]+', Operator),
+            (r'[0-9][0-9]*\.[0-9]+([eE][0-9]+)?[fd]?', Number.Float),
+            (r'0x[0-9a-f]+', Number.Hex),
+            (r'[0-9]+L?', Number.Integer),
+            (r'\n', Text),
+        ],
+        'class': [
+            (r'[a-zA-Z_]\w*', Name.Class, '#pop')
+        ],
+        'import': [
+            (r'[\w.]+\*?', Name.Namespace, '#pop')
+        ],
+    }
+
+    def analyse_text(text):
+        """public method and private method don't seem to be quite common
+        elsewhere."""
+        result = 0
+        if re.search(r'\b(?:public|private)\s+method\b', text):
+            result += 0.01
+        return result
+
+
+
+class EasytrieveLexer(RegexLexer):
+    """
+    Easytrieve Plus is a programming language for extracting, filtering and
+    converting sequential data. Furthermore it can layout data for reports.
+    It is mainly used on mainframe platforms and can access several of the
+    mainframe's native file formats. It is somewhat comparable to awk.
+    """
+    name = 'Easytrieve'
+    aliases = ['easytrieve']
+    filenames = ['*.ezt', '*.mac']
+    mimetypes = ['text/x-easytrieve']
+    url = 'https://www.broadcom.com/products/mainframe/application-development/easytrieve-report-generator'
+    version_added = '2.1'
+    flags = 0
+
+    # Note: We cannot use r'\b' at the start and end of keywords because
+    # Easytrieve Plus delimiter characters are:
+    #
+    #   * space ( )
+    #   * apostrophe (')
+    #   * period (.)
+    #   * comma (,)
+    #   * parenthesis ( and )
+    #   * colon (:)
+    #
+    # Additionally words end once a '*' appears, indicatins a comment.
+    _DELIMITERS = r' \'.,():\n'
+    _DELIMITERS_OR_COMENT = _DELIMITERS + '*'
+    _DELIMITER_PATTERN = '[' + _DELIMITERS + ']'
+    _DELIMITER_PATTERN_CAPTURE = '(' + _DELIMITER_PATTERN + ')'
+    _NON_DELIMITER_OR_COMMENT_PATTERN = '[^' + _DELIMITERS_OR_COMENT + ']'
+    _OPERATORS_PATTERN = '[.+\\-/=\\[\\](){}<>;,&%¬]'
+    _KEYWORDS = [
+        'AFTER-BREAK', 'AFTER-LINE', 'AFTER-SCREEN', 'AIM', 'AND', 'ATTR',
+        'BEFORE', 'BEFORE-BREAK', 'BEFORE-LINE', 'BEFORE-SCREEN', 'BUSHU',
+        'BY', 'CALL', 'CASE', 'CHECKPOINT', 'CHKP', 'CHKP-STATUS', 'CLEAR',
+        'CLOSE', 'COL', 'COLOR', 'COMMIT', 'CONTROL', 'COPY', 'CURSOR', 'D',
+        'DECLARE', 'DEFAULT', 'DEFINE', 'DELETE', 'DENWA', 'DISPLAY', 'DLI',
+        'DO', 'DUPLICATE', 'E', 'ELSE', 'ELSE-IF', 'END', 'END-CASE',
+        'END-DO', 'END-IF', 'END-PROC', 'ENDPAGE', 'ENDTABLE', 'ENTER', 'EOF',
+        'EQ', 'ERROR', 'EXIT', 'EXTERNAL', 'EZLIB', 'F1', 'F10', 'F11', 'F12',
+        'F13', 'F14', 'F15', 'F16', 'F17', 'F18', 'F19', 'F2', 'F20', 'F21',
+        'F22', 'F23', 'F24', 'F25', 'F26', 'F27', 'F28', 'F29', 'F3', 'F30',
+        'F31', 'F32', 'F33', 'F34', 'F35', 'F36', 'F4', 'F5', 'F6', 'F7',
+        'F8', 'F9', 'FETCH', 'FILE-STATUS', 'FILL', 'FINAL', 'FIRST',
+        'FIRST-DUP', 'FOR', 'GE', 'GET', 'GO', 'GOTO', 'GQ', 'GR', 'GT',
+        'HEADING', 'HEX', 'HIGH-VALUES', 'IDD', 'IDMS', 'IF', 'IN', 'INSERT',
+        'JUSTIFY', 'KANJI-DATE', 'KANJI-DATE-LONG', 'KANJI-TIME', 'KEY',
+        'KEY-PRESSED', 'KOKUGO', 'KUN', 'LAST-DUP', 'LE', 'LEVEL', 'LIKE',
+        'LINE', 'LINE-COUNT', 'LINE-NUMBER', 'LINK', 'LIST', 'LOW-VALUES',
+        'LQ', 'LS', 'LT', 'MACRO', 'MASK', 'MATCHED', 'MEND', 'MESSAGE',
+        'MOVE', 'MSTART', 'NE', 'NEWPAGE', 'NOMASK', 'NOPRINT', 'NOT',
+        'NOTE', 'NOVERIFY', 'NQ', 'NULL', 'OF', 'OR', 'OTHERWISE', 'PA1',
+        'PA2', 'PA3', 'PAGE-COUNT', 'PAGE-NUMBER', 'PARM-REGISTER',
+        'PATH-ID', 'PATTERN', 'PERFORM', 'POINT', 'POS', 'PRIMARY', 'PRINT',
+        'PROCEDURE', 'PROGRAM', 'PUT', 'READ', 'RECORD', 'RECORD-COUNT',
+        'RECORD-LENGTH', 'REFRESH', 'RELEASE', 'RENUM', 'REPEAT', 'REPORT',
+        'REPORT-INPUT', 'RESHOW', 'RESTART', 'RETRIEVE', 'RETURN-CODE',
+        'ROLLBACK', 'ROW', 'S', 'SCREEN', 'SEARCH', 'SECONDARY', 'SELECT',
+        'SEQUENCE', 'SIZE', 'SKIP', 'SOKAKU', 'SORT', 'SQL', 'STOP', 'SUM',
+        'SYSDATE', 'SYSDATE-LONG', 'SYSIN', 'SYSIPT', 'SYSLST', 'SYSPRINT',
+        'SYSSNAP', 'SYSTIME', 'TALLY', 'TERM-COLUMNS', 'TERM-NAME',
+        'TERM-ROWS', 'TERMINATION', 'TITLE', 'TO', 'TRANSFER', 'TRC',
+        'UNIQUE', 'UNTIL', 'UPDATE', 'UPPERCASE', 'USER', 'USERID', 'VALUE',
+        'VERIFY', 'W', 'WHEN', 'WHILE', 'WORK', 'WRITE', 'X', 'XDM', 'XRST'
+    ]
+
+    tokens = {
+        'root': [
+            (r'\*.*\n', Comment.Single),
+            (r'\n+', Whitespace),
+            # Macro argument
+            (r'&' + _NON_DELIMITER_OR_COMMENT_PATTERN + r'+\.', Name.Variable,
+             'after_macro_argument'),
+            # Macro call
+            (r'%' + _NON_DELIMITER_OR_COMMENT_PATTERN + r'+', Name.Variable),
+            (r'(FILE|MACRO|REPORT)(\s+)',
+             bygroups(Keyword.Declaration, Whitespace), 'after_declaration'),
+            (r'(JOB|PARM)' + r'(' + _DELIMITER_PATTERN + r')',
+             bygroups(Keyword.Declaration, Operator)),
+            (words(_KEYWORDS, suffix=_DELIMITER_PATTERN_CAPTURE),
+             bygroups(Keyword.Reserved, Operator)),
+            (_OPERATORS_PATTERN, Operator),
+            # Procedure declaration
+            (r'(' + _NON_DELIMITER_OR_COMMENT_PATTERN + r'+)(\s*)(\.?)(\s*)(PROC)(\s*\n)',
+             bygroups(Name.Function, Whitespace, Operator, Whitespace,
+                      Keyword.Declaration, Whitespace)),
+            (r'[0-9]+\.[0-9]*', Number.Float),
+            (r'[0-9]+', Number.Integer),
+            (r"'(''|[^'])*'", String),
+            (r'\s+', Whitespace),
+            # Everything else just belongs to a name
+            (_NON_DELIMITER_OR_COMMENT_PATTERN + r'+', Name),
+         ],
+        'after_declaration': [
+            (_NON_DELIMITER_OR_COMMENT_PATTERN + r'+', Name.Function),
+            default('#pop'),
+        ],
+        'after_macro_argument': [
+            (r'\*.*\n', Comment.Single, '#pop'),
+            (r'\s+', Whitespace, '#pop'),
+            (_OPERATORS_PATTERN, Operator, '#pop'),
+            (r"'(''|[^'])*'", String, '#pop'),
+            # Everything else just belongs to a name
+            (_NON_DELIMITER_OR_COMMENT_PATTERN + r'+', Name),
+        ],
+    }
+    _COMMENT_LINE_REGEX = re.compile(r'^\s*\*')
+    _MACRO_HEADER_REGEX = re.compile(r'^\s*MACRO')
+
+    def analyse_text(text):
+        """
+        Perform a structural analysis for basic Easytrieve constructs.
+        """
+        result = 0.0
+        lines = text.split('\n')
+        hasEndProc = False
+        hasHeaderComment = False
+        hasFile = False
+        hasJob = False
+        hasProc = False
+        hasParm = False
+        hasReport = False
+
+        def isCommentLine(line):
+            return EasytrieveLexer._COMMENT_LINE_REGEX.match(lines[0]) is not None
+
+        def isEmptyLine(line):
+            return not bool(line.strip())
+
+        # Remove possible empty lines and header comments.
+        while lines and (isEmptyLine(lines[0]) or isCommentLine(lines[0])):
+            if not isEmptyLine(lines[0]):
+                hasHeaderComment = True
+            del lines[0]
+
+        if EasytrieveLexer._MACRO_HEADER_REGEX.match(lines[0]):
+            # Looks like an Easytrieve macro.
+            result = 0.4
+            if hasHeaderComment:
+                result += 0.4
+        else:
+            # Scan the source for lines starting with indicators.
+            for line in lines:
+                words = line.split()
+                if (len(words) >= 2):
+                    firstWord = words[0]
+                    if not hasReport:
+                        if not hasJob:
+                            if not hasFile:
+                                if not hasParm:
+                                    if firstWord == 'PARM':
+                                        hasParm = True
+                                if firstWord == 'FILE':
+                                    hasFile = True
+                            if firstWord == 'JOB':
+                                hasJob = True
+                        elif firstWord == 'PROC':
+                            hasProc = True
+                        elif firstWord == 'END-PROC':
+                            hasEndProc = True
+                        elif firstWord == 'REPORT':
+                            hasReport = True
+
+            # Weight the findings.
+            if hasJob and (hasProc == hasEndProc):
+                if hasHeaderComment:
+                    result += 0.1
+                if hasParm:
+                    if hasProc:
+                        # Found PARM, JOB and PROC/END-PROC:
+                        # pretty sure this is Easytrieve.
+                        result += 0.8
+                    else:
+                        # Found PARAM and  JOB: probably this is Easytrieve
+                        result += 0.5
+                else:
+                    # Found JOB and possibly other keywords: might be Easytrieve
+                    result += 0.11
+                    if hasParm:
+                        # Note: PARAM is not a proper English word, so this is
+                        # regarded a much better indicator for Easytrieve than
+                        # the other words.
+                        result += 0.2
+                    if hasFile:
+                        result += 0.01
+                    if hasReport:
+                        result += 0.01
+        assert 0.0 <= result <= 1.0
+        return result
+
+
+class JclLexer(RegexLexer):
+    """
+    Job Control Language (JCL)
+    is a scripting language used on mainframe platforms to instruct the system
+    on how to run a batch job or start a subsystem. It is somewhat
+    comparable to MS DOS batch and Unix shell scripts.
+    """
+    name = 'JCL'
+    aliases = ['jcl']
+    filenames = ['*.jcl']
+    mimetypes = ['text/x-jcl']
+    url = 'https://en.wikipedia.org/wiki/Job_Control_Language'
+    version_added = '2.1'
+
+    flags = re.IGNORECASE
+
+    tokens = {
+        'root': [
+            (r'//\*.*\n', Comment.Single),
+            (r'//', Keyword.Pseudo, 'statement'),
+            (r'/\*', Keyword.Pseudo, 'jes2_statement'),
+            # TODO: JES3 statement
+            (r'.*\n', Other)  # Input text or inline code in any language.
+        ],
+        'statement': [
+            (r'\s*\n', Whitespace, '#pop'),
+            (r'([a-z]\w*)(\s+)(exec|job)(\s*)',
+             bygroups(Name.Label, Whitespace, Keyword.Reserved, Whitespace),
+             'option'),
+            (r'[a-z]\w*', Name.Variable, 'statement_command'),
+            (r'\s+', Whitespace, 'statement_command'),
+        ],
+        'statement_command': [
+            (r'\s+(command|cntl|dd|endctl|endif|else|include|jcllib|'
+             r'output|pend|proc|set|then|xmit)\s+', Keyword.Reserved, 'option'),
+            include('option')
+        ],
+        'jes2_statement': [
+            (r'\s*\n', Whitespace, '#pop'),
+            (r'\$', Keyword, 'option'),
+            (r'\b(jobparam|message|netacct|notify|output|priority|route|'
+             r'setup|signoff|xeq|xmit)\b', Keyword, 'option'),
+        ],
+        'option': [
+            # (r'\n', Text, 'root'),
+            (r'\*', Name.Builtin),
+            (r'[\[\](){}<>;,]', Punctuation),
+            (r'[-+*/=&%]', Operator),
+            (r'[a-z_]\w*', Name),
+            (r'\d+\.\d*', Number.Float),
+            (r'\.\d+', Number.Float),
+            (r'\d+', Number.Integer),
+            (r"'", String, 'option_string'),
+            (r'[ \t]+', Whitespace, 'option_comment'),
+            (r'\.', Punctuation),
+        ],
+        'option_string': [
+            (r"(\n)(//)", bygroups(Text, Keyword.Pseudo)),
+            (r"''", String),
+            (r"[^']", String),
+            (r"'", String, '#pop'),
+        ],
+        'option_comment': [
+            # (r'\n', Text, 'root'),
+            (r'.+', Comment.Single),
+        ]
+    }
+
+    _JOB_HEADER_PATTERN = re.compile(r'^//[a-z#$@][a-z0-9#$@]{0,7}\s+job(\s+.*)?$',
+                                     re.IGNORECASE)
+
+    def analyse_text(text):
+        """
+        Recognize JCL job by header.
+        """
+        result = 0.0
+        lines = text.split('\n')
+        if len(lines) > 0:
+            if JclLexer._JOB_HEADER_PATTERN.match(lines[0]):
+                result = 1.0
+        assert 0.0 <= result <= 1.0
+        return result
+
+
+class MiniScriptLexer(RegexLexer):
+    """
+    For MiniScript source code.
+    """
+
+    name = 'MiniScript'
+    url = 'https://miniscript.org'
+    aliases = ['miniscript', 'ms']
+    filenames = ['*.ms']
+    mimetypes = ['text/x-minicript', 'application/x-miniscript']
+    version_added = '2.6'
+
+    tokens = {
+        'root': [
+            (r'#!(.*?)$', Comment.Preproc),
+            default('base'),
+        ],
+        'base': [
+            ('//.*$', Comment.Single),
+            (r'(?i)(\d*\.\d+|\d+\.\d*)(e[+-]?\d+)?', Number),
+            (r'(?i)\d+e[+-]?\d+', Number),
+            (r'\d+', Number),
+            (r'\n', Text),
+            (r'[^\S\n]+', Text),
+            (r'"', String, 'string_double'),
+            (r'(==|!=|<=|>=|[=+\-*/%^<>.:])', Operator),
+            (r'[;,\[\]{}()]', Punctuation),
+            (words((
+                'break', 'continue', 'else', 'end', 'for', 'function', 'if',
+                'in', 'isa', 'then', 'repeat', 'return', 'while'), suffix=r'\b'),
+             Keyword),
+            (words((
+                'abs', 'acos', 'asin', 'atan', 'ceil', 'char', 'cos', 'floor',
+                'log', 'round', 'rnd', 'pi', 'sign', 'sin', 'sqrt', 'str', 'tan',
+                'hasIndex', 'indexOf', 'len', 'val', 'code', 'remove', 'lower',
+                'upper', 'replace', 'split', 'indexes', 'values', 'join', 'sum',
+                'sort', 'shuffle', 'push', 'pop', 'pull', 'range',
+                'print', 'input', 'time', 'wait', 'locals', 'globals', 'outer',
+                'yield'), suffix=r'\b'),
+             Name.Builtin),
+            (r'(true|false|null)\b', Keyword.Constant),
+            (r'(and|or|not|new)\b', Operator.Word),
+            (r'(self|super|__isa)\b', Name.Builtin.Pseudo),
+            (r'[a-zA-Z_]\w*', Name.Variable)
+        ],
+        'string_double': [
+            (r'[^"\n]+', String),
+            (r'""', String),
+            (r'"', String, '#pop'),
+            (r'\n', Text, '#pop'),  # Stray linefeed also terminates strings.
+        ]
+    }
diff --git a/pygments/lexers/sgf.py b/pygments/lexers/sgf.py
new file mode 100644
index 0000000000000000000000000000000000000000..f0e56cba55242ba7874e1e71186be35986537407
--- /dev/null
+++ b/pygments/lexers/sgf.py
@@ -0,0 +1,59 @@
+"""
+    pygments.lexers.sgf
+    ~~~~~~~~~~~~~~~~~~~
+
+    Lexer for Smart Game Format (sgf) file format.
+
+    :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS.
+    :license: BSD, see LICENSE for details.
+"""
+
+from pygments.lexer import RegexLexer, bygroups
+from pygments.token import Name, Literal, String, Punctuation, Whitespace
+
+__all__ = ["SmartGameFormatLexer"]
+
+
+class SmartGameFormatLexer(RegexLexer):
+    """
+    Lexer for Smart Game Format (sgf) file format.
+
+    The format is used to store game records of board games for two players
+    (mainly Go game).
+    """
+    name = 'SmartGameFormat'
+    url = 'https://www.red-bean.com/sgf/'
+    aliases = ['sgf']
+    filenames = ['*.sgf']
+    version_added = '2.4'
+
+    tokens = {
+        'root': [
+            (r'[():;]+', Punctuation),
+            # tokens:
+            (r'(A[BW]|AE|AN|AP|AR|AS|[BW]L|BM|[BW]R|[BW]S|[BW]T|CA|CH|CP|CR|'
+             r'DD|DM|DO|DT|EL|EV|EX|FF|FG|G[BW]|GC|GM|GN|HA|HO|ID|IP|IT|IY|KM|'
+             r'KO|LB|LN|LT|L|MA|MN|M|N|OB|OM|ON|OP|OT|OV|P[BW]|PC|PL|PM|RE|RG|'
+             r'RO|RU|SO|SC|SE|SI|SL|SO|SQ|ST|SU|SZ|T[BW]|TC|TE|TM|TR|UC|US|VW|'
+             r'V|[BW]|C)',
+             Name.Builtin),
+            # number:
+            (r'(\[)([0-9.]+)(\])',
+             bygroups(Punctuation, Literal.Number, Punctuation)),
+            # date:
+            (r'(\[)([0-9]{4}-[0-9]{2}-[0-9]{2})(\])',
+             bygroups(Punctuation, Literal.Date, Punctuation)),
+            # point:
+            (r'(\[)([a-z]{2})(\])',
+             bygroups(Punctuation, String, Punctuation)),
+            # double points:
+            (r'(\[)([a-z]{2})(:)([a-z]{2})(\])',
+             bygroups(Punctuation, String, Punctuation, String, Punctuation)),
+
+            (r'(\[)([\w\s#()+,\-.:?]+)(\])',
+             bygroups(Punctuation, String, Punctuation)),
+            (r'(\[)(\s.*)(\])',
+             bygroups(Punctuation, Whitespace, Punctuation)),
+            (r'\s+', Whitespace)
+        ],
+    }
diff --git a/pygments/lexers/shell.py b/pygments/lexers/shell.py
new file mode 100644
index 0000000000000000000000000000000000000000..744767a1d4464f838b7c062f63c850e3a2f18554
--- /dev/null
+++ b/pygments/lexers/shell.py
@@ -0,0 +1,902 @@
+"""
+    pygments.lexers.shell
+    ~~~~~~~~~~~~~~~~~~~~~
+
+    Lexers for various shells.
+
+    :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS.
+    :license: BSD, see LICENSE for details.
+"""
+
+import re
+
+from pygments.lexer import Lexer, RegexLexer, do_insertions, bygroups, \
+    include, default, this, using, words, line_re
+from pygments.token import Punctuation, Whitespace, \
+    Text, Comment, Operator, Keyword, Name, String, Number, Generic
+from pygments.util import shebang_matches
+
+__all__ = ['BashLexer', 'BashSessionLexer', 'TcshLexer', 'BatchLexer',
+           'SlurmBashLexer', 'MSDOSSessionLexer', 'PowerShellLexer',
+           'PowerShellSessionLexer', 'TcshSessionLexer', 'FishShellLexer',
+           'ExeclineLexer']
+
+
+class BashLexer(RegexLexer):
+    """
+    Lexer for (ba|k|z|)sh shell scripts.
+    """
+
+    name = 'Bash'
+    aliases = ['bash', 'sh', 'ksh', 'zsh', 'shell', 'openrc']
+    filenames = ['*.sh', '*.ksh', '*.bash', '*.ebuild', '*.eclass',
+                 '*.exheres-0', '*.exlib', '*.zsh',
+                 '.bashrc', 'bashrc', '.bash_*', 'bash_*', 'zshrc', '.zshrc',
+                 '.kshrc', 'kshrc',
+                 'PKGBUILD']
+    mimetypes = ['application/x-sh', 'application/x-shellscript', 'text/x-shellscript']
+    url = 'https://en.wikipedia.org/wiki/Unix_shell'
+    version_added = '0.6'
+
+    tokens = {
+        'root': [
+            include('basic'),
+            (r'`', String.Backtick, 'backticks'),
+            include('data'),
+            include('interp'),
+        ],
+        'interp': [
+            (r'\$\(\(', Keyword, 'math'),
+            (r'\$\(', Keyword, 'paren'),
+            (r'\$\{#?', String.Interpol, 'curly'),
+            (r'\$[a-zA-Z_]\w*', Name.Variable),  # user variable
+            (r'\$(?:\d+|[#$?!_*@-])', Name.Variable),      # builtin
+            (r'\$', Text),
+        ],
+        'basic': [
+            (r'\b(if|fi|else|while|in|do|done|for|then|return|function|case|'
+             r'select|break|continue|until|esac|elif)(\s*)\b',
+             bygroups(Keyword, Whitespace)),
+            (r'\b(alias|bg|bind|builtin|caller|cd|command|compgen|'
+             r'complete|declare|dirs|disown|echo|enable|eval|exec|exit|'
+             r'export|false|fc|fg|getopts|hash|help|history|jobs|kill|let|'
+             r'local|logout|popd|printf|pushd|pwd|read|readonly|set|shift|'
+             r'shopt|source|suspend|test|time|times|trap|true|type|typeset|'
+             r'ulimit|umask|unalias|unset|wait)(?=[\s)`])',
+             Name.Builtin),
+            (r'\A#!.+\n', Comment.Hashbang),
+            (r'#.*\n', Comment.Single),
+            (r'\\[\w\W]', String.Escape),
+            (r'(\b\w+)(\s*)(\+?=)', bygroups(Name.Variable, Whitespace, Operator)),
+            (r'[\[\]{}()=]', Operator),
+            (r'<<<', Operator),  # here-string
+            (r'<<-?\s*(\'?)\\?(\w+)[\w\W]+?\2', String),
+            (r'&&|\|\|', Operator),
+        ],
+        'data': [
+            (r'(?s)\$?"(\\.|[^"\\$])*"', String.Double),
+            (r'"', String.Double, 'string'),
+            (r"(?s)\$'(\\\\|\\[0-7]+|\\.|[^'\\])*'", String.Single),
+            (r"(?s)'.*?'", String.Single),
+            (r';', Punctuation),
+            (r'&', Punctuation),
+            (r'\|', Punctuation),
+            (r'\s+', Whitespace),
+            (r'\d+\b', Number),
+            (r'[^=\s\[\]{}()$"\'`\\<&|;]+', Text),
+            (r'<', Text),
+        ],
+        'string': [
+            (r'"', String.Double, '#pop'),
+            (r'(?s)(\\\\|\\[0-7]+|\\.|[^"\\$])+', String.Double),
+            include('interp'),
+        ],
+        'curly': [
+            (r'\}', String.Interpol, '#pop'),
+            (r':-', Keyword),
+            (r'\w+', Name.Variable),
+            (r'[^}:"\'`$\\]+', Punctuation),
+            (r':', Punctuation),
+            include('root'),
+        ],
+        'paren': [
+            (r'\)', Keyword, '#pop'),
+            include('root'),
+        ],
+        'math': [
+            (r'\)\)', Keyword, '#pop'),
+            (r'\*\*|\|\||<<|>>|[-+*/%^|&<>]', Operator),
+            (r'\d+#[\da-zA-Z]+', Number),
+            (r'\d+#(?! )', Number),
+            (r'0[xX][\da-fA-F]+', Number),
+            (r'\d+', Number),
+            (r'[a-zA-Z_]\w*', Name.Variable),  # user variable
+            include('root'),
+        ],
+        'backticks': [
+            (r'`', String.Backtick, '#pop'),
+            include('root'),
+        ],
+    }
+
+    def analyse_text(text):
+        if shebang_matches(text, r'(ba|z|)sh'):
+            return 1
+        if text.startswith('$ '):
+            return 0.2
+
+
+class SlurmBashLexer(BashLexer):
+    """
+    Lexer for (ba|k|z|)sh Slurm scripts.
+    """
+
+    name = 'Slurm'
+    aliases = ['slurm', 'sbatch']
+    filenames = ['*.sl']
+    mimetypes = []
+    version_added = '2.4'
+    EXTRA_KEYWORDS = {'srun'}
+
+    def get_tokens_unprocessed(self, text):
+        for index, token, value in BashLexer.get_tokens_unprocessed(self, text):
+            if token is Text and value in self.EXTRA_KEYWORDS:
+                yield index, Name.Builtin, value
+            elif token is Comment.Single and 'SBATCH' in value:
+                yield index, Keyword.Pseudo, value
+            else:
+                yield index, token, value
+
+
+class ShellSessionBaseLexer(Lexer):
+    """
+    Base lexer for shell sessions.
+
+    .. versionadded:: 2.1
+    """
+
+    _bare_continuation = False
+    _venv = re.compile(r'^(\([^)]*\))(\s*)')
+
+    def get_tokens_unprocessed(self, text):
+        innerlexer = self._innerLexerCls(**self.options)
+
+        pos = 0
+        curcode = ''
+        insertions = []
+        backslash_continuation = False
+
+        for match in line_re.finditer(text):
+            line = match.group()
+
+            venv_match = self._venv.match(line)
+            if venv_match:
+                venv = venv_match.group(1)
+                venv_whitespace = venv_match.group(2)
+                insertions.append((len(curcode),
+                                   [(0, Generic.Prompt.VirtualEnv, venv)]))
+                if venv_whitespace:
+                    insertions.append((len(curcode),
+                                       [(0, Text, venv_whitespace)]))
+                line = line[venv_match.end():]
+
+            m = self._ps1rgx.match(line)
+            if m:
+                # To support output lexers (say diff output), the output
+                # needs to be broken by prompts whenever the output lexer
+                # changes.
+                if not insertions:
+                    pos = match.start()
+
+                insertions.append((len(curcode),
+                                   [(0, Generic.Prompt, m.group(1))]))
+                curcode += m.group(2)
+                backslash_continuation = curcode.endswith('\\\n')
+            elif backslash_continuation:
+                if line.startswith(self._ps2):
+                    insertions.append((len(curcode),
+                                       [(0, Generic.Prompt,
+                                         line[:len(self._ps2)])]))
+                    curcode += line[len(self._ps2):]
+                else:
+                    curcode += line
+                backslash_continuation = curcode.endswith('\\\n')
+            elif self._bare_continuation and line.startswith(self._ps2):
+                insertions.append((len(curcode),
+                                   [(0, Generic.Prompt,
+                                     line[:len(self._ps2)])]))
+                curcode += line[len(self._ps2):]
+            else:
+                if insertions:
+                    toks = innerlexer.get_tokens_unprocessed(curcode)
+                    for i, t, v in do_insertions(insertions, toks):
+                        yield pos+i, t, v
+                yield match.start(), Generic.Output, line
+                insertions = []
+                curcode = ''
+        if insertions:
+            for i, t, v in do_insertions(insertions,
+                                         innerlexer.get_tokens_unprocessed(curcode)):
+                yield pos+i, t, v
+
+
+class BashSessionLexer(ShellSessionBaseLexer):
+    """
+    Lexer for Bash shell sessions, i.e. command lines, including a
+    prompt, interspersed with output.
+    """
+
+    name = 'Bash Session'
+    aliases = ['console', 'shell-session']
+    filenames = ['*.sh-session', '*.shell-session']
+    mimetypes = ['application/x-shell-session', 'application/x-sh-session']
+    url = 'https://en.wikipedia.org/wiki/Unix_shell'
+    version_added = '1.1'
+    _example = "console/example.sh-session"
+
+    _innerLexerCls = BashLexer
+    _ps1rgx = re.compile(
+        r'^((?:(?:\[.*?\])|(?:\(\S+\))?(?:| |sh\S*?|\w+\S+[@:]\S+(?:\s+\S+)' \
+        r'?|\[\S+[@:][^\n]+\].+))\s*[$#%]\s*)(.*\n?)')
+    _ps2 = '> '
+
+
+class BatchLexer(RegexLexer):
+    """
+    Lexer for the DOS/Windows Batch file format.
+    """
+    name = 'Batchfile'
+    aliases = ['batch', 'bat', 'dosbatch', 'winbatch']
+    filenames = ['*.bat', '*.cmd']
+    mimetypes = ['application/x-dos-batch']
+    url = 'https://en.wikipedia.org/wiki/Batch_file'
+    version_added = '0.7'
+
+    flags = re.MULTILINE | re.IGNORECASE
+
+    _nl = r'\n\x1a'
+    _punct = r'&<>|'
+    _ws = r'\t\v\f\r ,;=\xa0'
+    _nlws = r'\s\x1a\xa0,;='
+    _space = rf'(?:(?:(?:\^[{_nl}])?[{_ws}])+)'
+    _keyword_terminator = (rf'(?=(?:\^[{_nl}]?)?[{_ws}+./:[\\\]]|[{_nl}{_punct}(])')
+    _token_terminator = rf'(?=\^?[{_ws}]|[{_punct}{_nl}])'
+    _start_label = rf'((?:(?<=^[^:])|^[^:]?)[{_ws}]*)(:)'
+    _label = rf'(?:(?:[^{_nlws}{_punct}+:^]|\^[{_nl}]?[\w\W])*)'
+    _label_compound = rf'(?:(?:[^{_nlws}{_punct}+:^)]|\^[{_nl}]?[^)])*)'
+    _number = rf'(?:-?(?:0[0-7]+|0x[\da-f]+|\d+){_token_terminator})'
+    _opword = r'(?:equ|geq|gtr|leq|lss|neq)'
+    _string = rf'(?:"[^{_nl}"]*(?:"|(?=[{_nl}])))'
+    _variable = (r'(?:(?:%(?:\*|(?:~[a-z]*(?:\$[^:]+:)?)?\d|'
+                 rf'[^%:{_nl}]+(?::(?:~(?:-?\d+)?(?:,(?:-?\d+)?)?|(?:[^%{_nl}^]|'
+                 rf'\^[^%{_nl}])[^={_nl}]*=(?:[^%{_nl}^]|\^[^%{_nl}])*)?)?%))|'
+                 rf'(?:\^?![^!:{_nl}]+(?::(?:~(?:-?\d+)?(?:,(?:-?\d+)?)?|(?:'
+                 rf'[^!{_nl}^]|\^[^!{_nl}])[^={_nl}]*=(?:[^!{_nl}^]|\^[^!{_nl}])*)?)?\^?!))')
+    _core_token = rf'(?:(?:(?:\^[{_nl}]?)?[^"{_nlws}{_punct}])+)'
+    _core_token_compound = rf'(?:(?:(?:\^[{_nl}]?)?[^"{_nlws}{_punct})])+)'
+    _token = rf'(?:[{_punct}]+|{_core_token})'
+    _token_compound = rf'(?:[{_punct}]+|{_core_token_compound})'
+    _stoken = (rf'(?:[{_punct}]+|(?:{_string}|{_variable}|{_core_token})+)')
+
+    def _make_begin_state(compound, _core_token=_core_token,
+                          _core_token_compound=_core_token_compound,
+                          _keyword_terminator=_keyword_terminator,
+                          _nl=_nl, _punct=_punct, _string=_string,
+                          _space=_space, _start_label=_start_label,
+                          _stoken=_stoken, _token_terminator=_token_terminator,
+                          _variable=_variable, _ws=_ws):
+        rest = '(?:{}|{}|[^"%{}{}{}])*'.format(_string, _variable, _nl, _punct,
+                                            ')' if compound else '')
+        rest_of_line = rf'(?:(?:[^{_nl}^]|\^[{_nl}]?[\w\W])*)'
+        rest_of_line_compound = rf'(?:(?:[^{_nl}^)]|\^[{_nl}]?[^)])*)'
+        set_space = rf'((?:(?:\^[{_nl}]?)?[^\S\n])*)'
+        suffix = ''
+        if compound:
+            _keyword_terminator = rf'(?:(?=\))|{_keyword_terminator})'
+            _token_terminator = rf'(?:(?=\))|{_token_terminator})'
+            suffix = '/compound'
+        return [
+            ((r'\)', Punctuation, '#pop') if compound else
+             (rf'\)((?=\()|{_token_terminator}){rest_of_line}',
+              Comment.Single)),
+            (rf'(?={_start_label})', Text, f'follow{suffix}'),
+            (_space, using(this, state='text')),
+            include(f'redirect{suffix}'),
+            (rf'[{_nl}]+', Text),
+            (r'\(', Punctuation, 'root/compound'),
+            (r'@+', Punctuation),
+            (rf'((?:for|if|rem)(?:(?=(?:\^[{_nl}]?)?/)|(?:(?!\^)|'
+             rf'(?<=m))(?:(?=\()|{_token_terminator})))({_space}?{_core_token_compound if compound else _core_token}?(?:\^[{_nl}]?)?/(?:\^[{_nl}]?)?\?)',
+             bygroups(Keyword, using(this, state='text')),
+             f'follow{suffix}'),
+            (rf'(goto{_keyword_terminator})({rest}(?:\^[{_nl}]?)?/(?:\^[{_nl}]?)?\?{rest})',
+             bygroups(Keyword, using(this, state='text')),
+             f'follow{suffix}'),
+            (words(('assoc', 'break', 'cd', 'chdir', 'cls', 'color', 'copy',
+                    'date', 'del', 'dir', 'dpath', 'echo', 'endlocal', 'erase',
+                    'exit', 'ftype', 'keys', 'md', 'mkdir', 'mklink', 'move',
+                    'path', 'pause', 'popd', 'prompt', 'pushd', 'rd', 'ren',
+                    'rename', 'rmdir', 'setlocal', 'shift', 'start', 'time',
+                    'title', 'type', 'ver', 'verify', 'vol'),
+                   suffix=_keyword_terminator), Keyword, f'follow{suffix}'),
+            (rf'(call)({_space}?)(:)',
+             bygroups(Keyword, using(this, state='text'), Punctuation),
+             f'call{suffix}'),
+            (rf'call{_keyword_terminator}', Keyword),
+            (rf'(for{_token_terminator}(?!\^))({_space})(/f{_token_terminator})',
+             bygroups(Keyword, using(this, state='text'), Keyword),
+             ('for/f', 'for')),
+            (rf'(for{_token_terminator}(?!\^))({_space})(/l{_token_terminator})',
+             bygroups(Keyword, using(this, state='text'), Keyword),
+             ('for/l', 'for')),
+            (rf'for{_token_terminator}(?!\^)', Keyword, ('for2', 'for')),
+            (rf'(goto{_keyword_terminator})({_space}?)(:?)',
+             bygroups(Keyword, using(this, state='text'), Punctuation),
+             f'label{suffix}'),
+            (rf'(if(?:(?=\()|{_token_terminator})(?!\^))({_space}?)((?:/i{_token_terminator})?)({_space}?)((?:not{_token_terminator})?)({_space}?)',
+             bygroups(Keyword, using(this, state='text'), Keyword,
+                      using(this, state='text'), Keyword,
+                      using(this, state='text')), ('(?', 'if')),
+            (rf'rem(((?=\()|{_token_terminator}){_space}?{_stoken}?.*|{_keyword_terminator}{rest_of_line_compound if compound else rest_of_line})',
+             Comment.Single, f'follow{suffix}'),
+            (rf'(set{_keyword_terminator}){set_space}(/a)',
+             bygroups(Keyword, using(this, state='text'), Keyword),
+             f'arithmetic{suffix}'),
+            (r'(set{}){}((?:/p)?){}((?:(?:(?:\^[{}]?)?[^"{}{}^={}]|'
+             r'\^[{}]?[^"=])+)?)((?:(?:\^[{}]?)?=)?)'.format(_keyword_terminator, set_space, set_space, _nl, _nl, _punct,
+              ')' if compound else '', _nl, _nl),
+             bygroups(Keyword, using(this, state='text'), Keyword,
+                      using(this, state='text'), using(this, state='variable'),
+                      Punctuation),
+             f'follow{suffix}'),
+            default(f'follow{suffix}')
+        ]
+
+    def _make_follow_state(compound, _label=_label,
+                           _label_compound=_label_compound, _nl=_nl,
+                           _space=_space, _start_label=_start_label,
+                           _token=_token, _token_compound=_token_compound,
+                           _ws=_ws):
+        suffix = '/compound' if compound else ''
+        state = []
+        if compound:
+            state.append((r'(?=\))', Text, '#pop'))
+        state += [
+            (rf'{_start_label}([{_ws}]*)({_label_compound if compound else _label})(.*)',
+             bygroups(Text, Punctuation, Text, Name.Label, Comment.Single)),
+            include(f'redirect{suffix}'),
+            (rf'(?=[{_nl}])', Text, '#pop'),
+            (r'\|\|?|&&?', Punctuation, '#pop'),
+            include('text')
+        ]
+        return state
+
+    def _make_arithmetic_state(compound, _nl=_nl, _punct=_punct,
+                               _string=_string, _variable=_variable,
+                               _ws=_ws, _nlws=_nlws):
+        op = r'=+\-*/!~'
+        state = []
+        if compound:
+            state.append((r'(?=\))', Text, '#pop'))
+        state += [
+            (r'0[0-7]+', Number.Oct),
+            (r'0x[\da-f]+', Number.Hex),
+            (r'\d+', Number.Integer),
+            (r'[(),]+', Punctuation),
+            (rf'([{op}]|%|\^\^)+', Operator),
+            (r'({}|{}|(\^[{}]?)?[^(){}%\^"{}{}]|\^[{}]?{})+'.format(_string, _variable, _nl, op, _nlws, _punct, _nlws,
+              r'[^)]' if compound else r'[\w\W]'),
+             using(this, state='variable')),
+            (r'(?=[\x00|&])', Text, '#pop'),
+            include('follow')
+        ]
+        return state
+
+    def _make_call_state(compound, _label=_label,
+                         _label_compound=_label_compound):
+        state = []
+        if compound:
+            state.append((r'(?=\))', Text, '#pop'))
+        state.append((r'(:?)(%s)' % (_label_compound if compound else _label),
+                      bygroups(Punctuation, Name.Label), '#pop'))
+        return state
+
+    def _make_label_state(compound, _label=_label,
+                          _label_compound=_label_compound, _nl=_nl,
+                          _punct=_punct, _string=_string, _variable=_variable):
+        state = []
+        if compound:
+            state.append((r'(?=\))', Text, '#pop'))
+        state.append((r'({}?)((?:{}|{}|\^[{}]?{}|[^"%^{}{}{}])*)'.format(_label_compound if compound else _label, _string,
+                       _variable, _nl, r'[^)]' if compound else r'[\w\W]', _nl,
+                       _punct, r')' if compound else ''),
+                      bygroups(Name.Label, Comment.Single), '#pop'))
+        return state
+
+    def _make_redirect_state(compound,
+                             _core_token_compound=_core_token_compound,
+                             _nl=_nl, _punct=_punct, _stoken=_stoken,
+                             _string=_string, _space=_space,
+                             _variable=_variable, _nlws=_nlws):
+        stoken_compound = (rf'(?:[{_punct}]+|(?:{_string}|{_variable}|{_core_token_compound})+)')
+        return [
+            (rf'((?:(?<=[{_nlws}])\d)?)(>>?&|<&)([{_nlws}]*)(\d)',
+             bygroups(Number.Integer, Punctuation, Text, Number.Integer)),
+            (rf'((?:(?<=[{_nlws}])(?>?|<)({_space}?{stoken_compound if compound else _stoken})',
+             bygroups(Number.Integer, Punctuation, using(this, state='text')))
+        ]
+
+    tokens = {
+        'root': _make_begin_state(False),
+        'follow': _make_follow_state(False),
+        'arithmetic': _make_arithmetic_state(False),
+        'call': _make_call_state(False),
+        'label': _make_label_state(False),
+        'redirect': _make_redirect_state(False),
+        'root/compound': _make_begin_state(True),
+        'follow/compound': _make_follow_state(True),
+        'arithmetic/compound': _make_arithmetic_state(True),
+        'call/compound': _make_call_state(True),
+        'label/compound': _make_label_state(True),
+        'redirect/compound': _make_redirect_state(True),
+        'variable-or-escape': [
+            (_variable, Name.Variable),
+            (rf'%%|\^[{_nl}]?(\^!|[\w\W])', String.Escape)
+        ],
+        'string': [
+            (r'"', String.Double, '#pop'),
+            (_variable, Name.Variable),
+            (r'\^!|%%', String.Escape),
+            (rf'[^"%^{_nl}]+|[%^]', String.Double),
+            default('#pop')
+        ],
+        'sqstring': [
+            include('variable-or-escape'),
+            (r'[^%]+|%', String.Single)
+        ],
+        'bqstring': [
+            include('variable-or-escape'),
+            (r'[^%]+|%', String.Backtick)
+        ],
+        'text': [
+            (r'"', String.Double, 'string'),
+            include('variable-or-escape'),
+            (rf'[^"%^{_nlws}{_punct}\d)]+|.', Text)
+        ],
+        'variable': [
+            (r'"', String.Double, 'string'),
+            include('variable-or-escape'),
+            (rf'[^"%^{_nl}]+|.', Name.Variable)
+        ],
+        'for': [
+            (rf'({_space})(in)({_space})(\()',
+             bygroups(using(this, state='text'), Keyword,
+                      using(this, state='text'), Punctuation), '#pop'),
+            include('follow')
+        ],
+        'for2': [
+            (r'\)', Punctuation),
+            (rf'({_space})(do{_token_terminator})',
+             bygroups(using(this, state='text'), Keyword), '#pop'),
+            (rf'[{_nl}]+', Text),
+            include('follow')
+        ],
+        'for/f': [
+            (rf'(")((?:{_variable}|[^"])*?")([{_nlws}]*)(\))',
+             bygroups(String.Double, using(this, state='string'), Text,
+                      Punctuation)),
+            (r'"', String.Double, ('#pop', 'for2', 'string')),
+            (rf"('(?:%%|{_variable}|[\w\W])*?')([{_nlws}]*)(\))",
+             bygroups(using(this, state='sqstring'), Text, Punctuation)),
+            (rf'(`(?:%%|{_variable}|[\w\W])*?`)([{_nlws}]*)(\))',
+             bygroups(using(this, state='bqstring'), Text, Punctuation)),
+            include('for2')
+        ],
+        'for/l': [
+            (r'-?\d+', Number.Integer),
+            include('for2')
+        ],
+        'if': [
+            (rf'((?:cmdextversion|errorlevel){_token_terminator})({_space})(\d+)',
+             bygroups(Keyword, using(this, state='text'),
+                      Number.Integer), '#pop'),
+            (rf'(defined{_token_terminator})({_space})({_stoken})',
+             bygroups(Keyword, using(this, state='text'),
+                      using(this, state='variable')), '#pop'),
+            (rf'(exist{_token_terminator})({_space}{_stoken})',
+             bygroups(Keyword, using(this, state='text')), '#pop'),
+            (rf'({_number}{_space})({_opword})({_space}{_number})',
+             bygroups(using(this, state='arithmetic'), Operator.Word,
+                      using(this, state='arithmetic')), '#pop'),
+            (_stoken, using(this, state='text'), ('#pop', 'if2')),
+        ],
+        'if2': [
+            (rf'({_space}?)(==)({_space}?{_stoken})',
+             bygroups(using(this, state='text'), Operator,
+                      using(this, state='text')), '#pop'),
+            (rf'({_space})({_opword})({_space}{_stoken})',
+             bygroups(using(this, state='text'), Operator.Word,
+                      using(this, state='text')), '#pop')
+        ],
+        '(?': [
+            (_space, using(this, state='text')),
+            (r'\(', Punctuation, ('#pop', 'else?', 'root/compound')),
+            default('#pop')
+        ],
+        'else?': [
+            (_space, using(this, state='text')),
+            (rf'else{_token_terminator}', Keyword, '#pop'),
+            default('#pop')
+        ]
+    }
+
+
+class MSDOSSessionLexer(ShellSessionBaseLexer):
+    """
+    Lexer for MS DOS shell sessions, i.e. command lines, including a
+    prompt, interspersed with output.
+    """
+
+    name = 'MSDOS Session'
+    aliases = ['doscon']
+    filenames = []
+    mimetypes = []
+    url = 'https://en.wikipedia.org/wiki/MS-DOS'
+    version_added = '2.1'
+    _example = "doscon/session"
+
+    _innerLexerCls = BatchLexer
+    _ps1rgx = re.compile(r'^([^>]*>)(.*\n?)')
+    _ps2 = 'More? '
+
+
+class TcshLexer(RegexLexer):
+    """
+    Lexer for tcsh scripts.
+    """
+
+    name = 'Tcsh'
+    aliases = ['tcsh', 'csh']
+    filenames = ['*.tcsh', '*.csh']
+    mimetypes = ['application/x-csh']
+    url = 'https://www.tcsh.org'
+    version_added = '0.10'
+
+    tokens = {
+        'root': [
+            include('basic'),
+            (r'\$\(', Keyword, 'paren'),
+            (r'\$\{#?', Keyword, 'curly'),
+            (r'`', String.Backtick, 'backticks'),
+            include('data'),
+        ],
+        'basic': [
+            (r'\b(if|endif|else|while|then|foreach|case|default|'
+             r'break|continue|goto|breaksw|end|switch|endsw)\s*\b',
+             Keyword),
+            (r'\b(alias|alloc|bg|bindkey|builtins|bye|caller|cd|chdir|'
+             r'complete|dirs|echo|echotc|eval|exec|exit|fg|filetest|getxvers|'
+             r'glob|getspath|hashstat|history|hup|inlib|jobs|kill|'
+             r'limit|log|login|logout|ls-F|migrate|newgrp|nice|nohup|notify|'
+             r'onintr|popd|printenv|pushd|rehash|repeat|rootnode|popd|pushd|'
+             r'set|shift|sched|setenv|setpath|settc|setty|setxvers|shift|'
+             r'source|stop|suspend|source|suspend|telltc|time|'
+             r'umask|unalias|uncomplete|unhash|universe|unlimit|unset|unsetenv|'
+             r'ver|wait|warp|watchlog|where|which)\s*\b',
+             Name.Builtin),
+            (r'#.*', Comment),
+            (r'\\[\w\W]', String.Escape),
+            (r'(\b\w+)(\s*)(=)', bygroups(Name.Variable, Text, Operator)),
+            (r'[\[\]{}()=]+', Operator),
+            (r'<<\s*(\'?)\\?(\w+)[\w\W]+?\2', String),
+            (r';', Punctuation),
+        ],
+        'data': [
+            (r'(?s)"(\\\\|\\[0-7]+|\\.|[^"\\])*"', String.Double),
+            (r"(?s)'(\\\\|\\[0-7]+|\\.|[^'\\])*'", String.Single),
+            (r'\s+', Text),
+            (r'[^=\s\[\]{}()$"\'`\\;#]+', Text),
+            (r'\d+(?= |\Z)', Number),
+            (r'\$#?(\w+|.)', Name.Variable),
+        ],
+        'curly': [
+            (r'\}', Keyword, '#pop'),
+            (r':-', Keyword),
+            (r'\w+', Name.Variable),
+            (r'[^}:"\'`$]+', Punctuation),
+            (r':', Punctuation),
+            include('root'),
+        ],
+        'paren': [
+            (r'\)', Keyword, '#pop'),
+            include('root'),
+        ],
+        'backticks': [
+            (r'`', String.Backtick, '#pop'),
+            include('root'),
+        ],
+    }
+
+
+class TcshSessionLexer(ShellSessionBaseLexer):
+    """
+    Lexer for Tcsh sessions, i.e. command lines, including a
+    prompt, interspersed with output.
+    """
+
+    name = 'Tcsh Session'
+    aliases = ['tcshcon']
+    filenames = []
+    mimetypes = []
+    url = 'https://www.tcsh.org'
+    version_added = '2.1'
+    _example = "tcshcon/session"
+
+    _innerLexerCls = TcshLexer
+    _ps1rgx = re.compile(r'^([^>]+>)(.*\n?)')
+    _ps2 = '? '
+
+
+class PowerShellLexer(RegexLexer):
+    """
+    For Windows PowerShell code.
+    """
+    name = 'PowerShell'
+    aliases = ['powershell', 'pwsh', 'posh', 'ps1', 'psm1']
+    filenames = ['*.ps1', '*.psm1']
+    mimetypes = ['text/x-powershell']
+    url = 'https://learn.microsoft.com/en-us/powershell'
+    version_added = '1.5'
+
+    flags = re.DOTALL | re.IGNORECASE | re.MULTILINE
+
+    keywords = (
+        'while validateset validaterange validatepattern validatelength '
+        'validatecount until trap switch return ref process param parameter in '
+        'if global: local: function foreach for finally filter end elseif else '
+        'dynamicparam do default continue cmdletbinding break begin alias \\? '
+        '% #script #private #local #global mandatory parametersetname position '
+        'valuefrompipeline valuefrompipelinebypropertyname '
+        'valuefromremainingarguments helpmessage try catch throw').split()
+
+    operators = (
+        'and as band bnot bor bxor casesensitive ccontains ceq cge cgt cle '
+        'clike clt cmatch cne cnotcontains cnotlike cnotmatch contains '
+        'creplace eq exact f file ge gt icontains ieq ige igt ile ilike ilt '
+        'imatch ine inotcontains inotlike inotmatch ireplace is isnot le like '
+        'lt match ne not notcontains notlike notmatch or regex replace '
+        'wildcard').split()
+
+    verbs = (
+        'write where watch wait use update unregister unpublish unprotect '
+        'unlock uninstall undo unblock trace test tee take sync switch '
+        'suspend submit stop step start split sort skip show set send select '
+        'search scroll save revoke resume restore restart resolve resize '
+        'reset request repair rename remove register redo receive read push '
+        'publish protect pop ping out optimize open new move mount merge '
+        'measure lock limit join invoke install initialize import hide group '
+        'grant get format foreach find export expand exit enter enable edit '
+        'dismount disconnect disable deny debug cxnew copy convertto '
+        'convertfrom convert connect confirm compress complete compare close '
+        'clear checkpoint block backup assert approve aggregate add').split()
+
+    aliases_ = (
+        'ac asnp cat cd cfs chdir clc clear clhy cli clp cls clv cnsn '
+        'compare copy cp cpi cpp curl cvpa dbp del diff dir dnsn ebp echo epal '
+        'epcsv epsn erase etsn exsn fc fhx fl foreach ft fw gal gbp gc gci gcm '
+        'gcs gdr ghy gi gjb gl gm gmo gp gps gpv group gsn gsnp gsv gu gv gwmi '
+        'h history icm iex ihy ii ipal ipcsv ipmo ipsn irm ise iwmi iwr kill lp '
+        'ls man md measure mi mount move mp mv nal ndr ni nmo npssc nsn nv ogv '
+        'oh popd ps pushd pwd r rbp rcjb rcsn rd rdr ren ri rjb rm rmdir rmo '
+        'rni rnp rp rsn rsnp rujb rv rvpa rwmi sajb sal saps sasv sbp sc select '
+        'set shcm si sl sleep sls sort sp spjb spps spsv start sujb sv swmi tee '
+        'trcm type wget where wjb write').split()
+
+    commenthelp = (
+        'component description example externalhelp forwardhelpcategory '
+        'forwardhelptargetname functionality inputs link '
+        'notes outputs parameter remotehelprunspace role synopsis').split()
+
+    tokens = {
+        'root': [
+            # we need to count pairs of parentheses for correct highlight
+            # of '$(...)' blocks in strings
+            (r'\(', Punctuation, 'child'),
+            (r'\s+', Text),
+            (r'^(\s*#[#\s]*)(\.(?:{}))([^\n]*$)'.format('|'.join(commenthelp)),
+             bygroups(Comment, String.Doc, Comment)),
+            (r'#[^\n]*?$', Comment),
+            (r'(<|<)#', Comment.Multiline, 'multline'),
+            (r'@"\n', String.Heredoc, 'heredoc-double'),
+            (r"@'\n.*?\n'@", String.Heredoc),
+            # escaped syntax
+            (r'`[\'"$@-]', Punctuation),
+            (r'"', String.Double, 'string'),
+            (r"'([^']|'')*'", String.Single),
+            (r'(\$|@@|@)((global|script|private|env):)?\w+',
+             Name.Variable),
+            (r'({})\b'.format('|'.join(keywords)), Keyword),
+            (r'-({})\b'.format('|'.join(operators)), Operator),
+            (r'({})-[a-z_]\w*\b'.format('|'.join(verbs)), Name.Builtin),
+            (r'({})\s'.format('|'.join(aliases_)), Name.Builtin),
+            (r'\[[a-z_\[][\w. `,\[\]]*\]', Name.Constant),  # .net [type]s
+            (r'-[a-z_]\w*', Name),
+            (r'\w+', Name),
+            (r'[.,;:@{}\[\]$()=+*/\\&%!~?^`|<>-]', Punctuation),
+        ],
+        'child': [
+            (r'\)', Punctuation, '#pop'),
+            include('root'),
+        ],
+        'multline': [
+            (r'[^#&.]+', Comment.Multiline),
+            (r'#(>|>)', Comment.Multiline, '#pop'),
+            (r'\.({})'.format('|'.join(commenthelp)), String.Doc),
+            (r'[#&.]', Comment.Multiline),
+        ],
+        'string': [
+            (r"`[0abfnrtv'\"$`]", String.Escape),
+            (r'[^$`"]+', String.Double),
+            (r'\$\(', Punctuation, 'child'),
+            (r'""', String.Double),
+            (r'[`$]', String.Double),
+            (r'"', String.Double, '#pop'),
+        ],
+        'heredoc-double': [
+            (r'\n"@', String.Heredoc, '#pop'),
+            (r'\$\(', Punctuation, 'child'),
+            (r'[^@\n]+"]', String.Heredoc),
+            (r".", String.Heredoc),
+        ]
+    }
+
+
+class PowerShellSessionLexer(ShellSessionBaseLexer):
+    """
+    Lexer for PowerShell sessions, i.e. command lines, including a
+    prompt, interspersed with output.
+    """
+
+    name = 'PowerShell Session'
+    aliases = ['pwsh-session', 'ps1con']
+    filenames = []
+    mimetypes = []
+    url = 'https://learn.microsoft.com/en-us/powershell'
+    version_added = '2.1'
+    _example = "pwsh-session/session"
+
+    _innerLexerCls = PowerShellLexer
+    _bare_continuation = True
+    _ps1rgx = re.compile(r'^((?:\[[^]]+\]: )?PS[^>]*> ?)(.*\n?)')
+    _ps2 = '> '
+
+
+class FishShellLexer(RegexLexer):
+    """
+    Lexer for Fish shell scripts.
+    """
+
+    name = 'Fish'
+    aliases = ['fish', 'fishshell']
+    filenames = ['*.fish', '*.load']
+    mimetypes = ['application/x-fish']
+    url = 'https://fishshell.com'
+    version_added = '2.1'
+
+    tokens = {
+        'root': [
+            include('basic'),
+            include('data'),
+            include('interp'),
+        ],
+        'interp': [
+            (r'\$\(\(', Keyword, 'math'),
+            (r'\(', Keyword, 'paren'),
+            (r'\$#?(\w+|.)', Name.Variable),
+        ],
+        'basic': [
+            (r'\b(begin|end|if|else|while|break|for|in|return|function|block|'
+             r'case|continue|switch|not|and|or|set|echo|exit|pwd|true|false|'
+             r'cd|count|test)(\s*)\b',
+             bygroups(Keyword, Text)),
+            (r'\b(alias|bg|bind|breakpoint|builtin|command|commandline|'
+             r'complete|contains|dirh|dirs|emit|eval|exec|fg|fish|fish_config|'
+             r'fish_indent|fish_pager|fish_prompt|fish_right_prompt|'
+             r'fish_update_completions|fishd|funced|funcsave|functions|help|'
+             r'history|isatty|jobs|math|mimedb|nextd|open|popd|prevd|psub|'
+             r'pushd|random|read|set_color|source|status|trap|type|ulimit|'
+             r'umask|vared|fc|getopts|hash|kill|printf|time|wait)\s*\b(?!\.)',
+             Name.Builtin),
+            (r'#.*\n', Comment),
+            (r'\\[\w\W]', String.Escape),
+            (r'(\b\w+)(\s*)(=)', bygroups(Name.Variable, Whitespace, Operator)),
+            (r'[\[\]()=]', Operator),
+            (r'<<-?\s*(\'?)\\?(\w+)[\w\W]+?\2', String),
+        ],
+        'data': [
+            (r'(?s)\$?"(\\\\|\\[0-7]+|\\.|[^"\\$])*"', String.Double),
+            (r'"', String.Double, 'string'),
+            (r"(?s)\$'(\\\\|\\[0-7]+|\\.|[^'\\])*'", String.Single),
+            (r"(?s)'.*?'", String.Single),
+            (r';', Punctuation),
+            (r'&|\||\^|<|>', Operator),
+            (r'\s+', Text),
+            (r'\d+(?= |\Z)', Number),
+            (r'[^=\s\[\]{}()$"\'`\\<&|;]+', Text),
+        ],
+        'string': [
+            (r'"', String.Double, '#pop'),
+            (r'(?s)(\\\\|\\[0-7]+|\\.|[^"\\$])+', String.Double),
+            include('interp'),
+        ],
+        'paren': [
+            (r'\)', Keyword, '#pop'),
+            include('root'),
+        ],
+        'math': [
+            (r'\)\)', Keyword, '#pop'),
+            (r'[-+*/%^|&]|\*\*|\|\|', Operator),
+            (r'\d+#\d+', Number),
+            (r'\d+#(?! )', Number),
+            (r'\d+', Number),
+            include('root'),
+        ],
+    }
+
+class ExeclineLexer(RegexLexer):
+    """
+    Lexer for Laurent Bercot's execline language.
+    """
+
+    name = 'execline'
+    aliases = ['execline']
+    filenames = ['*.exec']
+    url = 'https://skarnet.org/software/execline'
+    version_added = '2.7'
+
+    tokens = {
+        'root': [
+            include('basic'),
+            include('data'),
+            include('interp')
+        ],
+        'interp': [
+            (r'\$\{', String.Interpol, 'curly'),
+            (r'\$[\w@#]+', Name.Variable),  # user variable
+            (r'\$', Text),
+        ],
+        'basic': [
+            (r'\b(background|backtick|cd|define|dollarat|elgetopt|'
+             r'elgetpositionals|elglob|emptyenv|envfile|exec|execlineb|'
+             r'exit|export|fdblock|fdclose|fdmove|fdreserve|fdswap|'
+             r'forbacktickx|foreground|forstdin|forx|getcwd|getpid|heredoc|'
+             r'homeof|if|ifelse|ifte|ifthenelse|importas|loopwhilex|'
+             r'multidefine|multisubstitute|pipeline|piperw|posix-cd|'
+             r'redirfd|runblock|shift|trap|tryexec|umask|unexport|wait|'
+             r'withstdinas)\b', Name.Builtin),
+            (r'\A#!.+\n', Comment.Hashbang),
+            (r'#.*\n', Comment.Single),
+            (r'[{}]', Operator)
+        ],
+        'data': [
+            (r'(?s)"(\\.|[^"\\$])*"', String.Double),
+            (r'"', String.Double, 'string'),
+            (r'\s+', Text),
+            (r'[^\s{}$"\\]+', Text)
+        ],
+        'string': [
+            (r'"', String.Double, '#pop'),
+            (r'(?s)(\\\\|\\.|[^"\\$])+', String.Double),
+            include('interp'),
+        ],
+        'curly': [
+            (r'\}', String.Interpol, '#pop'),
+            (r'[\w#@]+', Name.Variable),
+            include('root')
+        ]
+
+    }
+
+    def analyse_text(text):
+        if shebang_matches(text, r'execlineb'):
+            return 1
diff --git a/pygments/lexers/sieve.py b/pygments/lexers/sieve.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc48980c49921d811b5bed122c531cc8213be463
--- /dev/null
+++ b/pygments/lexers/sieve.py
@@ -0,0 +1,78 @@
+"""
+    pygments.lexers.sieve
+    ~~~~~~~~~~~~~~~~~~~~~
+
+    Lexer for Sieve file format.
+
+    https://tools.ietf.org/html/rfc5228
+    https://tools.ietf.org/html/rfc5173
+    https://tools.ietf.org/html/rfc5229
+    https://tools.ietf.org/html/rfc5230
+    https://tools.ietf.org/html/rfc5232
+    https://tools.ietf.org/html/rfc5235
+    https://tools.ietf.org/html/rfc5429
+    https://tools.ietf.org/html/rfc8580
+
+    :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS.
+    :license: BSD, see LICENSE for details.
+"""
+
+from pygments.lexer import RegexLexer, bygroups
+from pygments.token import Comment, Name, Literal, String, Text, Punctuation, \
+    Keyword
+
+__all__ = ["SieveLexer"]
+
+
+class SieveLexer(RegexLexer):
+    """
+    Lexer for sieve format.
+    """
+    name = 'Sieve'
+    filenames = ['*.siv', '*.sieve']
+    aliases = ['sieve']
+    url = 'https://en.wikipedia.org/wiki/Sieve_(mail_filtering_language)'
+    version_added = '2.6'
+
+    tokens = {
+        'root': [
+            (r'\s+', Text),
+            (r'[();,{}\[\]]', Punctuation),
+            # import:
+            (r'(?i)require',
+             Keyword.Namespace),
+            # tags:
+            (r'(?i)(:)(addresses|all|contains|content|create|copy|comparator|'
+             r'count|days|detail|domain|fcc|flags|from|handle|importance|is|'
+             r'localpart|length|lowerfirst|lower|matches|message|mime|options|'
+             r'over|percent|quotewildcard|raw|regex|specialuse|subject|text|'
+             r'under|upperfirst|upper|value)',
+             bygroups(Name.Tag, Name.Tag)),
+            # tokens:
+            (r'(?i)(address|addflag|allof|anyof|body|discard|elsif|else|envelope|'
+             r'ereject|exists|false|fileinto|if|hasflag|header|keep|'
+             r'notify_method_capability|notify|not|redirect|reject|removeflag|'
+             r'setflag|size|spamtest|stop|string|true|vacation|virustest)',
+             Name.Builtin),
+            (r'(?i)set',
+             Keyword.Declaration),
+            # number:
+            (r'([0-9.]+)([kmgKMG])?',
+             bygroups(Literal.Number, Literal.Number)),
+            # comment:
+            (r'#.*$',
+             Comment.Single),
+            (r'/\*.*\*/',
+             Comment.Multiline),
+            # string:
+            (r'"[^"]*?"',
+             String),
+            # text block:
+            (r'text:',
+             Name.Tag, 'text'),
+        ],
+        'text': [
+            (r'[^.].*?\n', String),
+            (r'^\.', Punctuation, "#pop"),
+        ]
+    }
diff --git a/pygments/lexers/slash.py b/pygments/lexers/slash.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c439d0db0fd51d6870f1294b2eb276aa64e8a6c
--- /dev/null
+++ b/pygments/lexers/slash.py
@@ -0,0 +1,183 @@
+"""
+    pygments.lexers.slash
+    ~~~~~~~~~~~~~~~~~~~~~
+
+    Lexer for the Slash programming language.
+
+    :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS.
+    :license: BSD, see LICENSE for details.
+"""
+
+from pygments.lexer import ExtendedRegexLexer, bygroups, DelegatingLexer
+from pygments.token import Name, Number, String, Comment, Punctuation, \
+    Other, Keyword, Operator, Whitespace
+
+__all__ = ['SlashLexer']
+
+
+class SlashLanguageLexer(ExtendedRegexLexer):
+    _nkw = r'(?=[^a-zA-Z_0-9])'
+
+    def move_state(new_state):
+        return ("#pop", new_state)
+
+    def right_angle_bracket(lexer, match, ctx):
+        if len(ctx.stack) > 1 and ctx.stack[-2] == "string":
+            ctx.stack.pop()
+        yield match.start(), String.Interpol, '}'
+        ctx.pos = match.end()
+        pass
+
+    tokens = {
+        "root": [
+            (r"<%=",        Comment.Preproc,    move_state("slash")),
+            (r"<%!!",       Comment.Preproc,    move_state("slash")),
+            (r"<%#.*?%>",   Comment.Multiline),
+            (r"<%",         Comment.Preproc,    move_state("slash")),
+            (r".|\n",       Other),
+        ],
+        "string": [
+            (r"\\",         String.Escape,      move_state("string_e")),
+            (r"\"",         String,             move_state("slash")),
+            (r"#\{",        String.Interpol,    "slash"),
+            (r'.|\n',       String),
+        ],
+        "string_e": [
+            (r'n',                  String.Escape,      move_state("string")),
+            (r't',                  String.Escape,      move_state("string")),
+            (r'r',                  String.Escape,      move_state("string")),
+            (r'e',                  String.Escape,      move_state("string")),
+            (r'x[a-fA-F0-9]{2}',    String.Escape,      move_state("string")),
+            (r'.',                  String.Escape,      move_state("string")),
+        ],
+        "regexp": [
+            (r'}[a-z]*',            String.Regex,       move_state("slash")),
+            (r'\\(.|\n)',           String.Regex),
+            (r'{',                  String.Regex,       "regexp_r"),
+            (r'.|\n',               String.Regex),
+        ],
+        "regexp_r": [
+            (r'}[a-z]*',            String.Regex,       "#pop"),
+            (r'\\(.|\n)',           String.Regex),
+            (r'{',                  String.Regex,       "regexp_r"),
+        ],
+        "slash": [
+            (r"%>",                     Comment.Preproc,    move_state("root")),
+            (r"\"",                     String,             move_state("string")),
+            (r"'[a-zA-Z0-9_]+",         String),
+            (r'%r{',                    String.Regex,       move_state("regexp")),
+            (r'/\*.*?\*/',              Comment.Multiline),
+            (r"(#|//).*?\n",            Comment.Single),
+            (r'-?[0-9]+e[+-]?[0-9]+',   Number.Float),
+            (r'-?[0-9]+\.[0-9]+(e[+-]?[0-9]+)?', Number.Float),
+            (r'-?[0-9]+',               Number.Integer),
+            (r'nil'+_nkw,               Name.Builtin),
+            (r'true'+_nkw,              Name.Builtin),
+            (r'false'+_nkw,             Name.Builtin),
+            (r'self'+_nkw,              Name.Builtin),
+            (r'(class)(\s+)([A-Z][a-zA-Z0-9_\']*)',
+                bygroups(Keyword, Whitespace, Name.Class)),
+            (r'class'+_nkw,             Keyword),
+            (r'extends'+_nkw,           Keyword),
+            (r'(def)(\s+)(self)(\s*)(\.)(\s*)([a-z_][a-zA-Z0-9_\']*=?|<<|>>|==|<=>|<=|<|>=|>|\+|-(self)?|~(self)?|\*|/|%|^|&&|&|\||\[\]=?)',
+                bygroups(Keyword, Whitespace, Name.Builtin, Whitespace, Punctuation, Whitespace, Name.Function)),
+            (r'(def)(\s+)([a-z_][a-zA-Z0-9_\']*=?|<<|>>|==|<=>|<=|<|>=|>|\+|-(self)?|~(self)?|\*|/|%|^|&&|&|\||\[\]=?)',
+                bygroups(Keyword, Whitespace, Name.Function)),
+            (r'def'+_nkw,               Keyword),
+            (r'if'+_nkw,                Keyword),
+            (r'elsif'+_nkw,             Keyword),
+            (r'else'+_nkw,              Keyword),
+            (r'unless'+_nkw,            Keyword),
+            (r'for'+_nkw,               Keyword),
+            (r'in'+_nkw,                Keyword),
+            (r'while'+_nkw,             Keyword),
+            (r'until'+_nkw,             Keyword),
+            (r'and'+_nkw,               Keyword),
+            (r'or'+_nkw,                Keyword),
+            (r'not'+_nkw,               Keyword),
+            (r'lambda'+_nkw,            Keyword),
+            (r'try'+_nkw,               Keyword),
+            (r'catch'+_nkw,             Keyword),
+            (r'return'+_nkw,            Keyword),
+            (r'next'+_nkw,              Keyword),
+            (r'last'+_nkw,              Keyword),
+            (r'throw'+_nkw,             Keyword),
+            (r'use'+_nkw,               Keyword),
+            (r'switch'+_nkw,            Keyword),
+            (r'\\',                     Keyword),
+            (r'λ',                      Keyword),
+            (r'__FILE__'+_nkw,          Name.Builtin.Pseudo),
+            (r'__LINE__'+_nkw,          Name.Builtin.Pseudo),
+            (r'[A-Z][a-zA-Z0-9_\']*'+_nkw, Name.Constant),
+            (r'[a-z_][a-zA-Z0-9_\']*'+_nkw, Name),
+            (r'@[a-z_][a-zA-Z0-9_\']*'+_nkw, Name.Variable.Instance),
+            (r'@@[a-z_][a-zA-Z0-9_\']*'+_nkw, Name.Variable.Class),
+            (r'\(',                     Punctuation),
+            (r'\)',                     Punctuation),
+            (r'\[',                     Punctuation),
+            (r'\]',                     Punctuation),
+            (r'\{',                     Punctuation),
+            (r'\}',                     right_angle_bracket),
+            (r';',                      Punctuation),
+            (r',',                      Punctuation),
+            (r'<<=',                    Operator),
+            (r'>>=',                    Operator),
+            (r'<<',                     Operator),
+            (r'>>',                     Operator),
+            (r'==',                     Operator),
+            (r'!=',                     Operator),
+            (r'=>',                     Operator),
+            (r'=',                      Operator),
+            (r'<=>',                    Operator),
+            (r'<=',                     Operator),
+            (r'>=',                     Operator),
+            (r'<',                      Operator),
+            (r'>',                      Operator),
+            (r'\+\+',                   Operator),
+            (r'\+=',                    Operator),
+            (r'-=',                     Operator),
+            (r'\*\*=',                  Operator),
+            (r'\*=',                    Operator),
+            (r'\*\*',                   Operator),
+            (r'\*',                     Operator),
+            (r'/=',                     Operator),
+            (r'\+',                     Operator),
+            (r'-',                      Operator),
+            (r'/',                      Operator),
+            (r'%=',                     Operator),
+            (r'%',                      Operator),
+            (r'^=',                     Operator),
+            (r'&&=',                    Operator),
+            (r'&=',                     Operator),
+            (r'&&',                     Operator),
+            (r'&',                      Operator),
+            (r'\|\|=',                  Operator),
+            (r'\|=',                    Operator),
+            (r'\|\|',                   Operator),
+            (r'\|',                     Operator),
+            (r'!',                      Operator),
+            (r'\.\.\.',                 Operator),
+            (r'\.\.',                   Operator),
+            (r'\.',                     Operator),
+            (r'::',                     Operator),
+            (r':',                      Operator),
+            (r'(\s|\n)+',               Whitespace),
+            (r'[a-z_][a-zA-Z0-9_\']*',  Name.Variable),
+        ],
+    }
+
+
+class SlashLexer(DelegatingLexer):
+    """
+    Lexer for the Slash programming language.
+    """
+
+    name = 'Slash'
+    aliases = ['slash']
+    filenames = ['*.sla']
+    url = 'https://github.com/arturadib/Slash-A'
+    version_added = '2.4'
+
+    def __init__(self, **options):
+        from pygments.lexers.web import HtmlLexer
+        super().__init__(HtmlLexer, SlashLanguageLexer, **options)
diff --git a/pygments/lexers/smalltalk.py b/pygments/lexers/smalltalk.py
new file mode 100644
index 0000000000000000000000000000000000000000..674b7b4b345a4b968f080bcd1552979201da6164
--- /dev/null
+++ b/pygments/lexers/smalltalk.py
@@ -0,0 +1,194 @@
+"""
+    pygments.lexers.smalltalk
+    ~~~~~~~~~~~~~~~~~~~~~~~~~
+
+    Lexers for Smalltalk and related languages.
+
+    :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS.
+    :license: BSD, see LICENSE for details.
+"""
+
+from pygments.lexer import RegexLexer, include, bygroups, default
+from pygments.token import Text, Comment, Operator, Keyword, Name, String, \
+    Number, Punctuation
+
+__all__ = ['SmalltalkLexer', 'NewspeakLexer']
+
+
+class SmalltalkLexer(RegexLexer):
+    """
+    For Smalltalk syntax.
+    Contributed by Stefan Matthias Aust.
+    Rewritten by Nils Winter.
+    """
+    name = 'Smalltalk'
+    url = 'http://www.smalltalk.org/'
+    filenames = ['*.st']
+    aliases = ['smalltalk', 'squeak', 'st']
+    mimetypes = ['text/x-smalltalk']
+    version_added = '0.10'
+
+    tokens = {
+        'root': [
+            (r'(<)(\w+:)(.*?)(>)', bygroups(Text, Keyword, Text, Text)),
+            include('squeak fileout'),
+            include('whitespaces'),
+            include('method definition'),
+            (r'(\|)([\w\s]*)(\|)', bygroups(Operator, Name.Variable, Operator)),
+            include('objects'),
+            (r'\^|\:=|\_', Operator),
+            # temporaries
+            (r'[\]({}.;!]', Text),
+        ],
+        'method definition': [
+            # Not perfect can't allow whitespaces at the beginning and the
+            # without breaking everything
+            (r'([a-zA-Z]+\w*:)(\s*)(\w+)',
+             bygroups(Name.Function, Text, Name.Variable)),
+            (r'^(\b[a-zA-Z]+\w*\b)(\s*)$', bygroups(Name.Function, Text)),
+            (r'^([-+*/\\~<>=|&!?,@%]+)(\s*)(\w+)(\s*)$',
+             bygroups(Name.Function, Text, Name.Variable, Text)),
+        ],
+        'blockvariables': [
+            include('whitespaces'),
+            (r'(:)(\s*)(\w+)',
+             bygroups(Operator, Text, Name.Variable)),
+            (r'\|', Operator, '#pop'),
+            default('#pop'),  # else pop
+        ],
+        'literals': [
+            (r"'(''|[^'])*'", String, 'afterobject'),
+            (r'\$.', String.Char, 'afterobject'),
+            (r'#\(', String.Symbol, 'parenth'),
+            (r'\)', Text, 'afterobject'),
+            (r'(\d+r)?-?\d+(\.\d+)?(e-?\d+)?', Number, 'afterobject'),
+        ],
+        '_parenth_helper': [
+            include('whitespaces'),
+            (r'(\d+r)?-?\d+(\.\d+)?(e-?\d+)?', Number),
+            (r'[-+*/\\~<>=|&#!?,@%\w:]+', String.Symbol),
+            # literals
+            (r"'(''|[^'])*'", String),
+            (r'\$.', String.Char),
+            (r'#*\(', String.Symbol, 'inner_parenth'),
+        ],
+        'parenth': [
+            # This state is a bit tricky since
+            # we can't just pop this state
+            (r'\)', String.Symbol, ('root', 'afterobject')),
+            include('_parenth_helper'),
+        ],
+        'inner_parenth': [
+            (r'\)', String.Symbol, '#pop'),
+            include('_parenth_helper'),
+        ],
+        'whitespaces': [
+            # skip whitespace and comments
+            (r'\s+', Text),
+            (r'"(""|[^"])*"', Comment),
+        ],
+        'objects': [
+            (r'\[', Text, 'blockvariables'),
+            (r'\]', Text, 'afterobject'),
+            (r'\b(self|super|true|false|nil|thisContext)\b',
+             Name.Builtin.Pseudo, 'afterobject'),
+            (r'\b[A-Z]\w*(?!:)\b', Name.Class, 'afterobject'),
+            (r'\b[a-z]\w*(?!:)\b', Name.Variable, 'afterobject'),
+            (r'#("(""|[^"])*"|[-+*/\\~<>=|&!?,@%]+|[\w:]+)',
+             String.Symbol, 'afterobject'),
+            include('literals'),
+        ],
+        'afterobject': [
+            (r'! !$', Keyword, '#pop'),  # squeak chunk delimiter
+            include('whitespaces'),
+            (r'\b(ifTrue:|ifFalse:|whileTrue:|whileFalse:|timesRepeat:)',
+             Name.Builtin, '#pop'),
+            (r'\b(new\b(?!:))', Name.Builtin),
+            (r'\:=|\_', Operator, '#pop'),
+            (r'\b[a-zA-Z]+\w*:', Name.Function, '#pop'),
+            (r'\b[a-zA-Z]+\w*', Name.Function),
+            (r'\w+:?|[-+*/\\~<>=|&!?,@%]+', Name.Function, '#pop'),
+            (r'\.', Punctuation, '#pop'),
+            (r';', Punctuation),
+            (r'[\])}]', Text),
+            (r'[\[({]', Text, '#pop'),
+        ],
+        'squeak fileout': [
+            # Squeak fileout format (optional)
+            (r'^"(""|[^"])*"!', Keyword),
+            (r"^'(''|[^'])*'!", Keyword),
+            (r'^(!)(\w+)( commentStamp: )(.*?)( prior: .*?!\n)(.*?)(!)',
+                bygroups(Keyword, Name.Class, Keyword, String, Keyword, Text, Keyword)),
+            (r"^(!)(\w+(?: class)?)( methodsFor: )('(?:''|[^'])*')(.*?!)",
+                bygroups(Keyword, Name.Class, Keyword, String, Keyword)),
+            (r'^(\w+)( subclass: )(#\w+)'
+             r'(\s+instanceVariableNames: )(.*?)'
+             r'(\s+classVariableNames: )(.*?)'
+             r'(\s+poolDictionaries: )(.*?)'
+             r'(\s+category: )(.*?)(!)',
+                bygroups(Name.Class, Keyword, String.Symbol, Keyword, String, Keyword,
+                         String, Keyword, String, Keyword, String, Keyword)),
+            (r'^(\w+(?: class)?)(\s+instanceVariableNames: )(.*?)(!)',
+                bygroups(Name.Class, Keyword, String, Keyword)),
+            (r'(!\n)(\].*)(! !)$', bygroups(Keyword, Text, Keyword)),
+            (r'! !$', Keyword),
+        ],
+    }
+
+
+class NewspeakLexer(RegexLexer):
+    """
+    For Newspeak syntax.
+    """
+    name = 'Newspeak'
+    url = 'http://newspeaklanguage.org/'
+    filenames = ['*.ns2']
+    aliases = ['newspeak', ]
+    mimetypes = ['text/x-newspeak']
+    version_added = '1.1'
+
+    tokens = {
+        'root': [
+            (r'\b(Newsqueak2)\b', Keyword.Declaration),
+            (r"'[^']*'", String),
+            (r'\b(class)(\s+)(\w+)(\s*)',
+             bygroups(Keyword.Declaration, Text, Name.Class, Text)),
+            (r'\b(mixin|self|super|private|public|protected|nil|true|false)\b',
+             Keyword),
+            (r'(\w+\:)(\s*)([a-zA-Z_]\w+)',
+             bygroups(Name.Function, Text, Name.Variable)),
+            (r'(\w+)(\s*)(=)',
+             bygroups(Name.Attribute, Text, Operator)),
+            (r'<\w+>', Comment.Special),
+            include('expressionstat'),
+            include('whitespace')
+        ],
+
+        'expressionstat': [
+            (r'(\d+\.\d*|\.\d+|\d+[fF])[fF]?', Number.Float),
+            (r'\d+', Number.Integer),
+            (r':\w+', Name.Variable),
+            (r'(\w+)(::)', bygroups(Name.Variable, Operator)),
+            (r'\w+:', Name.Function),
+            (r'\w+', Name.Variable),
+            (r'\(|\)', Punctuation),
+            (r'\[|\]', Punctuation),
+            (r'\{|\}', Punctuation),
+
+            (r'(\^|\+|\/|~|\*|<|>|=|@|%|\||&|\?|!|,|-|:)', Operator),
+            (r'\.|;', Punctuation),
+            include('whitespace'),
+            include('literals'),
+        ],
+        'literals': [
+            (r'\$.', String),
+            (r"'[^']*'", String),
+            (r"#'[^']*'", String.Symbol),
+            (r"#\w+:?", String.Symbol),
+            (r"#(\+|\/|~|\*|<|>|=|@|%|\||&|\?|!|,|-)+", String.Symbol)
+        ],
+        'whitespace': [
+            (r'\s+', Text),
+            (r'"[^"]*"', Comment)
+        ],
+    }
diff --git a/pygments/lexers/smithy.py b/pygments/lexers/smithy.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd479aec40c596032921b5cd24f249052058e04e
--- /dev/null
+++ b/pygments/lexers/smithy.py
@@ -0,0 +1,77 @@
+"""
+    pygments.lexers.smithy
+    ~~~~~~~~~~~~~~~~~~~~~~
+
+    Lexers for the Smithy IDL.
+
+    :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS.
+    :license: BSD, see LICENSE for details.
+"""
+
+from pygments.lexer import RegexLexer, bygroups, words
+from pygments.token import Text, Comment, Keyword, Name, String, \
+    Number, Whitespace, Punctuation
+
+__all__ = ['SmithyLexer']
+
+
+class SmithyLexer(RegexLexer):
+    """
+    For Smithy IDL
+    """
+    name = 'Smithy'
+    url = 'https://awslabs.github.io/smithy/'
+    filenames = ['*.smithy']
+    aliases = ['smithy']
+    version_added = '2.10'
+
+    unquoted = r'[A-Za-z0-9_\.#$-]+'
+    identifier = r"[A-Za-z0-9_\.#$-]+"
+
+    simple_shapes = (
+        'use', 'byte', 'short', 'integer', 'long', 'float', 'document',
+        'double', 'bigInteger', 'bigDecimal', 'boolean', 'blob', 'string',
+        'timestamp',
+    )
+
+    aggregate_shapes = (
+       'apply', 'list', 'map', 'set', 'structure', 'union', 'resource',
+       'operation', 'service', 'trait'
+    )
+
+    tokens = {
+        'root': [
+            (r'///.*$', Comment.Multiline),
+            (r'//.*$', Comment),
+            (r'@[0-9a-zA-Z\.#-]*', Name.Decorator),
+            (r'(=)', Name.Decorator),
+            (r'^(\$version)(:)(.+)',
+                bygroups(Keyword.Declaration, Name.Decorator, Name.Class)),
+            (r'^(namespace)(\s+' + identifier + r')\b',
+                bygroups(Keyword.Declaration, Name.Class)),
+            (words(simple_shapes,
+                   prefix=r'^', suffix=r'(\s+' + identifier + r')\b'),
+                bygroups(Keyword.Declaration, Name.Class)),
+            (words(aggregate_shapes,
+                   prefix=r'^', suffix=r'(\s+' + identifier + r')'),
+                bygroups(Keyword.Declaration, Name.Class)),
+            (r'^(metadata)(\s+)((?:\S+)|(?:\"[^"]+\"))(\s*)(=)',
+                bygroups(Keyword.Declaration, Whitespace, Name.Class,
+                         Whitespace, Name.Decorator)),
+            (r"(true|false|null)", Keyword.Constant),
+            (r"(-?(?:0|[1-9]\d*)(?:\.\d+)?(?:[eE][+-]?\d+)?)", Number),
+            (identifier + ":", Name.Label),
+            (identifier, Name.Variable.Class),
+            (r'\[', Text, "#push"),
+            (r'\]', Text, "#pop"),
+            (r'\(', Text, "#push"),
+            (r'\)', Text, "#pop"),
+            (r'\{', Text, "#push"),
+            (r'\}', Text, "#pop"),
+            (r'"{3}(\\\\|\n|\\")*"{3}', String.Doc),
+            (r'"(\\\\|\n|\\"|[^"])*"', String.Double),
+            (r"'(\\\\|\n|\\'|[^'])*'", String.Single),
+            (r'[:,]+', Punctuation),
+            (r'\s+', Whitespace),
+        ]
+    }
diff --git a/pygments/lexers/smv.py b/pygments/lexers/smv.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf97b52a5ce1953909b43b66255f0324fd8f051e
--- /dev/null
+++ b/pygments/lexers/smv.py
@@ -0,0 +1,78 @@
+"""
+    pygments.lexers.smv
+    ~~~~~~~~~~~~~~~~~~~
+
+    Lexers for the SMV languages.
+
+    :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS.
+    :license: BSD, see LICENSE for details.
+"""
+
+from pygments.lexer import RegexLexer, words
+from pygments.token import Comment, Keyword, Name, Number, Operator, \
+    Punctuation, Text
+
+__all__ = ['NuSMVLexer']
+
+
+class NuSMVLexer(RegexLexer):
+    """
+    Lexer for the NuSMV language.
+    """
+
+    name = 'NuSMV'
+    aliases = ['nusmv']
+    filenames = ['*.smv']
+    mimetypes = []
+    url = 'https://nusmv.fbk.eu'
+    version_added = '2.2'
+
+    tokens = {
+        'root': [
+            # Comments
+            (r'(?s)\/\-\-.*?\-\-/', Comment),
+            (r'--.*\n', Comment),
+
+            # Reserved
+            (words(('MODULE', 'DEFINE', 'MDEFINE', 'CONSTANTS', 'VAR', 'IVAR',
+                    'FROZENVAR', 'INIT', 'TRANS', 'INVAR', 'SPEC', 'CTLSPEC',
+                    'LTLSPEC', 'PSLSPEC', 'COMPUTE', 'NAME', 'INVARSPEC',
+                    'FAIRNESS', 'JUSTICE', 'COMPASSION', 'ISA', 'ASSIGN',
+                    'CONSTRAINT', 'SIMPWFF', 'CTLWFF', 'LTLWFF', 'PSLWFF',
+                    'COMPWFF', 'IN', 'MIN', 'MAX', 'MIRROR', 'PRED',
+                    'PREDICATES'), suffix=r'(?![\w$#-])'),
+             Keyword.Declaration),
+            (r'process(?![\w$#-])', Keyword),
+            (words(('array', 'of', 'boolean', 'integer', 'real', 'word'),
+                   suffix=r'(?![\w$#-])'), Keyword.Type),
+            (words(('case', 'esac'), suffix=r'(?![\w$#-])'), Keyword),
+            (words(('word1', 'bool', 'signed', 'unsigned', 'extend', 'resize',
+                    'sizeof', 'uwconst', 'swconst', 'init', 'self', 'count',
+                    'abs', 'max', 'min'), suffix=r'(?![\w$#-])'),
+             Name.Builtin),
+            (words(('EX', 'AX', 'EF', 'AF', 'EG', 'AG', 'E', 'F', 'O', 'G',
+                    'H', 'X', 'Y', 'Z', 'A', 'U', 'S', 'V', 'T', 'BU', 'EBF',
+                    'ABF', 'EBG', 'ABG', 'next', 'mod', 'union', 'in', 'xor',
+                    'xnor'), suffix=r'(?![\w$#-])'),
+                Operator.Word),
+            (words(('TRUE', 'FALSE'), suffix=r'(?![\w$#-])'), Keyword.Constant),
+
+            # Names
+            (r'[a-zA-Z_][\w$#-]*', Name.Variable),
+
+            # Operators
+            (r':=', Operator),
+            (r'[-&|+*/<>!=]', Operator),
+
+            # Literals
+            (r'\-?\d+\b', Number.Integer),
+            (r'0[su][bB]\d*_[01_]+', Number.Bin),
+            (r'0[su][oO]\d*_[0-7_]+', Number.Oct),
+            (r'0[su][dD]\d*_[\d_]+', Number.Decimal),
+            (r'0[su][hH]\d*_[\da-fA-F_]+', Number.Hex),
+
+            # Whitespace, punctuation and the rest
+            (r'\s+', Text.Whitespace),
+            (r'[()\[\]{};?:.,]', Punctuation),
+        ],
+    }
diff --git a/pygments/lexers/snobol.py b/pygments/lexers/snobol.py
new file mode 100644
index 0000000000000000000000000000000000000000..bab51e9b1107c5e295febaa842c68ac93ee1b1fc
--- /dev/null
+++ b/pygments/lexers/snobol.py
@@ -0,0 +1,82 @@
+"""
+    pygments.lexers.snobol
+    ~~~~~~~~~~~~~~~~~~~~~~
+
+    Lexers for the SNOBOL language.
+
+    :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS.
+    :license: BSD, see LICENSE for details.
+"""
+
+from pygments.lexer import RegexLexer, bygroups
+from pygments.token import Text, Comment, Operator, Keyword, Name, String, \
+    Number, Punctuation
+
+__all__ = ['SnobolLexer']
+
+
+class SnobolLexer(RegexLexer):
+    """
+    Lexer for the SNOBOL4 programming language.
+
+    Recognizes the common ASCII equivalents of the original SNOBOL4 operators.
+    Does not require spaces around binary operators.
+    """
+
+    name = "Snobol"
+    aliases = ["snobol"]
+    filenames = ['*.snobol']
+    mimetypes = ['text/x-snobol']
+    url = 'https://www.regressive.org/snobol4'
+    version_added = '1.5'
+
+    tokens = {
+        # root state, start of line
+        # comments, continuation lines, and directives start in column 1
+        # as do labels
+        'root': [
+            (r'\*.*\n', Comment),
+            (r'[+.] ', Punctuation, 'statement'),
+            (r'-.*\n', Comment),
+            (r'END\s*\n', Name.Label, 'heredoc'),
+            (r'[A-Za-z$][\w$]*', Name.Label, 'statement'),
+            (r'\s+', Text, 'statement'),
+        ],
+        # statement state, line after continuation or label
+        'statement': [
+            (r'\s*\n', Text, '#pop'),
+            (r'\s+', Text),
+            (r'(?<=[^\w.])(LT|LE|EQ|NE|GE|GT|INTEGER|IDENT|DIFFER|LGT|SIZE|'
+             r'REPLACE|TRIM|DUPL|REMDR|DATE|TIME|EVAL|APPLY|OPSYN|LOAD|UNLOAD|'
+             r'LEN|SPAN|BREAK|ANY|NOTANY|TAB|RTAB|REM|POS|RPOS|FAIL|FENCE|'
+             r'ABORT|ARB|ARBNO|BAL|SUCCEED|INPUT|OUTPUT|TERMINAL)(?=[^\w.])',
+             Name.Builtin),
+            (r'[A-Za-z][\w.]*', Name),
+            # ASCII equivalents of original operators
+            # | for the EBCDIC equivalent, ! likewise
+            # \ for EBCDIC negation
+            (r'\*\*|[?$.!%*/#+\-@|&\\=]', Operator),
+            (r'"[^"]*"', String),
+            (r"'[^']*'", String),
+            # Accept SPITBOL syntax for real numbers
+            # as well as Macro SNOBOL4
+            (r'[0-9]+(?=[^.EeDd])', Number.Integer),
+            (r'[0-9]+(\.[0-9]*)?([EDed][-+]?[0-9]+)?', Number.Float),
+            # Goto
+            (r':', Punctuation, 'goto'),
+            (r'[()<>,;]', Punctuation),
+        ],
+        # Goto block
+        'goto': [
+            (r'\s*\n', Text, "#pop:2"),
+            (r'\s+', Text),
+            (r'F|S', Keyword),
+            (r'(\()([A-Za-z][\w.]*)(\))',
+             bygroups(Punctuation, Name.Label, Punctuation))
+        ],
+        # everything after the END statement is basically one
+        # big heredoc.
+        'heredoc': [
+            (r'.*\n', String.Heredoc)
+        ]
+    }
diff --git a/pygments/lexers/solidity.py b/pygments/lexers/solidity.py
new file mode 100644
index 0000000000000000000000000000000000000000..3182a148a3830c6b55e6b4f194010121dfe0abc4
--- /dev/null
+++ b/pygments/lexers/solidity.py
@@ -0,0 +1,87 @@
+"""
+    pygments.lexers.solidity
+    ~~~~~~~~~~~~~~~~~~~~~~~~
+
+    Lexers for Solidity.
+
+    :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS.
+    :license: BSD, see LICENSE for details.
+"""
+
+from pygments.lexer import RegexLexer, bygroups, include, words
+from pygments.token import Text, Comment, Operator, Keyword, Name, String, \
+    Number, Punctuation, Whitespace
+
+__all__ = ['SolidityLexer']
+
+
+class SolidityLexer(RegexLexer):
+    """
+    For Solidity source code.
+    """
+
+    name = 'Solidity'
+    aliases = ['solidity']
+    filenames = ['*.sol']
+    mimetypes = []
+    url = 'https://soliditylang.org'
+    version_added = '2.5'
+
+    datatype = (
+        r'\b(address|bool|(?:(?:bytes|hash|int|string|uint)(?:8|16|24|32|40|48|56|64'
+        r'|72|80|88|96|104|112|120|128|136|144|152|160|168|176|184|192|200|208'
+        r'|216|224|232|240|248|256)?))\b'
+    )
+
+    tokens = {
+        'root': [
+            include('whitespace'),
+            include('comments'),
+            (r'\bpragma\s+solidity\b', Keyword, 'pragma'),
+            (r'\b(contract)(\s+)([a-zA-Z_]\w*)',
+             bygroups(Keyword, Whitespace, Name.Entity)),
+            (datatype + r'(\s+)((?:external|public|internal|private)\s+)?' +
+             r'([a-zA-Z_]\w*)',
+             bygroups(Keyword.Type, Whitespace, Keyword, Name.Variable)),
+            (r'\b(enum|event|function|struct)(\s+)([a-zA-Z_]\w*)',
+             bygroups(Keyword.Type, Whitespace, Name.Variable)),
+            (r'\b(msg|block|tx)\.([A-Za-z_][a-zA-Z0-9_]*)\b', Keyword),
+            (words((
+                'block', 'break', 'constant', 'constructor', 'continue',
+                'contract', 'do', 'else', 'external', 'false', 'for',
+                'function', 'if', 'import', 'inherited', 'internal', 'is',
+                'library', 'mapping', 'memory', 'modifier', 'msg', 'new',
+                'payable', 'private', 'public', 'require', 'return',
+                'returns', 'struct', 'suicide', 'throw', 'this', 'true',
+                'tx', 'var', 'while'), prefix=r'\b', suffix=r'\b'),
+             Keyword.Type),
+            (words(('keccak256',), prefix=r'\b', suffix=r'\b'), Name.Builtin),
+            (datatype, Keyword.Type),
+            include('constants'),
+            (r'[a-zA-Z_]\w*', Text),
+            (r'[~!%^&*+=|?:<>/-]', Operator),
+            (r'[.;{}(),\[\]]', Punctuation)
+        ],
+        'comments': [
+            (r'//(\n|[\w\W]*?[^\\]\n)', Comment.Single),
+            (r'/(\\\n)?[*][\w\W]*?[*](\\\n)?/', Comment.Multiline),
+            (r'/(\\\n)?[*][\w\W]*', Comment.Multiline)
+        ],
+        'constants': [
+            (r'("(\\"|.)*?")', String.Double),
+            (r"('(\\'|.)*?')", String.Single),
+            (r'\b0[xX][0-9a-fA-F]+\b', Number.Hex),
+            (r'\b\d+\b', Number.Decimal),
+        ],
+        'pragma': [
+            include('whitespace'),
+            include('comments'),
+            (r'(\^|>=|<)(\s*)(\d+\.\d+\.\d+)',
+             bygroups(Operator, Whitespace, Keyword)),
+            (r';', Punctuation, '#pop')
+        ],
+        'whitespace': [
+            (r'\s+', Whitespace),
+            (r'\n', Whitespace)
+        ]
+    }
diff --git a/pygments/lexers/soong.py b/pygments/lexers/soong.py
new file mode 100644
index 0000000000000000000000000000000000000000..bbf204dd22543f06c65849d54613c3a573dda899
--- /dev/null
+++ b/pygments/lexers/soong.py
@@ -0,0 +1,78 @@
+"""
+    pygments.lexers.soong
+    ~~~~~~~~~~~~~~~~~~~~~
+
+    Lexers for Soong (Android.bp Blueprint) files.
+
+    :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS.
+    :license: BSD, see LICENSE for details.
+"""
+
+from pygments.lexer import RegexLexer, bygroups, include
+from pygments.token import Comment, Name, Number, Operator, Punctuation, \
+        String, Whitespace
+
+__all__ = ['SoongLexer']
+
+class SoongLexer(RegexLexer):
+    name = 'Soong'
+    version_added = '2.18'
+    url = 'https://source.android.com/docs/setup/reference/androidbp'
+    aliases = ['androidbp', 'bp', 'soong']
+    filenames = ['Android.bp']
+
+    tokens = {
+        'root': [
+            # A variable assignment
+            (r'(\w*)(\s*)(\+?=)(\s*)',
+             bygroups(Name.Variable, Whitespace, Operator, Whitespace),
+             'assign-rhs'),
+
+            # A top-level module
+            (r'(\w*)(\s*)(\{)',
+             bygroups(Name.Function, Whitespace, Punctuation),
+             'in-rule'),
+
+            # Everything else
+            include('comments'),
+            (r'\s+', Whitespace),  # newlines okay
+        ],
+        'assign-rhs': [
+            include('expr'),
+            (r'\n', Whitespace, '#pop'),
+        ],
+        'in-list': [
+            include('expr'),
+            include('comments'),
+            (r'\s+', Whitespace),  # newlines okay in a list
+            (r',', Punctuation),
+            (r'\]', Punctuation, '#pop'),
+        ],
+        'in-map': [
+            # A map key
+            (r'(\w+)(:)(\s*)', bygroups(Name, Punctuation, Whitespace)),
+
+            include('expr'),
+            include('comments'),
+            (r'\s+', Whitespace),  # newlines okay in a map
+            (r',', Punctuation),
+            (r'\}', Punctuation, '#pop'),
+        ],
+        'in-rule': [
+            # Just re-use map syntax
+            include('in-map'),
+        ],
+        'comments': [
+            (r'//.*', Comment.Single),
+            (r'/(\\\n)?[*](.|\n)*?[*](\\\n)?/', Comment.Multiline),
+        ],
+        'expr': [
+            (r'(true|false)\b', Name.Builtin),
+            (r'0x[0-9a-fA-F]+', Number.Hex),
+            (r'\d+', Number.Integer),
+            (r'".*?"', String),
+            (r'\{', Punctuation, 'in-map'),
+            (r'\[', Punctuation, 'in-list'),
+            (r'\w+', Name),
+        ],
+    }
diff --git a/pygments/lexers/sophia.py b/pygments/lexers/sophia.py
new file mode 100644
index 0000000000000000000000000000000000000000..37fcec5c39bfe4e8f52bf735618971684ab27739
--- /dev/null
+++ b/pygments/lexers/sophia.py
@@ -0,0 +1,102 @@
+"""
+    pygments.lexers.sophia
+    ~~~~~~~~~~~~~~~~~~~~~~
+
+    Lexer for Sophia.
+
+    Derived from pygments/lexers/reason.py.
+
+    :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS.
+    :license: BSD, see LICENSE for details.
+"""
+
+from pygments.lexer import RegexLexer, include, default, words
+from pygments.token import Comment, Keyword, Name, Number, Operator, \
+    Punctuation, String, Text
+
+__all__ = ['SophiaLexer']
+
+class SophiaLexer(RegexLexer):
+    """
+    A Sophia lexer.
+    """
+
+    name = 'Sophia'
+    aliases = ['sophia']
+    filenames = ['*.aes']
+    mimetypes = []
+    url = 'https://docs.aeternity.com/aesophia'
+    version_added = '2.11'
+
+    keywords = (
+        'contract', 'include', 'let', 'switch', 'type', 'record', 'datatype',
+        'if', 'elif', 'else', 'function', 'stateful', 'payable', 'public',
+        'entrypoint', 'private', 'indexed', 'namespace', 'interface', 'main',
+        'using', 'as', 'for', 'hiding',
+    )
+
+    builtins = ('state', 'put', 'abort', 'require')
+
+    word_operators = ('mod', 'band', 'bor', 'bxor', 'bnot')
+
+    primitive_types = ('int', 'address', 'bool', 'bits', 'bytes', 'string',
+                       'list', 'option', 'char', 'unit', 'map', 'event',
+                       'hash', 'signature', 'oracle', 'oracle_query')
+
+    tokens = {
+        'escape-sequence': [
+            (r'\\[\\"\'ntbr]', String.Escape),
+            (r'\\[0-9]{3}', String.Escape),
+            (r'\\x[0-9a-fA-F]{2}', String.Escape),
+        ],
+        'root': [
+            (r'\s+', Text.Whitespace),
+            (r'(true|false)\b', Keyword.Constant),
+            (r'\b([A-Z][\w\']*)(?=\s*\.)', Name.Class, 'dotted'),
+            (r'\b([A-Z][\w\']*)', Name.Function),
+            (r'//.*?\n', Comment.Single),
+            (r'\/\*(?!/)', Comment.Multiline, 'comment'),
+
+            (r'0[xX][\da-fA-F][\da-fA-F_]*', Number.Hex),
+            (r'#[\da-fA-F][\da-fA-F_]*', Name.Label),
+            (r'\d[\d_]*', Number.Integer),
+
+            (words(keywords, suffix=r'\b'), Keyword),
+            (words(builtins, suffix=r'\b'), Name.Builtin),
+            (words(word_operators, prefix=r'\b', suffix=r'\b'), Operator.Word),
+            (words(primitive_types, prefix=r'\b', suffix=r'\b'), Keyword.Type),
+
+            (r'[=!<>+\\*/:&|?~@^-]', Operator.Word),
+            (r'[.;:{}(),\[\]]', Punctuation),
+
+            (r"(ak_|ok_|oq_|ct_)[\w']*", Name.Label),
+            (r"[^\W\d][\w']*", Name),
+
+            (r"'(?:(\\[\\\"'ntbr ])|(\\[0-9]{3})|(\\x[0-9a-fA-F]{2}))'",
+             String.Char),
+            (r"'.'", String.Char),
+            (r"'[a-z][\w]*", Name.Variable),
+
+            (r'"', String.Double, 'string')
+        ],
+        'comment': [
+            (r'[^/*]+', Comment.Multiline),
+            (r'\/\*', Comment.Multiline, '#push'),
+            (r'\*\/', Comment.Multiline, '#pop'),
+            (r'\*', Comment.Multiline),
+        ],
+        'string': [
+            (r'[^\\"]+', String.Double),
+            include('escape-sequence'),
+            (r'\\\n', String.Double),
+            (r'"', String.Double, '#pop'),
+        ],
+        'dotted': [
+            (r'\s+', Text),
+            (r'\.', Punctuation),
+            (r'[A-Z][\w\']*(?=\s*\.)', Name.Function),
+            (r'[A-Z][\w\']*', Name.Function, '#pop'),
+            (r'[a-z_][\w\']*', Name, '#pop'),
+            default('#pop'),
+        ],
+    }
diff --git a/pygments/lexers/special.py b/pygments/lexers/special.py
new file mode 100644
index 0000000000000000000000000000000000000000..524946fc310a6acef08db483ac7f12f09a5e8706
--- /dev/null
+++ b/pygments/lexers/special.py
@@ -0,0 +1,122 @@
+"""
+    pygments.lexers.special
+    ~~~~~~~~~~~~~~~~~~~~~~~
+
+    Special lexers.
+
+    :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS.
+    :license: BSD, see LICENSE for details.
+"""
+
+import ast
+
+from pygments.lexer import Lexer, line_re
+from pygments.token import Token, Error, Text, Generic
+from pygments.util import get_choice_opt
+
+
+__all__ = ['TextLexer', 'OutputLexer', 'RawTokenLexer']
+
+
+class TextLexer(Lexer):
+    """
+    "Null" lexer, doesn't highlight anything.
+    """
+    name = 'Text only'
+    aliases = ['text']
+    filenames = ['*.txt']
+    mimetypes = ['text/plain']
+    url = ""
+    version_added = ''
+
+    priority = 0.01
+
+    def get_tokens_unprocessed(self, text):
+        yield 0, Text, text
+
+    def analyse_text(text):
+        return TextLexer.priority
+
+
+class OutputLexer(Lexer):
+    """
+    Simple lexer that highlights everything as ``Token.Generic.Output``.
+    """
+    name = 'Text output'
+    aliases = ['output']
+    url = ""
+    version_added = '2.10'
+    _example = "output/output"
+
+    def get_tokens_unprocessed(self, text):
+        yield 0, Generic.Output, text
+
+
+_ttype_cache = {}
+
+
+class RawTokenLexer(Lexer):
+    """
+    Recreate a token stream formatted with the `RawTokenFormatter`.
+
+    Additional options accepted:
+
+    `compress`
+        If set to ``"gz"`` or ``"bz2"``, decompress the token stream with
+        the given compression algorithm before lexing (default: ``""``).
+    """
+    name = 'Raw token data'
+    aliases = []
+    filenames = []
+    mimetypes = ['application/x-pygments-tokens']
+    url = 'https://pygments.org/docs/formatters/#RawTokenFormatter'
+    version_added = ''
+
+    def __init__(self, **options):
+        self.compress = get_choice_opt(options, 'compress',
+                                       ['', 'none', 'gz', 'bz2'], '')
+        Lexer.__init__(self, **options)
+
+    def get_tokens(self, text):
+        if self.compress:
+            if isinstance(text, str):
+                text = text.encode('latin1')
+            try:
+                if self.compress == 'gz':
+                    import gzip
+                    text = gzip.decompress(text)
+                elif self.compress == 'bz2':
+                    import bz2
+                    text = bz2.decompress(text)
+            except OSError:
+                yield Error, text.decode('latin1')
+        if isinstance(text, bytes):
+            text = text.decode('latin1')
+
+        # do not call Lexer.get_tokens() because stripping is not optional.
+        text = text.strip('\n') + '\n'
+        for i, t, v in self.get_tokens_unprocessed(text):
+            yield t, v
+
+    def get_tokens_unprocessed(self, text):
+        length = 0
+        for match in line_re.finditer(text):
+            try:
+                ttypestr, val = match.group().rstrip().split('\t', 1)
+                ttype = _ttype_cache.get(ttypestr)
+                if not ttype:
+                    ttype = Token
+                    ttypes = ttypestr.split('.')[1:]
+                    for ttype_ in ttypes:
+                        if not ttype_ or not ttype_[0].isupper():
+                            raise ValueError('malformed token name')
+                        ttype = getattr(ttype, ttype_)
+                    _ttype_cache[ttypestr] = ttype
+                val = ast.literal_eval(val)
+                if not isinstance(val, str):
+                    raise ValueError('expected str')
+            except (SyntaxError, ValueError):
+                val = match.group()
+                ttype = Error
+            yield length, ttype, val
+            length += len(val)
diff --git a/pygments/lexers/spice.py b/pygments/lexers/spice.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d2b1a1a81a47b9ac830b204df3ce015b8f66985
--- /dev/null
+++ b/pygments/lexers/spice.py
@@ -0,0 +1,70 @@
+"""
+    pygments.lexers.spice
+    ~~~~~~~~~~~~~~~~~~~~~
+
+    Lexers for the Spice programming language.
+
+    :copyright: Copyright 2006-2025 by the Pygments team, see AUTHORS.
+    :license: BSD, see LICENSE for details.
+"""
+
+from pygments.lexer import RegexLexer, bygroups, words
+from pygments.token import Text, Comment, Operator, Keyword, Name, String, \
+    Number, Punctuation, Whitespace
+
+__all__ = ['SpiceLexer']
+
+
+class SpiceLexer(RegexLexer):
+    """
+    For Spice source.
+    """
+    name = 'Spice'
+    url = 'https://www.spicelang.com'
+    filenames = ['*.spice']
+    aliases = ['spice', 'spicelang']
+    mimetypes = ['text/x-spice']
+    version_added = '2.11'
+
+    tokens = {
+        'root': [
+            (r'\n', Whitespace),
+            (r'\s+', Whitespace),
+            (r'\\\n', Text),
+            # comments
+            (r'//(.*?)\n', Comment.Single),
+            (r'/(\\\n)?[*]{2}(.|\n)*?[*](\\\n)?/', String.Doc),
+            (r'/(\\\n)?[*](.|\n)*?[*](\\\n)?/', Comment.Multiline),
+            # keywords
+            (r'(import|as)\b', Keyword.Namespace),
+            (r'(f|p|type|struct|interface|enum|alias|operator)\b', Keyword.Declaration),
+            (words(('if', 'else', 'switch', 'case', 'default', 'for', 'foreach', 'do',
+                    'while', 'break', 'continue', 'fallthrough', 'return', 'assert',
+                    'unsafe', 'ext'), suffix=r'\b'), Keyword),
+            (words(('const', 'signed', 'unsigned', 'inline', 'public', 'heap', 'compose'),
+                   suffix=r'\b'), Keyword.Pseudo),
+            (words(('new', 'yield', 'stash', 'pick', 'sync', 'class'), suffix=r'\b'),
+                   Keyword.Reserved),
+            (r'(true|false|nil)\b', Keyword.Constant),
+            (words(('double', 'int', 'short', 'long', 'byte', 'char', 'string',
+                    'bool', 'dyn'), suffix=r'\b'), Keyword.Type),
+            (words(('printf', 'sizeof', 'alignof', 'len', 'panic'), suffix=r'\b(\()'),
+             bygroups(Name.Builtin, Punctuation)),
+            # numeric literals
+            (r'[-]?[0-9]*[.][0-9]+([eE][+-]?[0-9]+)?', Number.Double),
+            (r'0[bB][01]+[slu]?', Number.Bin),
+            (r'0[oO][0-7]+[slu]?', Number.Oct),
+            (r'0[xXhH][0-9a-fA-F]+[slu]?', Number.Hex),
+            (r'(0[dD])?[0-9]+[slu]?', Number.Integer),
+            # string literal
+            (r'"(\\\\|\\[^\\]|[^"\\])*"', String),
+            # char literal
+            (r'\'(\\\\|\\[^\\]|[^\'\\])\'', String.Char),
+            # tokens
+            (r'<<=|>>=|<<|>>|<=|>=|\+=|-=|\*=|/=|\%=|\|=|&=|\^=|&&|\|\||&|\||'
+             r'\+\+|--|\%|\^|\~|==|!=|->|::|[.]{3}|#!|#|[+\-*/&]', Operator),
+            (r'[|<>=!()\[\]{}.,;:\?]', Punctuation),
+            # identifiers
+            (r'[^\W\d]\w*', Name.Other),
+        ]
+    }
diff --git a/pygments/lexers/sql.py b/pygments/lexers/sql.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3e6f17f39491b43cc621ca6981ff33ca90c380f
--- /dev/null
+++ b/pygments/lexers/sql.py
@@ -0,0 +1,1109 @@
+"""
+    pygments.lexers.sql
+    ~~~~~~~~~~~~~~~~~~~
+
+    Lexers for various SQL dialects and related interactive sessions.
+
+    Postgres specific lexers:
+
+    `PostgresLexer`
+        A SQL lexer for the PostgreSQL dialect. Differences w.r.t. the SQL
+        lexer are:
+
+        - keywords and data types list parsed from the PG docs (run the
+          `_postgres_builtins` module to update them);
+        - Content of $-strings parsed using a specific lexer, e.g. the content
+          of a PL/Python function is parsed using the Python lexer;
+        - parse PG specific constructs: E-strings, $-strings, U&-strings,
+          different operators and punctuation.
+
+    `PlPgsqlLexer`
+        A lexer for the PL/pgSQL language. Adds a few specific construct on
+        top of the PG SQL lexer (such as <